Commit 95cb6ba2 authored by Yuxin Wu's avatar Yuxin Wu

add trainer.create_session, add verboes in RunOp

parent ee1af311
......@@ -173,6 +173,10 @@ class Callback(object):
"""
return self._chief_only
@chief_only.setter
def chief_only(self, v):
self._chief_only = v
def __str__(self):
return type(self).__name__
......
......@@ -17,13 +17,15 @@ class RunOp(Callback):
""" Run an Op. """
def __init__(self, setup_func,
run_before=True, run_as_trigger=True, run_step=False):
run_before=True, run_as_trigger=True,
run_step=False, verbose=False):
"""
Args:
setup_func: a function that returns the Op in the graph
run_before (bool): run the Op before training
run_as_trigger (bool): run the Op on every trigger
run_step (bool): run the Op every step (along with training)
verbose (bool): pring logs when the op is run.
Examples:
The `DQN Example
......@@ -34,22 +36,30 @@ class RunOp(Callback):
self.run_before = run_before
self.run_as_trigger = run_as_trigger
self.run_step = run_step
self.verbose = verbose
def _setup_graph(self):
self._op = self.setup_func()
def _before_train(self):
if self.run_before:
self._print()
self._op.run()
def _trigger(self):
if self.run_as_trigger:
self._print()
self._op.run()
def _before_run(self, _):
if self.run_step:
self._print()
return [self._op]
def _print(self):
if self.verbose:
logger.info("Running Op {} ...".format(self._op.name))
class RunUpdateOps(RunOp):
"""
......
......@@ -9,8 +9,6 @@ import six
from six.moves import range
import tensorflow as tf
from tensorflow.python.training.monitored_session \
import _HookedSession as HookedSession
from .predict import PredictorFactory
from .config import TrainConfig
......@@ -118,6 +116,7 @@ class Trainer(object):
self.monitors = Monitors(self.monitors)
self.register_callback(self.monitors)
# TODO cache per graph, avoid describing all towers
describe_model()
# some final operations that might modify the graph
......@@ -125,21 +124,24 @@ class Trainer(object):
self._callbacks = Callbacks(self._callbacks)
self._callbacks.setup_graph(weakref.proxy(self))
# create session
logger.info("Creating the session ...")
self.sess = self.config.session_creator.create_session()
self._monitored_sess = tf.train.MonitoredSession(
session_creator=ReuseSessionCreator(self.sess), hooks=None)
self._create_session()
logger.info("Initializing the session ...")
# init session
self.config.session_init.init(self.sess)
self.sess.graph.finalize()
logger.info("Graph Finalized.")
def _create_session(self):
"""
Setup self.sess (the raw tf.Session)
and self.hooked_sess (the session with hooks and coordinator)
"""
hooks = self._callbacks.get_hooks()
self.hooked_sess = HookedSession(self.sess, hooks)
self.sess = self.config.session_creator.create_session()
self.hooked_sess = tf.train.MonitoredSession(
session_creator=ReuseSessionCreator(self.sess), hooks=hooks)
@abstractmethod
def _setup(self):
......@@ -167,7 +169,7 @@ class Trainer(object):
logger.info("Start Epoch {} ...".format(self.epoch_num))
start_time = time.time()
for self.local_step in range(self.config.steps_per_epoch):
if self._monitored_sess.should_stop():
if self.hooked_sess.should_stop():
return
self.run_step() # implemented by subclass
self._callbacks.trigger_step()
......@@ -186,7 +188,7 @@ class Trainer(object):
raise
finally:
self._callbacks.after_train()
self._monitored_sess.close()
self.hooked_sess.close()
# Predictor related methods: TODO
def get_predictor(self, input_names, output_names, tower=0):
......
......@@ -165,8 +165,10 @@ class DistributedReplicatedTrainer(SingleCostFeedfreeTrainer):
main_fetch = tf.group(*var_update_ops, name='main_fetches')
self.train_op = self.add_sync_queues_and_barrier(
'post_copy_barrier', [main_fetch])
self.register_callback(RunOp(
self.get_post_init_ops, run_before=True, run_as_trigger=False))
cb = RunOp(self.get_post_init_ops,
run_before=True, run_as_trigger=False, verbose=True)
cb.chief_only = False
self.register_callback(cb)
self._set_session_creator()
......@@ -251,4 +253,4 @@ class DistributedReplicatedTrainer(SingleCostFeedfreeTrainer):
post_init_ops.append(copy_to.assign(v.read_value()))
else:
logger.warn("Global varable {} doesn't match a corresponding local var".format(v.name))
return tf.group(*post_init_ops, name='post_init_ops')
return tf.group(*post_init_ops, name='sync_variables_from_ps')
......@@ -242,7 +242,7 @@ class QueueInput(FeedfreeInput):
def setup_training(self, trainer):
super(QueueInput, self).setup_training(trainer)
cb = StartProcOrThread(self.thread)
cb._chief_only = False
cb.chief_only = False
trainer.register_callback(cb)
def get_input_tensors(self):
......
......@@ -70,7 +70,7 @@ class MultiGPUTrainerBase(Trainer):
keys_to_freeze = TOWER_FREEZE_KEYS[:]
if var_strategy == 'replicated': # TODO ugly
logger.info("UPDATE_OPS from all GPUs will be kept in the collection.")
logger.info("In replicated mode, UPDATE_OPS from all GPUs will be run.")
keys_to_freeze.remove(tf.GraphKeys.UPDATE_OPS)
for idx, t in enumerate(towers):
......@@ -261,7 +261,7 @@ class SyncMultiGPUTrainerReplicated(MultiGPUTrainerBase, SingleCostFeedfreeTrain
self.train_op = tf.group(*train_ops, name='train_op')
self.register_callback(RunOp(
SyncMultiGPUTrainerReplicated.get_post_init_ops,
run_before=True, run_as_trigger=True))
run_before=True, run_as_trigger=True, verbose=True))
# Adopt from https://github.com/tensorflow/benchmarks/blob/master/scripts/tf_cnn_benchmarks/variable_mgr.py
......@@ -279,7 +279,7 @@ class SyncMultiGPUTrainerReplicated(MultiGPUTrainerBase, SingleCostFeedfreeTrain
split_name = split_name[1:]
copy_from = var_by_name['/'.join(split_name)]
post_init_ops.append(v.assign(copy_from.read_value()))
return tf.group(*post_init_ops, name='init_sync_vars')
return tf.group(*post_init_ops, name='sync_variables_from_tower0')
class AsyncMultiGPUTrainer(MultiGPUTrainerBase,
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment