Commit e5f5da3c authored by Yuxin Wu's avatar Yuxin Wu

Let MovingAverageSummary bind the train_op instead of the session.

parent 93819550
......@@ -55,6 +55,7 @@ before_script:
# Check that these private names can be imported because tensorpack is using them
- python -c "from tensorflow.python.client.session import _FetchHandler"
- python -c "from tensorflow.python.training.monitored_session import _HookedSession"
- python -c "import tensorflow as tf; tf.Operation._add_control_input"
script:
- flake8 .
......
......@@ -19,29 +19,47 @@ class MovingAverageSummary(Callback):
This callback is enabled by default.
Maintain the moving average of summarized tensors in every step,
by ops added to the collection.
Note that it only __maintains__ the moving averages in the graph,
Note that it only __maintains__ the moving averages by updating
the relevant variables in the graph,
the actual summary should be done in other callbacks.
"""
def __init__(self, collection=MOVING_SUMMARY_OPS_KEY):
def __init__(self, collection=MOVING_SUMMARY_OPS_KEY, train_op=None):
"""
Args:
collection(str): the collection of EMA-maintaining ops.
The default value would work with
the tensors you added by :func:`tfutils.summary.add_moving_summary()`,
but you can use other collections as well.
train_op (tf.Operation or str): the (name of) training op to associate the maintaing ops with.
If not provided, the EMA-maintaining ops will be hooked to
`trainer.hooked_session` and be executed in every iteration.
Otherwise, the EMA-maintaining ops will be executed whenever
the training op is executed.
"""
self._collection = collection
self._train_op = train_op
def _setup_graph(self):
ops = tf.get_collection(self._collection)
logger.info("Maintain moving average summary of {} tensors in collection {}.".format(
len(ops), self._collection))
ops = [k.op for k in tf.get_collection(self._collection)]
if self._train_op is None:
logger.info("[MovingAverageSummary] {} operations in collection '{}' "
"will be run with session hooks.".format(len(ops), self._collection))
self.ema_op = tf.group(*ops, name='maintain_moving_average_summary')
self._fetch = tf.train.SessionRunArgs(fetches=self.ema_op)
self.ema_op = tf.group(*ops, name='maintain_moving_average_summary')
self._fetch = tf.train.SessionRunArgs(fetches=self.ema_op)
else:
if isinstance(self._train_op, tf.Tensor):
self._train_op = self._train_op.op
if not isinstance(self._train_op, tf.Operation):
self._train_op = self.graph.get_operation_by_name(self._train_op)
self._train_op._add_control_inputs(ops)
logger.info("[MovingAverageSummary] {} operations in collection '{}'"
" will be run together with operation '{}'.".format(
len(ops), self._collection, self._train_op.name))
def _before_run(self, _):
return self._fetch
if self._train_op is None:
return self._fetch
class MergeAllSummaries_RunAlone(Callback):
......
......@@ -77,7 +77,7 @@ class DataParallelBuilder(GraphBuilder):
raise ValueError("Number of gradients from each tower is different! " + str(nvars))
@staticmethod
def build_on_towers(
def call_for_each_tower(
towers, func, devices=None, use_vs=None):
"""
Run `func` on all GPUs (towers) and return the results.
......@@ -119,6 +119,10 @@ class DataParallelBuilder(GraphBuilder):
ret.append(func())
return ret
@staticmethod
def build_on_towers(*args, **kwargs):
return DataParallelBuilder.call_for_each_tower(*args, **kwargs)
class SyncMultiGPUParameterServerBuilder(DataParallelBuilder):
"""
......@@ -405,4 +409,4 @@ class AsyncMultiGPUBuilder(DataParallelBuilder):
# will call apply_gradients (therefore gradproc) multiple times
train_ops.append(opt.apply_gradients(
grad_and_vars, name='apply_grad_{}'.format(i)))
return tf.group(*train_ops, name='train_op')
return tf.group(*train_ops, name='train_op')
......@@ -53,7 +53,7 @@ class SimpleTrainer(SingleCostTrainer):
with TrainTowerContext(''):
grads = self._make_get_grad_fn(input, get_cost_fn, get_opt_fn)()
opt = get_opt_fn()
self.train_op = opt.apply_gradients(grads, name='min_op')
self.train_op = opt.apply_gradients(grads, name='train_op')
return []
......@@ -404,7 +404,7 @@ class HorovodTrainer(SingleCostTrainer):
grads = self.allreduce(grads)
opt = get_opt_fn()
self.train_op = opt.apply_gradients(grads, name='min_op')
self.train_op = opt.apply_gradients(grads, name='train_op')
def broadcast(self):
logger.info("Running horovod broadcast ...")
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment