Commit 3b7e7c55 authored by Yuxin Wu's avatar Yuxin Wu

Call Trainer._setup() in constructor, to logically allow custom graph built outside trainer. (#318)

parent ec1bea93
...@@ -64,6 +64,8 @@ class Trainer(object): ...@@ -64,6 +64,8 @@ class Trainer(object):
self.monitors = [] self.monitors = []
self._epoch_num = None self._epoch_num = None
self._setup() # subclass will setup the graph
@property @property
def epoch_num(self): def epoch_num(self):
if self._epoch_num is not None: if self._epoch_num is not None:
...@@ -76,9 +78,6 @@ class Trainer(object): ...@@ -76,9 +78,6 @@ class Trainer(object):
""" """
Use this method before :meth:`Trainer._setup` finishes, Use this method before :meth:`Trainer._setup` finishes,
to register a callback to the trainer. to register a callback to the trainer.
The hooks of the registered callback will be bind to the
`self.hooked_sess` session.
""" """
assert isinstance(cb, Callback), cb assert isinstance(cb, Callback), cb
assert not isinstance(self._callbacks, Callbacks), \ assert not isinstance(self._callbacks, Callbacks), \
...@@ -120,8 +119,6 @@ class Trainer(object): ...@@ -120,8 +119,6 @@ class Trainer(object):
""" """
Setup the trainer and be ready for the main loop. Setup the trainer and be ready for the main loop.
""" """
self._setup() # subclass will setup the graph
self.register_callback(MaintainStepCounter()) self.register_callback(MaintainStepCounter())
for cb in self.config.callbacks: for cb in self.config.callbacks:
self.register_callback(cb) self.register_callback(cb)
...@@ -160,9 +157,15 @@ class Trainer(object): ...@@ -160,9 +157,15 @@ class Trainer(object):
self.hooked_sess = tf.train.MonitoredSession( self.hooked_sess = tf.train.MonitoredSession(
session_creator=ReuseSessionCreator(self.sess), hooks=hooks) session_creator=ReuseSessionCreator(self.sess), hooks=hooks)
@abstractmethod
def _setup(self): def _setup(self):
""" setup Trainer-specific stuff for training""" """
Build the entire graph for training.
Responsible for setup InputSource as well (including registering InputSource callbacks)
Since this method will get called in constructor only,
you can simply leave it empty and build your graph outside the trainer.
"""
pass
@property @property
def global_step(self): def global_step(self):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment