Commit e791b9a5 authored by Yuxin Wu's avatar Yuxin Wu

check config.data/config.model in trainers

parent 784e2b7b
......@@ -10,13 +10,14 @@ class MyModel(ModelDesc):
return [InputDesc(...), InputDesc(...)]
def _build_graph(self, inputs):
tensorA, tensorB = inputs
# build the graph
def _get_optimizer(self):
return tf.train.GradientDescentOptimizer(0.1)
```
Basically, `_get_inputs` should define the metainfo of all the possible placeholders your graph may need.
`_get_inputs` should define the metainfo of all the inputs your graph may need.
`_build_graph` should add tensors/operations to the graph, where
the argument `inputs` is the list of input tensors matching `_get_inputs`.
......
......@@ -95,8 +95,9 @@ class TrainConfig(object):
monitors = [TFEventWriter(), JSONWriter(), ScalarPrinter()]
self.monitors = monitors
if model is not None:
assert_type(model, ModelDesc)
self.model = model
assert_type(self.model, ModelDesc)
if session_init is None:
session_init = JustCurrentSession()
......
......@@ -52,7 +52,7 @@ class DistributedReplicatedTrainer(MultiGPUTrainerBase):
config (TrainConfig): the train config.
server (tf.train.Server): the server object with ps and workers
"""
assert config.data is not None and config.model is not None
self.server = server
server_def = server.server_def
self.cluster = tf.train.ClusterSpec(server_def.cluster)
......@@ -83,7 +83,8 @@ class DistributedReplicatedTrainer(MultiGPUTrainerBase):
@staticmethod
def _average_grads(tower_grads, devices):
"""
Average grad with round-robin device selection.
Average grads from towers.
The device where the average happens is chosen with round-robin.
Args:
tower_grads: Ngpu x Nvar x 2
......@@ -111,6 +112,9 @@ class DistributedReplicatedTrainer(MultiGPUTrainerBase):
def _apply_shadow_vars(avg_grads):
"""
Replace variables in avg_grads by shadow variables.
Args:
avg_grads: list of (grad, var) tuples
"""
ps_var_grads = []
for grad, var in avg_grads:
......
......@@ -58,6 +58,7 @@ def QueueInputTrainer(config, input_queue=None):
config (TrainConfig): a `TrainConfig` instance. config.dataflow must exist.
input_queue (tf.QueueBase): an input queue. Defaults to the :class:`QueueInput` default.
"""
assert (config.data is not None or config.dataflow is not None) and config.model is not None
if config.data is not None:
assert isinstance(config.data, QueueInput), config.data
else:
......
......@@ -30,6 +30,7 @@ def _check_tf_version():
def apply_prefetch_policy(config, gpu_prefetch=True):
assert (config.data is not None or config.dataflow is not None) and config.model is not None
if config.data is None and config.dataflow is not None:
# always use Queue prefetch
config.data = QueueInput(config.dataflow)
......
......@@ -27,6 +27,7 @@ class SimpleTrainer(Trainer):
"Got nr_tower={}, but doesn't support multigpu!" \
" Use Sync/AsyncMultiGPUTrainer instead.".format(len(config.tower))
assert (config.data is not None or config.dataflow is not None) and config.model is not None
if config.dataflow is None:
self._input_source = config.data
else:
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment