Commit 9e995a8d authored by Yuxin Wu's avatar Yuxin Wu

Use `TrainLoop` to manage the loop, and delegate properties. Hide `trainer.config`

parent 7efe4939
...@@ -223,9 +223,9 @@ class EvalCallback(Callback): ...@@ -223,9 +223,9 @@ class EvalCallback(Callback):
self.df = PrefetchDataZMQ(get_eval_dataflow(), 1) self.df = PrefetchDataZMQ(get_eval_dataflow(), 1)
EVAL_TIMES = 5 # eval 5 times during training EVAL_TIMES = 5 # eval 5 times during training
interval = self.trainer.config.max_epoch // (EVAL_TIMES + 1) interval = self.trainer.max_epoch // (EVAL_TIMES + 1)
self.epochs_to_eval = set([interval * k for k in range(1, EVAL_TIMES)]) self.epochs_to_eval = set([interval * k for k in range(1, EVAL_TIMES)])
self.epochs_to_eval.add(self.trainer.config.max_epoch) self.epochs_to_eval.add(self.trainer.max_epoch)
get_tf_nms() # just to make sure the nms part of graph is created get_tf_nms() # just to make sure the nms part of graph is created
def _eval(self): def _eval(self):
......
...@@ -45,7 +45,7 @@ class Callback(object): ...@@ -45,7 +45,7 @@ class Callback(object):
_chief_only = True _chief_only = True
def setup_graph(self, trainer): def setup_graph(self, trainer):
self._steps_per_epoch = trainer.config.steps_per_epoch self._steps_per_epoch = trainer.steps_per_epoch
self.trainer = trainer self.trainer = trainer
self.graph = tf.get_default_graph() self.graph = tf.get_default_graph()
scope_name = type(self).__name__ scope_name = type(self).__name__
......
...@@ -124,7 +124,7 @@ class InferenceRunner(InferenceRunnerBase): ...@@ -124,7 +124,7 @@ class InferenceRunner(InferenceRunnerBase):
def _setup_graph(self): def _setup_graph(self):
assert self.trainer.model is not None assert self.trainer.model is not None
# Use predict_tower in train config. either gpuid or -1 # Use predict_tower in train config. either gpuid or -1
tower_id = self.trainer.config.predict_tower[0] tower_id = self.trainer._config.predict_tower[0]
device = '/gpu:{}'.format(tower_id) if tower_id >= 0 else '/cpu:0' device = '/gpu:{}'.format(tower_id) if tower_id >= 0 else '/cpu:0'
input_callbacks = self._input_source.setup(self.trainer.model.get_inputs_desc()) input_callbacks = self._input_source.setup(self.trainer.model.get_inputs_desc())
......
...@@ -256,13 +256,13 @@ class JSONWriter(TrainingMonitor): ...@@ -256,13 +256,13 @@ class JSONWriter(TrainingMonitor):
pass pass
else: else:
logger.info("Found training history from JSON, now starting from epoch number {}.".format(epoch)) logger.info("Found training history from JSON, now starting from epoch number {}.".format(epoch))
self.trainer.config.starting_epoch = epoch self.trainer.starting_epoch = epoch
else: else:
self._stats = [] self._stats = []
self._stat_now = {} self._stat_now = {}
self._last_gs = -1 self._last_gs = -1
self._total = self.trainer.config.steps_per_epoch self._total = self.trainer.steps_per_epoch
def _trigger_step(self): def _trigger_step(self):
# will do this in trigger_epoch # will do this in trigger_epoch
...@@ -327,7 +327,7 @@ class ScalarPrinter(TrainingMonitor): ...@@ -327,7 +327,7 @@ class ScalarPrinter(TrainingMonitor):
def _setup_graph(self): def _setup_graph(self):
self._dic = {} self._dic = {}
self._total = self.trainer.config.steps_per_epoch self._total = self.trainer.steps_per_epoch
def _trigger_step(self): def _trigger_step(self):
if self._enable_step: if self._enable_step:
......
...@@ -67,7 +67,7 @@ class ProgressBar(Callback): ...@@ -67,7 +67,7 @@ class ProgressBar(Callback):
def _before_train(self): def _before_train(self):
self._last_updated = self.local_step self._last_updated = self.local_step
self._total = self.trainer.config.steps_per_epoch self._total = self.trainer.steps_per_epoch
self._tqdm_args = get_tqdm_kwargs(leave=True) self._tqdm_args = get_tqdm_kwargs(leave=True)
self._fetches = get_op_or_tensor_by_name(self._names) or None self._fetches = get_op_or_tensor_by_name(self._names) or None
...@@ -133,4 +133,4 @@ class MaintainStepCounter(Callback): ...@@ -133,4 +133,4 @@ class MaintainStepCounter(Callback):
def _after_run(self, _, __): def _after_run(self, _, __):
# Keep python-side global_step in agreement with TF-side # Keep python-side global_step in agreement with TF-side
self.trainer._global_step += 1 self.trainer.loop._global_step += 1
...@@ -70,7 +70,7 @@ class MergeAllSummaries_RunWithOp(Callback): ...@@ -70,7 +70,7 @@ class MergeAllSummaries_RunWithOp(Callback):
self._fetches = tf.train.SessionRunArgs(self.summary_op) self._fetches = tf.train.SessionRunArgs(self.summary_op)
else: else:
self._fetches = None self._fetches = None
self._total = self.trainer.config.steps_per_epoch self._total = self.trainer.steps_per_epoch
def _need_run(self): def _need_run(self):
if self.local_step == self._total - 1: if self.local_step == self._total - 1:
......
...@@ -30,6 +30,63 @@ class StopTraining(BaseException): ...@@ -30,6 +30,63 @@ class StopTraining(BaseException):
pass pass
class TrainLoop(object):
"""
Manage the double for loop.
"""
def __init__(self):
self._epoch_num = 0
self._global_step = 0
self._local_step = -1
def config(self, steps_per_epoch, starting_epoch, max_epoch):
"""
Configure the loop given the settings.
"""
self.starting_epoch = starting_epoch
self.max_epoch = max_epoch
self.steps_per_epoch = steps_per_epoch
self._epoch_num = starting_epoch - 1
def update_global_step(self):
"""
Update the Python-side global_step from TF.
This must be called under initialized default session.
"""
self._global_step = get_global_step_value()
@property
def epoch_num(self):
"""
The number of the currently ongoing epoch.
An epoch is defined to cover the moment before calling `before_epoch` until after calling `trigger_epoch`.
i.e., in the `trigger_epoch` of epoch 3, `self.epoch_num` is 3.
If you need use `self.epoch_num` in your callback, you'll need to know this.
"""
return self._epoch_num
@property
def global_step(self):
"""
The tensorflow global_step, i.e. how many times ``hooked_sess.run`` has been called.
Note:
1. global_step is incremented **after** each ``hooked_sess.run`` returns from TF runtime.
2. If you make zero or more than one calls to ``hooked_sess.run`` in one
:meth:`run_step`, local_step and global_step may increment at different speed.
"""
return self._global_step
@property
def local_step(self):
"""
The number of (tensorpack) steps that have finished in the current epoch.
"""
return self._local_step
class Trainer(object): class Trainer(object):
""" Base class for a trainer. """ Base class for a trainer.
...@@ -39,7 +96,6 @@ class Trainer(object): ...@@ -39,7 +96,6 @@ class Trainer(object):
sess (tf.Session): the current session in use. sess (tf.Session): the current session in use.
hooked_sess (tf.train.MonitoredSession): the session with hooks. hooked_sess (tf.train.MonitoredSession): the session with hooks.
monitors (Monitors): the monitors. Other callbacks can use it for logging. monitors (Monitors): the monitors. Other callbacks can use it for logging.
local_step (int): the number of (tensorpack) steps that have finished in the current epoch.
""" """
# step attr only available after before_train? # step attr only available after before_train?
...@@ -51,33 +107,16 @@ class Trainer(object): ...@@ -51,33 +107,16 @@ class Trainer(object):
config (TrainConfig): the train config. config (TrainConfig): the train config.
""" """
assert isinstance(config, TrainConfig), type(config) assert isinstance(config, TrainConfig), type(config)
self.config = config self._config = config
self.model = config.model self.model = config.model
self.local_step = -1
self._callbacks = [] self._callbacks = []
self.monitors = [] self.monitors = []
self._epoch_num = None self.loop = TrainLoop()
self._global_step = 0 self.loop.config(config.steps_per_epoch, config.starting_epoch, config.max_epoch)
self._setup() # subclass will setup the graph and InputSource self._setup() # subclass will setup the graph and InputSource
@property
def epoch_num(self):
"""
The number of the currently ongoing epoch.
An epoch is defined to cover the moment before calling `before_epoch` until after calling `trigger_epoch`.
i.e., in the `trigger_epoch` of epoch 3, `self.epoch_num` is 3.
If you need use `self.epoch_num` in your callback, you'll need to know this.
"""
if self._epoch_num is not None:
# has started training
return self._epoch_num
else:
return self.config.starting_epoch - 1
def register_callback(self, cb): def register_callback(self, cb):
""" """
Register a callback to the trainer. Register a callback to the trainer.
...@@ -129,9 +168,9 @@ class Trainer(object): ...@@ -129,9 +168,9 @@ class Trainer(object):
Setup the trainer and be ready for the main loop. Setup the trainer and be ready for the main loop.
""" """
self.register_callback(MaintainStepCounter()) self.register_callback(MaintainStepCounter())
for cb in self.config.callbacks: for cb in self._config.callbacks:
self.register_callback(cb) self.register_callback(cb)
for m in self.config.monitors: for m in self._config.monitors:
self.register_monitor(m) self.register_monitor(m)
self.monitors = Monitors(self.monitors) self.monitors = Monitors(self.monitors)
self.register_callback(self.monitors) self.register_callback(self.monitors)
...@@ -148,9 +187,9 @@ class Trainer(object): ...@@ -148,9 +187,9 @@ class Trainer(object):
if self.is_chief: if self.is_chief:
logger.info("Initializing the session ...") logger.info("Initializing the session ...")
self.config.session_init.init(self.sess) self._config.session_init.init(self.sess)
else: else:
assert isinstance(self.config.session_init, JustCurrentSession), \ assert isinstance(self._config.session_init, JustCurrentSession), \
"session_init is only valid for chief worker session!" "session_init is only valid for chief worker session!"
self.sess.graph.finalize() self.sess.graph.finalize()
...@@ -162,7 +201,7 @@ class Trainer(object): ...@@ -162,7 +201,7 @@ class Trainer(object):
and self.hooked_sess (the session with hooks and coordinator) and self.hooked_sess (the session with hooks and coordinator)
""" """
hooks = self._callbacks.get_hooks() hooks = self._callbacks.get_hooks()
self.sess = self.config.session_creator.create_session() self.sess = self._config.session_creator.create_session()
self.hooked_sess = tf.train.MonitoredSession( self.hooked_sess = tf.train.MonitoredSession(
session_creator=ReuseSessionCreator(self.sess), hooks=hooks) session_creator=ReuseSessionCreator(self.sess), hooks=hooks)
...@@ -176,41 +215,29 @@ class Trainer(object): ...@@ -176,41 +215,29 @@ class Trainer(object):
""" """
pass pass
@property
def global_step(self):
"""
The tensorflow global_step, i.e. how many times ``hooked_sess.run`` has been called.
Note:
1. global_step is incremented **after** each ``hooked_sess.run`` returns from TF runtime.
2. If you make zero or more than one calls to ``hooked_sess.run`` in one
:meth:`run_step`, local_step and global_step may increment at different speed.
"""
return self._global_step
def main_loop(self): def main_loop(self):
""" """
Run the main training loop. Run the main training loop.
""" """
with self.sess.as_default(): with self.sess.as_default():
self._global_step = get_global_step_value() self.loop.update_global_step()
try: try:
self._callbacks.before_train() self._callbacks.before_train()
# refresh global step (might have changed by callbacks) TODO ugly # refresh global step (might have changed by callbacks) TODO ugly
self._global_step = get_global_step_value() self.loop.update_global_step()
for self._epoch_num in range( for self.loop._epoch_num in range(
self.config.starting_epoch, self.config.max_epoch + 1): self.loop.starting_epoch, self.loop.max_epoch + 1):
logger.info("Start Epoch {} ...".format(self._epoch_num)) logger.info("Start Epoch {} ...".format(self.loop.epoch_num))
start_time = time.time() start_time = time.time()
self._callbacks.before_epoch() self._callbacks.before_epoch()
for self.local_step in range(self.config.steps_per_epoch): for self.loop._local_step in range(self.loop.steps_per_epoch):
if self.hooked_sess.should_stop(): if self.hooked_sess.should_stop():
return return
self.run_step() # implemented by subclass self.run_step() # implemented by subclass
self._callbacks.trigger_step() self._callbacks.trigger_step()
self._callbacks.after_epoch() self._callbacks.after_epoch()
logger.info("Epoch {} (global_step {}) finished, time:{:.2f} sec.".format( logger.info("Epoch {} (global_step {}) finished, time:{:.2f} sec.".format(
self._epoch_num, self.global_step, time.time() - start_time)) self.loop.epoch_num, self.loop.global_step, time.time() - start_time))
# trigger epoch outside the timing region. # trigger epoch outside the timing region.
self._callbacks.trigger_epoch() self._callbacks.trigger_epoch()
...@@ -256,6 +283,19 @@ class Trainer(object): ...@@ -256,6 +283,19 @@ class Trainer(object):
return "" return ""
def _delegate_attr(name):
"""
Delegate property to self.loop
"""
setattr(Trainer, name, property(
lambda self: getattr(self.loop, name)))
for name in ['global_step', 'local_step', 'steps_per_epoch',
'epoch_num', 'starting_epoch', 'max_epoch']:
_delegate_attr(name)
def launch_train( def launch_train(
run_step, model=None, callbacks=None, extra_callbacks=None, monitors=None, run_step, model=None, callbacks=None, extra_callbacks=None, monitors=None,
session_creator=None, session_config=None, session_init=None, session_creator=None, session_config=None, session_init=None,
......
...@@ -88,7 +88,7 @@ class DistributedTrainerReplicated(Trainer): ...@@ -88,7 +88,7 @@ class DistributedTrainerReplicated(Trainer):
# whether something should be global or local. We now assume # whether something should be global or local. We now assume
# they should be local. # they should be local.
cbs = self._input_source.setup(self.model.get_inputs_desc()) cbs = self._input_source.setup(self.model.get_inputs_desc())
self.config.callbacks.extend(cbs) self._config.callbacks.extend(cbs)
self.train_op, initial_sync_op, model_sync_op = self._builder.build( self.train_op, initial_sync_op, model_sync_op = self._builder.build(
self._input_source, self.model.build_graph_get_cost, self.model.get_optimizer) self._input_source, self.model.build_graph_get_cost, self.model.get_optimizer)
...@@ -110,14 +110,14 @@ class DistributedTrainerReplicated(Trainer): ...@@ -110,14 +110,14 @@ class DistributedTrainerReplicated(Trainer):
self._set_session_creator() self._set_session_creator()
def _set_session_creator(self): def _set_session_creator(self):
old_sess_creator = self.config.session_creator old_sess_creator = self._config.session_creator
if not isinstance(old_sess_creator, NewSessionCreator) \ if not isinstance(old_sess_creator, NewSessionCreator) \
or self.config.session_config is not None: or self._config.session_config is not None:
raise ValueError( raise ValueError(
"Cannot set session_creator or session_config for distributed training! " "Cannot set session_creator or session_config for distributed training! "
"To use a custom session config, pass it with tf.train.Server.") "To use a custom session config, pass it with tf.train.Server.")
self.config.session_creator = get_distributed_session_creator(self.server) self._config.session_creator = get_distributed_session_creator(self.server)
@property @property
def vs_name_for_predictor(self): def vs_name_for_predictor(self):
......
...@@ -71,10 +71,10 @@ class SyncMultiGPUTrainerParameterServer(Trainer): ...@@ -71,10 +71,10 @@ class SyncMultiGPUTrainerParameterServer(Trainer):
callbacks = self._input_source.setup(self.model.get_inputs_desc()) callbacks = self._input_source.setup(self.model.get_inputs_desc())
self.train_op = SyncMultiGPUParameterServerBuilder( self.train_op = SyncMultiGPUParameterServerBuilder(
self.config.tower, self._ps_device).build( self._config.tower, self._ps_device).build(
self._input_source, self.model.build_graph_get_cost, self.model.get_optimizer) self._input_source, self.model.build_graph_get_cost, self.model.get_optimizer)
self.config.callbacks.extend(callbacks) self._config.callbacks.extend(callbacks)
def SyncMultiGPUTrainer(config): def SyncMultiGPUTrainer(config):
...@@ -102,13 +102,13 @@ class SyncMultiGPUTrainerReplicated(Trainer): ...@@ -102,13 +102,13 @@ class SyncMultiGPUTrainerReplicated(Trainer):
def _setup(self): def _setup(self):
callbacks = self._input_source.setup(self.model.get_inputs_desc()) callbacks = self._input_source.setup(self.model.get_inputs_desc())
self.train_op, post_init_op = SyncMultiGPUReplicatedBuilder(self.config.tower).build( self.train_op, post_init_op = SyncMultiGPUReplicatedBuilder(self._config.tower).build(
self._input_source, self.model.build_graph_get_cost, self.model.get_optimizer) self._input_source, self.model.build_graph_get_cost, self.model.get_optimizer)
cb = RunOp( cb = RunOp(
lambda: post_init_op, lambda: post_init_op,
run_before=True, run_as_trigger=True, verbose=True) run_before=True, run_as_trigger=True, verbose=True)
self.config.callbacks.extend(callbacks + [cb]) self._config.callbacks.extend(callbacks + [cb])
class AsyncMultiGPUTrainer(Trainer): class AsyncMultiGPUTrainer(Trainer):
...@@ -130,7 +130,7 @@ class AsyncMultiGPUTrainer(Trainer): ...@@ -130,7 +130,7 @@ class AsyncMultiGPUTrainer(Trainer):
callbacks = self._input_source.setup(self.model.get_inputs_desc()) callbacks = self._input_source.setup(self.model.get_inputs_desc())
self.train_op = AsyncMultiGPUBuilder( self.train_op = AsyncMultiGPUBuilder(
self.config.tower, self._scale_gradient).build( self._config.tower, self._scale_gradient).build(
self._input_source, self.model.build_graph_get_cost, self.model.get_optimizer) self._input_source, self.model.build_graph_get_cost, self.model.get_optimizer)
self.config.callbacks.extend(callbacks) self._config.callbacks.extend(callbacks)
...@@ -44,7 +44,7 @@ class SimpleTrainer(Trainer): ...@@ -44,7 +44,7 @@ class SimpleTrainer(Trainer):
self.train_op = SimpleBuilder().build( self.train_op = SimpleBuilder().build(
self._input_source, self.model.build_graph_get_cost, self.model.get_optimizer) self._input_source, self.model.build_graph_get_cost, self.model.get_optimizer)
self.config.callbacks.extend(cbs) self._config.callbacks.extend(cbs)
def QueueInputTrainer(config, input_queue=None): def QueueInputTrainer(config, input_queue=None):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment