Commit e465842d authored by Yuxin Wu's avatar Yuxin Wu

update docs

parent c346e924
...@@ -68,11 +68,19 @@ You can customize the trainer by either using or inheriting the base `Trainer` c ...@@ -68,11 +68,19 @@ You can customize the trainer by either using or inheriting the base `Trainer` c
You will need to do two things for a new Trainer: You will need to do two things for a new Trainer:
1. Define the graph. There are 2 ways you can do this: 1. Define the graph. There are 2 ways you can do this:
1. Create any tensors and ops you like, before creating the trainer. 1. Create any tensors and ops you need, before creating the trainer.
2. Create them inside `Trainer.__init__`. 2. Create them inside `Trainer.__init__`.
2. Define what is the iteration. There are 2 ways to define the iteration: 2. Define what is the iteration. There are 2 ways to define the iteration:
1. Set `Trainer.train_op` to a TensorFlow operation. This op will be run by default. 1. Set `Trainer.train_op` to a TensorFlow operation. This op will be run by default.
2. Subclass `Trainer` and override the `run_step()` method. This way you can do something more than running an op. 2. Subclass `Trainer` and override the `run_step()` method. This way you can
do something more than running an op.
Note that trainer has `self.sess` and `self.hooked_sess`: only the hooked
session will trigger the `before_run`/`after_run` callbacks.
If you need more than one `Session.run` in one steps, special care needs
to be taken to choose which session to use, because many states
(global steps, StagingArea, summaries) are maintained through `before_run`/`after_run`.
There are several different [GAN trainers](../../examples/GAN/GAN.py) for reference. There are several different [GAN trainers](../../examples/GAN/GAN.py) for reference.
...@@ -418,7 +418,7 @@ class ResNetFPNModel(DetectionModel): ...@@ -418,7 +418,7 @@ class ResNetFPNModel(DetectionModel):
mrcnn_loss = 0.0 mrcnn_loss = 0.0
wd_cost = regularize_cost( wd_cost = regularize_cost(
'(?:group1|group2|group3|rpn|fpn|fastrcnn|maskrcnn)/.*W', '(?:group1|group2|group3|rpn|fastrcnn|maskrcnn)/.*W',
l2_regularizer(1e-4), name='wd_cost') l2_regularizer(1e-4), name='wd_cost')
total_cost = tf.add_n(rpn_loss_collection + [ total_cost = tf.add_n(rpn_loss_collection + [
......
...@@ -7,7 +7,6 @@ ...@@ -7,7 +7,6 @@
import tensorflow as tf import tensorflow as tf
from .base import Callback from .base import Callback
__all__ = ['CallbackToHook', 'HookToCallback'] __all__ = ['CallbackToHook', 'HookToCallback']
...@@ -17,8 +16,6 @@ class CallbackToHook(tf.train.SessionRunHook): ...@@ -17,8 +16,6 @@ class CallbackToHook(tf.train.SessionRunHook):
You shouldn't need to use this. You shouldn't need to use this.
""" """
_chief_only = False
def __init__(self, cb): def __init__(self, cb):
self._cb = cb self._cb = cb
......
...@@ -22,18 +22,19 @@ else: ...@@ -22,18 +22,19 @@ else:
def send_dataflow_zmq(df, addr, hwm=50, format=None, bind=False): def send_dataflow_zmq(df, addr, hwm=50, format=None, bind=False):
""" """
Run DataFlow and send data to a ZMQ socket addr. Run DataFlow and send data to a ZMQ socket addr.
It will __connect__ to this addr, It will serialize and send each datapoint to this address with a PUSH socket.
serialize and send each datapoint to this addr with a PUSH socket. This function never returns.
This function never returns unless an error is encountered.
Args: Args:
df (DataFlow): Will infinitely loop over the DataFlow. df (DataFlow): Will infinitely loop over the DataFlow.
addr: a ZMQ socket endpoint. addr: a ZMQ socket endpoint.
hwm (int): ZMQ high-water mark (buffer size) hwm (int): ZMQ high-water mark (buffer size)
format (str): The serialization format. format (str): The serialization format.
Default format would use :mod:`tensorpack.utils.serialize`. Default format uses :mod:`tensorpack.utils.serialize`.
An alternate format is 'zmq_ops', used by https://github.com/tensorpack/zmq_ops. This format works with :class:`dataflow.RemoteDataZMQ`.
bind (bool): whether to bind or connect to the endpoint. An alternate format is 'zmq_ops', used by https://github.com/tensorpack/zmq_ops
and :class:`input_source.ZMQInput`.
bind (bool): whether to bind or connect to the endpoint address.
""" """
assert format in [None, 'zmq_op', 'zmq_ops'] assert format in [None, 'zmq_op', 'zmq_ops']
if format is None: if format is None:
...@@ -82,6 +83,8 @@ def send_dataflow_zmq(df, addr, hwm=50, format=None, bind=False): ...@@ -82,6 +83,8 @@ def send_dataflow_zmq(df, addr, hwm=50, format=None, bind=False):
class RemoteDataZMQ(DataFlow): class RemoteDataZMQ(DataFlow):
""" """
Produce data from ZMQ PULL socket(s). Produce data from ZMQ PULL socket(s).
It is the receiver-side counterpart of :func:`send_dataflow_zmq`, which uses :mod:`tensorpack.utils.serialize`
for serialization.
See http://tensorpack.readthedocs.io/en/latest/tutorial/efficient-dataflow.html#distributed-dataflow See http://tensorpack.readthedocs.io/en/latest/tutorial/efficient-dataflow.html#distributed-dataflow
Attributes: Attributes:
......
...@@ -373,7 +373,7 @@ class DummyConstantInput(TensorInput): ...@@ -373,7 +373,7 @@ class DummyConstantInput(TensorInput):
class ZMQInput(TensorInput): class ZMQInput(TensorInput):
""" """
Receive tensors from a ZMQ endpoint, with ops from https://github.com/tensorpack/zmq_ops. Receive tensors from a ZMQ endpoint, with ops from https://github.com/tensorpack/zmq_ops.
It works with :meth:`dataflow.remote.send_dataflow_zmq(format='zmq_op')`. It works with :meth:`dataflow.remote.send_dataflow_zmq(format='zmq_ops')`.
""" """
def __init__(self, end_point, hwm, bind=True): def __init__(self, end_point, hwm, bind=True):
""" """
......
...@@ -196,10 +196,8 @@ class Trainer(object): ...@@ -196,10 +196,8 @@ class Trainer(object):
logger.info("Creating the session ...") logger.info("Creating the session ...")
hooks = self._callbacks.get_hooks()
self.sess = session_creator.create_session() self.sess = session_creator.create_session()
self.hooked_sess = tf.train.MonitoredSession( self.initialize_hooks()
session_creator=ReuseSessionCreator(self.sess), hooks=hooks)
if self.is_chief: if self.is_chief:
logger.info("Initializing the session ...") logger.info("Initializing the session ...")
...@@ -211,6 +209,18 @@ class Trainer(object): ...@@ -211,6 +209,18 @@ class Trainer(object):
self.sess.graph.finalize() self.sess.graph.finalize()
logger.info("Graph Finalized.") logger.info("Graph Finalized.")
@call_only_once
def initialize_hooks(self):
"""
Create SessionRunHooks for all callbacks, and hook it onto self.sess.
A new trainer may override this method to create multiple groups of hooks,
which can be useful when the training is not done by a single `train_op`.
"""
hooks = self._callbacks.get_hooks()
self.hooked_sess = tf.train.MonitoredSession(
session_creator=ReuseSessionCreator(self.sess), hooks=hooks)
@call_only_once @call_only_once
def main_loop(self, steps_per_epoch, starting_epoch, max_epoch): def main_loop(self, steps_per_epoch, starting_epoch, max_epoch):
""" """
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment