Commit 37ecbd4e authored by Yuxin Wu's avatar Yuxin Wu

docs update

parent ebeaa046
...@@ -77,4 +77,4 @@ These features may not be always useful, but think about how messy the main loop ...@@ -77,4 +77,4 @@ These features may not be always useful, but think about how messy the main loop
were to write the logic together with the loops. were to write the logic together with the loops.
See [Write a callback](http://tensorpack.readthedocs.io/en/latest/tutorial/extend/callback.html) See [Write a callback](http://tensorpack.readthedocs.io/en/latest/tutorial/extend/callback.html)
on how to implement a callback. for details on how callbacks work, what they can do, and how to write them.
## Write a Callback ## Write a Callback
The places where each callback method gets called is demonstrated in this snippet: In the main loop of the trainer,
the callbacks will be called in the order they are given in `TrainConfig`.
The time where each callback method gets called is demonstrated in this snippet:
```python ```python
def main_loop(): def train(self):
# ... create graph for the model ... # ... a predefined trainer may create graph for the model here ...
callbacks.setup_graph() callbacks.setup_graph()
# ... create session, initialize session, finalize graph ... # ... create session, initialize session, finalize graph ...
# start training: # start training:
...@@ -13,7 +15,7 @@ def main_loop(): ...@@ -13,7 +15,7 @@ def main_loop():
for epoch in range(epoch_start, epoch_end): for epoch in range(epoch_start, epoch_end):
callbacks.before_epoch() callbacks.before_epoch()
for step in range(steps_per_epoch): for step in range(steps_per_epoch):
run_one_step() # callbacks.{before,after}_run are hooked with session self.run_step() # callbacks.{before,after}_run are hooked with session
callbacks.trigger_step() callbacks.trigger_step()
callbacks.after_epoch() callbacks.after_epoch()
callbacks.trigger_epoch() callbacks.trigger_epoch()
...@@ -87,6 +89,7 @@ to let this method run every k steps or every k epochs. ...@@ -87,6 +89,7 @@ to let this method run every k steps or every k epochs.
### What you can do in the callback ### What you can do in the callback
* Access tensors / ops in either training / inference mode (need to create them in `_setup_graph`). * Access tensors / ops in either training / inference mode (need to create them in `_setup_graph`).
To create a callable function under inference mode, use `self.trainer.get_predictor`.
* Write stuff to the monitor backend, by `self.trainer.monitors.put_xxx`. * Write stuff to the monitor backend, by `self.trainer.monitors.put_xxx`.
The monitors might direct your events to TensorFlow events file, JSON file, stdout, etc. The monitors might direct your events to TensorFlow events file, JSON file, stdout, etc.
You can get history monitor data as well. See the docs for [Monitors](http://tensorpack.readthedocs.io/en/latest/modules/callbacks.html#tensorpack.callbacks.Monitors) You can get history monitor data as well. See the docs for [Monitors](http://tensorpack.readthedocs.io/en/latest/modules/callbacks.html#tensorpack.callbacks.Monitors)
......
...@@ -261,6 +261,9 @@ class Trainer(object): ...@@ -261,6 +261,9 @@ class Trainer(object):
def get_predictor(self, input_names, output_names, tower=0): def get_predictor(self, input_names, output_names, tower=0):
""" """
Returns a callable predictor built under ``is_training=False`` tower context.
Note that this method is only valid when this trainer has a ``ModelDesc``.
Args: Args:
input_names (list), output_names(list): list of names input_names (list), output_names(list): list of names
tower (int): build the predictor on device '/gpu:{tower}' or use -1 for '/cpu:0'. tower (int): build the predictor on device '/gpu:{tower}' or use -1 for '/cpu:0'.
...@@ -273,6 +276,8 @@ class Trainer(object): ...@@ -273,6 +276,8 @@ class Trainer(object):
@property @property
def predictor_factory(self): def predictor_factory(self):
assert self.model is not None, \
"Predictor can only be built one Trainer has ModelDesc!"
if not hasattr(self, '_predictor_factory'): if not hasattr(self, '_predictor_factory'):
self._predictor_factory = PredictorFactory( self._predictor_factory = PredictorFactory(
self.model, self.vs_name_for_predictor) self.model, self.vs_name_for_predictor)
......
...@@ -42,7 +42,7 @@ def _getlogger(): ...@@ -42,7 +42,7 @@ def _getlogger():
_logger = _getlogger() _logger = _getlogger()
_LOGGING_METHOD = ['info', 'warning', 'error', 'critical', 'warn', 'exception', 'debug'] _LOGGING_METHOD = ['info', 'warning', 'error', 'critical', 'warn', 'exception', 'debug', 'setLevel']
# export logger functions # export logger functions
for func in _LOGGING_METHOD: for func in _LOGGING_METHOD:
locals()[func] = getattr(_logger, func) locals()[func] = getattr(_logger, func)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment