Commit 9e995a8d authored by Yuxin Wu's avatar Yuxin Wu

Use `TrainLoop` to manage the loop, and delegate properties. Hide `trainer.config`

parent 7efe4939
......@@ -223,9 +223,9 @@ class EvalCallback(Callback):
self.df = PrefetchDataZMQ(get_eval_dataflow(), 1)
EVAL_TIMES = 5 # eval 5 times during training
interval = self.trainer.config.max_epoch // (EVAL_TIMES + 1)
interval = self.trainer.max_epoch // (EVAL_TIMES + 1)
self.epochs_to_eval = set([interval * k for k in range(1, EVAL_TIMES)])
self.epochs_to_eval.add(self.trainer.config.max_epoch)
self.epochs_to_eval.add(self.trainer.max_epoch)
get_tf_nms() # just to make sure the nms part of graph is created
def _eval(self):
......
......@@ -45,7 +45,7 @@ class Callback(object):
_chief_only = True
def setup_graph(self, trainer):
self._steps_per_epoch = trainer.config.steps_per_epoch
self._steps_per_epoch = trainer.steps_per_epoch
self.trainer = trainer
self.graph = tf.get_default_graph()
scope_name = type(self).__name__
......
......@@ -124,7 +124,7 @@ class InferenceRunner(InferenceRunnerBase):
def _setup_graph(self):
assert self.trainer.model is not None
# Use predict_tower in train config. either gpuid or -1
tower_id = self.trainer.config.predict_tower[0]
tower_id = self.trainer._config.predict_tower[0]
device = '/gpu:{}'.format(tower_id) if tower_id >= 0 else '/cpu:0'
input_callbacks = self._input_source.setup(self.trainer.model.get_inputs_desc())
......
......@@ -256,13 +256,13 @@ class JSONWriter(TrainingMonitor):
pass
else:
logger.info("Found training history from JSON, now starting from epoch number {}.".format(epoch))
self.trainer.config.starting_epoch = epoch
self.trainer.starting_epoch = epoch
else:
self._stats = []
self._stat_now = {}
self._last_gs = -1
self._total = self.trainer.config.steps_per_epoch
self._total = self.trainer.steps_per_epoch
def _trigger_step(self):
# will do this in trigger_epoch
......@@ -327,7 +327,7 @@ class ScalarPrinter(TrainingMonitor):
def _setup_graph(self):
self._dic = {}
self._total = self.trainer.config.steps_per_epoch
self._total = self.trainer.steps_per_epoch
def _trigger_step(self):
if self._enable_step:
......
......@@ -67,7 +67,7 @@ class ProgressBar(Callback):
def _before_train(self):
self._last_updated = self.local_step
self._total = self.trainer.config.steps_per_epoch
self._total = self.trainer.steps_per_epoch
self._tqdm_args = get_tqdm_kwargs(leave=True)
self._fetches = get_op_or_tensor_by_name(self._names) or None
......@@ -133,4 +133,4 @@ class MaintainStepCounter(Callback):
def _after_run(self, _, __):
# Keep python-side global_step in agreement with TF-side
self.trainer._global_step += 1
self.trainer.loop._global_step += 1
......@@ -70,7 +70,7 @@ class MergeAllSummaries_RunWithOp(Callback):
self._fetches = tf.train.SessionRunArgs(self.summary_op)
else:
self._fetches = None
self._total = self.trainer.config.steps_per_epoch
self._total = self.trainer.steps_per_epoch
def _need_run(self):
if self.local_step == self._total - 1:
......
......@@ -30,6 +30,63 @@ class StopTraining(BaseException):
pass
class TrainLoop(object):
"""
Manage the double for loop.
"""
def __init__(self):
self._epoch_num = 0
self._global_step = 0
self._local_step = -1
def config(self, steps_per_epoch, starting_epoch, max_epoch):
"""
Configure the loop given the settings.
"""
self.starting_epoch = starting_epoch
self.max_epoch = max_epoch
self.steps_per_epoch = steps_per_epoch
self._epoch_num = starting_epoch - 1
def update_global_step(self):
"""
Update the Python-side global_step from TF.
This must be called under initialized default session.
"""
self._global_step = get_global_step_value()
@property
def epoch_num(self):
"""
The number of the currently ongoing epoch.
An epoch is defined to cover the moment before calling `before_epoch` until after calling `trigger_epoch`.
i.e., in the `trigger_epoch` of epoch 3, `self.epoch_num` is 3.
If you need use `self.epoch_num` in your callback, you'll need to know this.
"""
return self._epoch_num
@property
def global_step(self):
"""
The tensorflow global_step, i.e. how many times ``hooked_sess.run`` has been called.
Note:
1. global_step is incremented **after** each ``hooked_sess.run`` returns from TF runtime.
2. If you make zero or more than one calls to ``hooked_sess.run`` in one
:meth:`run_step`, local_step and global_step may increment at different speed.
"""
return self._global_step
@property
def local_step(self):
"""
The number of (tensorpack) steps that have finished in the current epoch.
"""
return self._local_step
class Trainer(object):
""" Base class for a trainer.
......@@ -39,7 +96,6 @@ class Trainer(object):
sess (tf.Session): the current session in use.
hooked_sess (tf.train.MonitoredSession): the session with hooks.
monitors (Monitors): the monitors. Other callbacks can use it for logging.
local_step (int): the number of (tensorpack) steps that have finished in the current epoch.
"""
# step attr only available after before_train?
......@@ -51,33 +107,16 @@ class Trainer(object):
config (TrainConfig): the train config.
"""
assert isinstance(config, TrainConfig), type(config)
self.config = config
self._config = config
self.model = config.model
self.local_step = -1
self._callbacks = []
self.monitors = []
self._epoch_num = None
self._global_step = 0
self.loop = TrainLoop()
self.loop.config(config.steps_per_epoch, config.starting_epoch, config.max_epoch)
self._setup() # subclass will setup the graph and InputSource
@property
def epoch_num(self):
"""
The number of the currently ongoing epoch.
An epoch is defined to cover the moment before calling `before_epoch` until after calling `trigger_epoch`.
i.e., in the `trigger_epoch` of epoch 3, `self.epoch_num` is 3.
If you need use `self.epoch_num` in your callback, you'll need to know this.
"""
if self._epoch_num is not None:
# has started training
return self._epoch_num
else:
return self.config.starting_epoch - 1
def register_callback(self, cb):
"""
Register a callback to the trainer.
......@@ -129,9 +168,9 @@ class Trainer(object):
Setup the trainer and be ready for the main loop.
"""
self.register_callback(MaintainStepCounter())
for cb in self.config.callbacks:
for cb in self._config.callbacks:
self.register_callback(cb)
for m in self.config.monitors:
for m in self._config.monitors:
self.register_monitor(m)
self.monitors = Monitors(self.monitors)
self.register_callback(self.monitors)
......@@ -148,9 +187,9 @@ class Trainer(object):
if self.is_chief:
logger.info("Initializing the session ...")
self.config.session_init.init(self.sess)
self._config.session_init.init(self.sess)
else:
assert isinstance(self.config.session_init, JustCurrentSession), \
assert isinstance(self._config.session_init, JustCurrentSession), \
"session_init is only valid for chief worker session!"
self.sess.graph.finalize()
......@@ -162,7 +201,7 @@ class Trainer(object):
and self.hooked_sess (the session with hooks and coordinator)
"""
hooks = self._callbacks.get_hooks()
self.sess = self.config.session_creator.create_session()
self.sess = self._config.session_creator.create_session()
self.hooked_sess = tf.train.MonitoredSession(
session_creator=ReuseSessionCreator(self.sess), hooks=hooks)
......@@ -176,41 +215,29 @@ class Trainer(object):
"""
pass
@property
def global_step(self):
"""
The tensorflow global_step, i.e. how many times ``hooked_sess.run`` has been called.
Note:
1. global_step is incremented **after** each ``hooked_sess.run`` returns from TF runtime.
2. If you make zero or more than one calls to ``hooked_sess.run`` in one
:meth:`run_step`, local_step and global_step may increment at different speed.
"""
return self._global_step
def main_loop(self):
"""
Run the main training loop.
"""
with self.sess.as_default():
self._global_step = get_global_step_value()
self.loop.update_global_step()
try:
self._callbacks.before_train()
# refresh global step (might have changed by callbacks) TODO ugly
self._global_step = get_global_step_value()
for self._epoch_num in range(
self.config.starting_epoch, self.config.max_epoch + 1):
logger.info("Start Epoch {} ...".format(self._epoch_num))
self.loop.update_global_step()
for self.loop._epoch_num in range(
self.loop.starting_epoch, self.loop.max_epoch + 1):
logger.info("Start Epoch {} ...".format(self.loop.epoch_num))
start_time = time.time()
self._callbacks.before_epoch()
for self.local_step in range(self.config.steps_per_epoch):
for self.loop._local_step in range(self.loop.steps_per_epoch):
if self.hooked_sess.should_stop():
return
self.run_step() # implemented by subclass
self._callbacks.trigger_step()
self._callbacks.after_epoch()
logger.info("Epoch {} (global_step {}) finished, time:{:.2f} sec.".format(
self._epoch_num, self.global_step, time.time() - start_time))
self.loop.epoch_num, self.loop.global_step, time.time() - start_time))
# trigger epoch outside the timing region.
self._callbacks.trigger_epoch()
......@@ -256,6 +283,19 @@ class Trainer(object):
return ""
def _delegate_attr(name):
"""
Delegate property to self.loop
"""
setattr(Trainer, name, property(
lambda self: getattr(self.loop, name)))
for name in ['global_step', 'local_step', 'steps_per_epoch',
'epoch_num', 'starting_epoch', 'max_epoch']:
_delegate_attr(name)
def launch_train(
run_step, model=None, callbacks=None, extra_callbacks=None, monitors=None,
session_creator=None, session_config=None, session_init=None,
......
......@@ -88,7 +88,7 @@ class DistributedTrainerReplicated(Trainer):
# whether something should be global or local. We now assume
# they should be local.
cbs = self._input_source.setup(self.model.get_inputs_desc())
self.config.callbacks.extend(cbs)
self._config.callbacks.extend(cbs)
self.train_op, initial_sync_op, model_sync_op = self._builder.build(
self._input_source, self.model.build_graph_get_cost, self.model.get_optimizer)
......@@ -110,14 +110,14 @@ class DistributedTrainerReplicated(Trainer):
self._set_session_creator()
def _set_session_creator(self):
old_sess_creator = self.config.session_creator
old_sess_creator = self._config.session_creator
if not isinstance(old_sess_creator, NewSessionCreator) \
or self.config.session_config is not None:
or self._config.session_config is not None:
raise ValueError(
"Cannot set session_creator or session_config for distributed training! "
"To use a custom session config, pass it with tf.train.Server.")
self.config.session_creator = get_distributed_session_creator(self.server)
self._config.session_creator = get_distributed_session_creator(self.server)
@property
def vs_name_for_predictor(self):
......
......@@ -71,10 +71,10 @@ class SyncMultiGPUTrainerParameterServer(Trainer):
callbacks = self._input_source.setup(self.model.get_inputs_desc())
self.train_op = SyncMultiGPUParameterServerBuilder(
self.config.tower, self._ps_device).build(
self._config.tower, self._ps_device).build(
self._input_source, self.model.build_graph_get_cost, self.model.get_optimizer)
self.config.callbacks.extend(callbacks)
self._config.callbacks.extend(callbacks)
def SyncMultiGPUTrainer(config):
......@@ -102,13 +102,13 @@ class SyncMultiGPUTrainerReplicated(Trainer):
def _setup(self):
callbacks = self._input_source.setup(self.model.get_inputs_desc())
self.train_op, post_init_op = SyncMultiGPUReplicatedBuilder(self.config.tower).build(
self.train_op, post_init_op = SyncMultiGPUReplicatedBuilder(self._config.tower).build(
self._input_source, self.model.build_graph_get_cost, self.model.get_optimizer)
cb = RunOp(
lambda: post_init_op,
run_before=True, run_as_trigger=True, verbose=True)
self.config.callbacks.extend(callbacks + [cb])
self._config.callbacks.extend(callbacks + [cb])
class AsyncMultiGPUTrainer(Trainer):
......@@ -130,7 +130,7 @@ class AsyncMultiGPUTrainer(Trainer):
callbacks = self._input_source.setup(self.model.get_inputs_desc())
self.train_op = AsyncMultiGPUBuilder(
self.config.tower, self._scale_gradient).build(
self._config.tower, self._scale_gradient).build(
self._input_source, self.model.build_graph_get_cost, self.model.get_optimizer)
self.config.callbacks.extend(callbacks)
self._config.callbacks.extend(callbacks)
......@@ -44,7 +44,7 @@ class SimpleTrainer(Trainer):
self.train_op = SimpleBuilder().build(
self._input_source, self.model.build_graph_get_cost, self.model.get_optimizer)
self.config.callbacks.extend(cbs)
self._config.callbacks.extend(cbs)
def QueueInputTrainer(config, input_queue=None):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment