Commit 4c5cdf9b authored by Yuxin Wu's avatar Yuxin Wu

Fix some use of loop attributes. These attributes should be available only from before_train

parent 4126a583
......@@ -45,7 +45,6 @@ class Callback(object):
_chief_only = True
def setup_graph(self, trainer):
self._steps_per_epoch = trainer.steps_per_epoch
self.trainer = trainer
self.graph = tf.get_default_graph()
scope_name = type(self).__name__
......
......@@ -38,7 +38,8 @@ class RunOp(Callback):
uses this callback to update target network.
"""
if not callable(op):
op = lambda: op # noqa
self.setup_func = lambda: op # noqa
else:
self.setup_func = op
self.run_before = run_before
self.run_as_trigger = run_as_trigger
......
......@@ -239,12 +239,11 @@ class JSONWriter(TrainingMonitor):
logger.warn("logger directory was not set. Ignore JSONWriter.")
return NoOpMonitor()
def _setup_graph(self):
def _before_train(self):
self._dir = logger.LOG_DIR
self._fname = os.path.join(self._dir, self.FILENAME)
if os.path.isfile(self._fname):
# TODO make a backup first?
logger.info("Found existing JSON at {}, will append to it.".format(self._fname))
with open(self._fname) as f:
self._stats = json.load(f)
......@@ -255,18 +254,18 @@ class JSONWriter(TrainingMonitor):
except Exception:
pass
else:
# TODO is this a good idea?
logger.info("Found training history from JSON, now starting from epoch number {}.".format(epoch))
self.trainer.starting_epoch = epoch
self.trainer.loop.starting_epoch = epoch
else:
self._stats = []
self._stat_now = {}
self._last_gs = -1
self._total = self.trainer.steps_per_epoch
def _trigger_step(self):
# will do this in trigger_epoch
if self.local_step != self._total - 1:
if self.local_step != self.trainer.steps_per_epoch - 1:
self._push()
def _trigger_epoch(self):
......@@ -327,11 +326,10 @@ class ScalarPrinter(TrainingMonitor):
def _setup_graph(self):
self._dic = {}
self._total = self.trainer.steps_per_epoch
def _trigger_step(self):
if self._enable_step:
if self.local_step != self._total - 1:
if self.local_step != self.trainer.steps_per_epoch - 1:
# not the last step
self._print_stat()
else:
......
......@@ -70,10 +70,9 @@ class MergeAllSummaries_RunWithOp(Callback):
self._fetches = tf.train.SessionRunArgs(self.summary_op)
else:
self._fetches = None
self._total = self.trainer.steps_per_epoch
def _need_run(self):
if self.local_step == self._total - 1:
if self.local_step == self.trainer.steps_per_epoch - 1:
return True
if self._period > 0 and (self.local_step + 1) % self._period == 0:
return True
......
......@@ -108,6 +108,7 @@ class TrainConfig(object):
else:
self.session_creator = session_creator
assert session_config is None, "Cannot set both session_creator and session_config!"
# only used by DistributedTrainer for assertion!
self.session_config = session_config
if steps_per_epoch is None:
......
......@@ -69,7 +69,6 @@ class DistributedTrainerReplicated(Trainer):
logger.info("Distributed training on cluster:\n" + str(server.server_def.cluster))
self._input_source = config.data
self.nr_gpu = config.nr_tower
super(DistributedTrainerReplicated, self).__init__(config)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment