Commit e791b9a5 authored by Yuxin Wu's avatar Yuxin Wu

check config.data/config.model in trainers

parent 784e2b7b
...@@ -10,13 +10,14 @@ class MyModel(ModelDesc): ...@@ -10,13 +10,14 @@ class MyModel(ModelDesc):
return [InputDesc(...), InputDesc(...)] return [InputDesc(...), InputDesc(...)]
def _build_graph(self, inputs): def _build_graph(self, inputs):
tensorA, tensorB = inputs
# build the graph # build the graph
def _get_optimizer(self): def _get_optimizer(self):
return tf.train.GradientDescentOptimizer(0.1) return tf.train.GradientDescentOptimizer(0.1)
``` ```
Basically, `_get_inputs` should define the metainfo of all the possible placeholders your graph may need. `_get_inputs` should define the metainfo of all the inputs your graph may need.
`_build_graph` should add tensors/operations to the graph, where `_build_graph` should add tensors/operations to the graph, where
the argument `inputs` is the list of input tensors matching `_get_inputs`. the argument `inputs` is the list of input tensors matching `_get_inputs`.
......
...@@ -95,8 +95,9 @@ class TrainConfig(object): ...@@ -95,8 +95,9 @@ class TrainConfig(object):
monitors = [TFEventWriter(), JSONWriter(), ScalarPrinter()] monitors = [TFEventWriter(), JSONWriter(), ScalarPrinter()]
self.monitors = monitors self.monitors = monitors
if model is not None:
assert_type(model, ModelDesc)
self.model = model self.model = model
assert_type(self.model, ModelDesc)
if session_init is None: if session_init is None:
session_init = JustCurrentSession() session_init = JustCurrentSession()
......
...@@ -52,7 +52,7 @@ class DistributedReplicatedTrainer(MultiGPUTrainerBase): ...@@ -52,7 +52,7 @@ class DistributedReplicatedTrainer(MultiGPUTrainerBase):
config (TrainConfig): the train config. config (TrainConfig): the train config.
server (tf.train.Server): the server object with ps and workers server (tf.train.Server): the server object with ps and workers
""" """
assert config.data is not None and config.model is not None
self.server = server self.server = server
server_def = server.server_def server_def = server.server_def
self.cluster = tf.train.ClusterSpec(server_def.cluster) self.cluster = tf.train.ClusterSpec(server_def.cluster)
...@@ -83,7 +83,8 @@ class DistributedReplicatedTrainer(MultiGPUTrainerBase): ...@@ -83,7 +83,8 @@ class DistributedReplicatedTrainer(MultiGPUTrainerBase):
@staticmethod @staticmethod
def _average_grads(tower_grads, devices): def _average_grads(tower_grads, devices):
""" """
Average grad with round-robin device selection. Average grads from towers.
The device where the average happens is chosen with round-robin.
Args: Args:
tower_grads: Ngpu x Nvar x 2 tower_grads: Ngpu x Nvar x 2
...@@ -111,6 +112,9 @@ class DistributedReplicatedTrainer(MultiGPUTrainerBase): ...@@ -111,6 +112,9 @@ class DistributedReplicatedTrainer(MultiGPUTrainerBase):
def _apply_shadow_vars(avg_grads): def _apply_shadow_vars(avg_grads):
""" """
Replace variables in avg_grads by shadow variables. Replace variables in avg_grads by shadow variables.
Args:
avg_grads: list of (grad, var) tuples
""" """
ps_var_grads = [] ps_var_grads = []
for grad, var in avg_grads: for grad, var in avg_grads:
......
...@@ -58,6 +58,7 @@ def QueueInputTrainer(config, input_queue=None): ...@@ -58,6 +58,7 @@ def QueueInputTrainer(config, input_queue=None):
config (TrainConfig): a `TrainConfig` instance. config.dataflow must exist. config (TrainConfig): a `TrainConfig` instance. config.dataflow must exist.
input_queue (tf.QueueBase): an input queue. Defaults to the :class:`QueueInput` default. input_queue (tf.QueueBase): an input queue. Defaults to the :class:`QueueInput` default.
""" """
assert (config.data is not None or config.dataflow is not None) and config.model is not None
if config.data is not None: if config.data is not None:
assert isinstance(config.data, QueueInput), config.data assert isinstance(config.data, QueueInput), config.data
else: else:
......
...@@ -30,6 +30,7 @@ def _check_tf_version(): ...@@ -30,6 +30,7 @@ def _check_tf_version():
def apply_prefetch_policy(config, gpu_prefetch=True): def apply_prefetch_policy(config, gpu_prefetch=True):
assert (config.data is not None or config.dataflow is not None) and config.model is not None
if config.data is None and config.dataflow is not None: if config.data is None and config.dataflow is not None:
# always use Queue prefetch # always use Queue prefetch
config.data = QueueInput(config.dataflow) config.data = QueueInput(config.dataflow)
......
...@@ -27,6 +27,7 @@ class SimpleTrainer(Trainer): ...@@ -27,6 +27,7 @@ class SimpleTrainer(Trainer):
"Got nr_tower={}, but doesn't support multigpu!" \ "Got nr_tower={}, but doesn't support multigpu!" \
" Use Sync/AsyncMultiGPUTrainer instead.".format(len(config.tower)) " Use Sync/AsyncMultiGPUTrainer instead.".format(len(config.tower))
assert (config.data is not None or config.dataflow is not None) and config.model is not None
if config.dataflow is None: if config.dataflow is None:
self._input_source = config.data self._input_source = config.data
else: else:
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment