Commit 88b99f38 authored by Yuxin Wu's avatar Yuxin Wu

public register_callback

parent b4085f77
...@@ -86,7 +86,7 @@ class GANTrainer(TowerTrainer): ...@@ -86,7 +86,7 @@ class GANTrainer(TowerTrainer):
self.set_tower_func(tower_func) self.set_tower_func(tower_func)
for cb in cbs: for cb in cbs:
self._register_callback(cb) self.register_callback(cb)
class SeparateGANTrainer(TowerTrainer): class SeparateGANTrainer(TowerTrainer):
...@@ -116,7 +116,7 @@ class SeparateGANTrainer(TowerTrainer): ...@@ -116,7 +116,7 @@ class SeparateGANTrainer(TowerTrainer):
self.set_tower_func(tower_func) self.set_tower_func(tower_func)
for cb in cbs: for cb in cbs:
self._register_callback(cb) self.register_callback(cb)
def run_step(self): def run_step(self):
if self.global_step % (self._d_period) == 0: if self.global_step % (self._d_period) == 0:
...@@ -162,7 +162,7 @@ class MultiGPUGANTrainer(TowerTrainer): ...@@ -162,7 +162,7 @@ class MultiGPUGANTrainer(TowerTrainer):
self.train_op = d_min self.train_op = d_min
self.set_tower_func(tower_func) self.set_tower_func(tower_func)
for cb in cbs: for cb in cbs:
self._register_callback(cb) self.register_callback(cb)
class RandomZData(DataFlow): class RandomZData(DataFlow):
......
...@@ -107,7 +107,7 @@ class Trainer(object): ...@@ -107,7 +107,7 @@ class Trainer(object):
def _register_callback(self, cb): def _register_callback(self, cb):
""" """
Register a callback to the trainer. Register a callback to the trainer.
It can only be called before :meth:`Trainer.train` gets called. It can only be called before :meth:`Trainer.train()`.
""" """
assert isinstance(cb, Callback), cb assert isinstance(cb, Callback), cb
assert not isinstance(self._callbacks, Callbacks), \ assert not isinstance(self._callbacks, Callbacks), \
...@@ -117,6 +117,8 @@ class Trainer(object): ...@@ -117,6 +117,8 @@ class Trainer(object):
else: else:
self._callbacks.append(cb) self._callbacks.append(cb)
register_callback = _register_callback
def run_step(self): def run_step(self):
""" """
Defines what to do in one iteration. The default is: Defines what to do in one iteration. The default is:
...@@ -138,15 +140,15 @@ class Trainer(object): ...@@ -138,15 +140,15 @@ class Trainer(object):
""" """
describe_trainable_vars() # TODO weird describe_trainable_vars() # TODO weird
self._register_callback(MaintainStepCounter()) self.register_callback(MaintainStepCounter())
for cb in callbacks: for cb in callbacks:
self._register_callback(cb) self.register_callback(cb)
for cb in self._callbacks: for cb in self._callbacks:
assert not isinstance(cb, TrainingMonitor), "Monitor cannot be pre-registered for now!" assert not isinstance(cb, TrainingMonitor), "Monitor cannot be pre-registered for now!"
for m in monitors: for m in monitors:
self._register_callback(m) self.register_callback(m)
self.monitors = Monitors(monitors) self.monitors = Monitors(monitors)
self._register_callback(self.monitors) # monitors is also a callback self.register_callback(self.monitors) # monitors is also a callback
# some final operations that might modify the graph # some final operations that might modify the graph
logger.info("Setup callbacks graph ...") logger.info("Setup callbacks graph ...")
......
...@@ -128,7 +128,7 @@ class SingleCostTrainer(TowerTrainer): ...@@ -128,7 +128,7 @@ class SingleCostTrainer(TowerTrainer):
train_callbacks = self._setup_graph(input, get_cost_fn, get_opt_fn) train_callbacks = self._setup_graph(input, get_cost_fn, get_opt_fn)
internal_callbacks = input_callbacks + train_callbacks internal_callbacks = input_callbacks + train_callbacks
for cb in internal_callbacks: for cb in internal_callbacks:
self._register_callback(cb) self.register_callback(cb)
@abstractmethod @abstractmethod
def _setup_graph(self, input, get_cost_fn, get_opt_fn): def _setup_graph(self, input, get_cost_fn, get_opt_fn):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment