Commit 88b99f38 authored by Yuxin Wu's avatar Yuxin Wu

public register_callback

parent b4085f77
......@@ -86,7 +86,7 @@ class GANTrainer(TowerTrainer):
self.set_tower_func(tower_func)
for cb in cbs:
self._register_callback(cb)
self.register_callback(cb)
class SeparateGANTrainer(TowerTrainer):
......@@ -116,7 +116,7 @@ class SeparateGANTrainer(TowerTrainer):
self.set_tower_func(tower_func)
for cb in cbs:
self._register_callback(cb)
self.register_callback(cb)
def run_step(self):
if self.global_step % (self._d_period) == 0:
......@@ -162,7 +162,7 @@ class MultiGPUGANTrainer(TowerTrainer):
self.train_op = d_min
self.set_tower_func(tower_func)
for cb in cbs:
self._register_callback(cb)
self.register_callback(cb)
class RandomZData(DataFlow):
......
......@@ -107,7 +107,7 @@ class Trainer(object):
def _register_callback(self, cb):
"""
Register a callback to the trainer.
It can only be called before :meth:`Trainer.train` gets called.
It can only be called before :meth:`Trainer.train()`.
"""
assert isinstance(cb, Callback), cb
assert not isinstance(self._callbacks, Callbacks), \
......@@ -117,6 +117,8 @@ class Trainer(object):
else:
self._callbacks.append(cb)
register_callback = _register_callback
def run_step(self):
"""
Defines what to do in one iteration. The default is:
......@@ -138,15 +140,15 @@ class Trainer(object):
"""
describe_trainable_vars() # TODO weird
self._register_callback(MaintainStepCounter())
self.register_callback(MaintainStepCounter())
for cb in callbacks:
self._register_callback(cb)
self.register_callback(cb)
for cb in self._callbacks:
assert not isinstance(cb, TrainingMonitor), "Monitor cannot be pre-registered for now!"
for m in monitors:
self._register_callback(m)
self.register_callback(m)
self.monitors = Monitors(monitors)
self._register_callback(self.monitors) # monitors is also a callback
self.register_callback(self.monitors) # monitors is also a callback
# some final operations that might modify the graph
logger.info("Setup callbacks graph ...")
......
......@@ -128,7 +128,7 @@ class SingleCostTrainer(TowerTrainer):
train_callbacks = self._setup_graph(input, get_cost_fn, get_opt_fn)
internal_callbacks = input_callbacks + train_callbacks
for cb in internal_callbacks:
self._register_callback(cb)
self.register_callback(cb)
@abstractmethod
def _setup_graph(self, input, get_cost_fn, get_opt_fn):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment