Commit 4c5cdf9b authored by Yuxin Wu's avatar Yuxin Wu

Fix some use of loop attributes. These attributes should be available only from before_train

parent 4126a583
...@@ -45,7 +45,6 @@ class Callback(object): ...@@ -45,7 +45,6 @@ class Callback(object):
_chief_only = True _chief_only = True
def setup_graph(self, trainer): def setup_graph(self, trainer):
self._steps_per_epoch = trainer.steps_per_epoch
self.trainer = trainer self.trainer = trainer
self.graph = tf.get_default_graph() self.graph = tf.get_default_graph()
scope_name = type(self).__name__ scope_name = type(self).__name__
......
...@@ -38,7 +38,8 @@ class RunOp(Callback): ...@@ -38,7 +38,8 @@ class RunOp(Callback):
uses this callback to update target network. uses this callback to update target network.
""" """
if not callable(op): if not callable(op):
op = lambda: op # noqa self.setup_func = lambda: op # noqa
else:
self.setup_func = op self.setup_func = op
self.run_before = run_before self.run_before = run_before
self.run_as_trigger = run_as_trigger self.run_as_trigger = run_as_trigger
......
...@@ -239,12 +239,11 @@ class JSONWriter(TrainingMonitor): ...@@ -239,12 +239,11 @@ class JSONWriter(TrainingMonitor):
logger.warn("logger directory was not set. Ignore JSONWriter.") logger.warn("logger directory was not set. Ignore JSONWriter.")
return NoOpMonitor() return NoOpMonitor()
def _setup_graph(self): def _before_train(self):
self._dir = logger.LOG_DIR self._dir = logger.LOG_DIR
self._fname = os.path.join(self._dir, self.FILENAME) self._fname = os.path.join(self._dir, self.FILENAME)
if os.path.isfile(self._fname): if os.path.isfile(self._fname):
# TODO make a backup first?
logger.info("Found existing JSON at {}, will append to it.".format(self._fname)) logger.info("Found existing JSON at {}, will append to it.".format(self._fname))
with open(self._fname) as f: with open(self._fname) as f:
self._stats = json.load(f) self._stats = json.load(f)
...@@ -255,18 +254,18 @@ class JSONWriter(TrainingMonitor): ...@@ -255,18 +254,18 @@ class JSONWriter(TrainingMonitor):
except Exception: except Exception:
pass pass
else: else:
# TODO is this a good idea?
logger.info("Found training history from JSON, now starting from epoch number {}.".format(epoch)) logger.info("Found training history from JSON, now starting from epoch number {}.".format(epoch))
self.trainer.starting_epoch = epoch self.trainer.loop.starting_epoch = epoch
else: else:
self._stats = [] self._stats = []
self._stat_now = {} self._stat_now = {}
self._last_gs = -1 self._last_gs = -1
self._total = self.trainer.steps_per_epoch
def _trigger_step(self): def _trigger_step(self):
# will do this in trigger_epoch # will do this in trigger_epoch
if self.local_step != self._total - 1: if self.local_step != self.trainer.steps_per_epoch - 1:
self._push() self._push()
def _trigger_epoch(self): def _trigger_epoch(self):
...@@ -327,11 +326,10 @@ class ScalarPrinter(TrainingMonitor): ...@@ -327,11 +326,10 @@ class ScalarPrinter(TrainingMonitor):
def _setup_graph(self): def _setup_graph(self):
self._dic = {} self._dic = {}
self._total = self.trainer.steps_per_epoch
def _trigger_step(self): def _trigger_step(self):
if self._enable_step: if self._enable_step:
if self.local_step != self._total - 1: if self.local_step != self.trainer.steps_per_epoch - 1:
# not the last step # not the last step
self._print_stat() self._print_stat()
else: else:
......
...@@ -70,10 +70,9 @@ class MergeAllSummaries_RunWithOp(Callback): ...@@ -70,10 +70,9 @@ class MergeAllSummaries_RunWithOp(Callback):
self._fetches = tf.train.SessionRunArgs(self.summary_op) self._fetches = tf.train.SessionRunArgs(self.summary_op)
else: else:
self._fetches = None self._fetches = None
self._total = self.trainer.steps_per_epoch
def _need_run(self): def _need_run(self):
if self.local_step == self._total - 1: if self.local_step == self.trainer.steps_per_epoch - 1:
return True return True
if self._period > 0 and (self.local_step + 1) % self._period == 0: if self._period > 0 and (self.local_step + 1) % self._period == 0:
return True return True
......
...@@ -108,6 +108,7 @@ class TrainConfig(object): ...@@ -108,6 +108,7 @@ class TrainConfig(object):
else: else:
self.session_creator = session_creator self.session_creator = session_creator
assert session_config is None, "Cannot set both session_creator and session_config!" assert session_config is None, "Cannot set both session_creator and session_config!"
# only used by DistributedTrainer for assertion!
self.session_config = session_config self.session_config = session_config
if steps_per_epoch is None: if steps_per_epoch is None:
......
...@@ -69,7 +69,6 @@ class DistributedTrainerReplicated(Trainer): ...@@ -69,7 +69,6 @@ class DistributedTrainerReplicated(Trainer):
logger.info("Distributed training on cluster:\n" + str(server.server_def.cluster)) logger.info("Distributed training on cluster:\n" + str(server.server_def.cluster))
self._input_source = config.data self._input_source = config.data
self.nr_gpu = config.nr_tower
super(DistributedTrainerReplicated, self).__init__(config) super(DistributedTrainerReplicated, self).__init__(config)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment