Commit e28d616e authored by Yuxin Wu's avatar Yuxin Wu

statholder move to callbacks

parent 6edb1f0d
...@@ -108,6 +108,14 @@ class Callbacks(Callback): ...@@ -108,6 +108,14 @@ class Callbacks(Callback):
if not isinstance(cb.type, (TrainCallbackType, TestCallbackType)): if not isinstance(cb.type, (TrainCallbackType, TestCallbackType)):
raise ValueError( raise ValueError(
"Unknown callback running graph {}!".format(str(cb.type))) "Unknown callback running graph {}!".format(str(cb.type)))
# move "StatPrinter" to the last
for cb in cbs:
if isinstance(cb, StatPrinter):
sp = cb
cbs.remove(sp)
cbs.append(sp)
break
print(cbs)
self.cbs = cbs self.cbs = cbs
self.test_callback_context = TestCallbackContext() self.test_callback_context = TestCallbackContext()
......
...@@ -103,3 +103,7 @@ class StatPrinter(Callback): ...@@ -103,3 +103,7 @@ class StatPrinter(Callback):
def _before_train(self): def _before_train(self):
self.trainer.stat_holder.set_print_tag(self.print_tag) self.trainer.stat_holder.set_print_tag(self.print_tag)
def _trigger_epoch(self):
self.trainer.stat_holder.add_stat('global_step', self.global_step)
self.trainer.stat_holder.finalize()
...@@ -68,8 +68,6 @@ class Trainer(object): ...@@ -68,8 +68,6 @@ class Trainer(object):
self._trigger_epoch() self._trigger_epoch()
self.config.callbacks.trigger_epoch() self.config.callbacks.trigger_epoch()
self.summary_writer.flush() self.summary_writer.flush()
self.stat_holder.add_stat('global_step', self.global_step)
self.stat_holder.finalize()
@abstractmethod @abstractmethod
def _trigger_epoch(self): def _trigger_epoch(self):
......
...@@ -21,13 +21,12 @@ class TrainConfig(object): ...@@ -21,13 +21,12 @@ class TrainConfig(object):
:param dataset: the dataset to train. a `DataFlow` instance. :param dataset: the dataset to train. a `DataFlow` instance.
:param optimizer: a `tf.train.Optimizer` instance defining the optimizer for trainig. :param optimizer: a `tf.train.Optimizer` instance defining the optimizer for trainig.
:param callbacks: a `callback.Callbacks` instance. Define :param callbacks: a `callback.Callbacks` instance. Define
the callbacks to perform during training. It has to contain a the callbacks to perform during training.
SummaryWriter and a PeriodicSaver
:param session_config: a `tf.ConfigProto` instance to instantiate the :param session_config: a `tf.ConfigProto` instance to instantiate the
session. default to a session running 1 GPU. session. default to a session running 1 GPU.
:param session_init: a `sessinit.SessionInit` instance to :param session_init: a `sessinit.SessionInit` instance to
initialize variables of a session. default to a new session. initialize variables of a session. default to a new session.
:param model: a `ModelDesc` instance.j :param model: a `ModelDesc` instance.
:param starting_epoch: int. default to be 1. :param starting_epoch: int. default to be 1.
:param step_per_epoch: the number of steps (SGD updates) to perform in each epoch. :param step_per_epoch: the number of steps (SGD updates) to perform in each epoch.
:param max_epoch: maximum number of epoch to run training. default to inf :param max_epoch: maximum number of epoch to run training. default to inf
...@@ -63,9 +62,7 @@ class TrainConfig(object): ...@@ -63,9 +62,7 @@ class TrainConfig(object):
self.extra_threads_procs = kwargs.pop('extra_threads_procs', []) self.extra_threads_procs = kwargs.pop('extra_threads_procs', [])
assert len(kwargs) == 0, 'Unknown arguments: {}'.format(str(kwargs.keys())) assert len(kwargs) == 0, 'Unknown arguments: {}'.format(str(kwargs.keys()))
def set_tower(self, **kwargs): def set_tower(self, nr_tower=None, tower=None):
nr_tower = kwargs.pop('nr_tower', None)
tower = kwargs.pop('tower', None)
assert nr_tower is None or tower is None, "Cannot set both nr_tower and tower!" assert nr_tower is None or tower is None, "Cannot set both nr_tower and tower!"
if nr_tower: if nr_tower:
tower = list(range(nr_tower)) tower = list(range(nr_tower))
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment