Commit e465842d authored by Yuxin Wu's avatar Yuxin Wu

update docs

parent c346e924
......@@ -68,11 +68,19 @@ You can customize the trainer by either using or inheriting the base `Trainer` c
You will need to do two things for a new Trainer:
1. Define the graph. There are 2 ways you can do this:
1. Create any tensors and ops you like, before creating the trainer.
1. Create any tensors and ops you need, before creating the trainer.
2. Create them inside `Trainer.__init__`.
2. Define what is the iteration. There are 2 ways to define the iteration:
1. Set `Trainer.train_op` to a TensorFlow operation. This op will be run by default.
2. Subclass `Trainer` and override the `run_step()` method. This way you can do something more than running an op.
2. Subclass `Trainer` and override the `run_step()` method. This way you can
do something more than running an op.
Note that trainer has `self.sess` and `self.hooked_sess`: only the hooked
session will trigger the `before_run`/`after_run` callbacks.
If you need more than one `Session.run` in one steps, special care needs
to be taken to choose which session to use, because many states
(global steps, StagingArea, summaries) are maintained through `before_run`/`after_run`.
There are several different [GAN trainers](../../examples/GAN/GAN.py) for reference.
......@@ -418,7 +418,7 @@ class ResNetFPNModel(DetectionModel):
mrcnn_loss = 0.0
wd_cost = regularize_cost(
'(?:group1|group2|group3|rpn|fpn|fastrcnn|maskrcnn)/.*W',
'(?:group1|group2|group3|rpn|fastrcnn|maskrcnn)/.*W',
l2_regularizer(1e-4), name='wd_cost')
total_cost = tf.add_n(rpn_loss_collection + [
......
......@@ -7,7 +7,6 @@
import tensorflow as tf
from .base import Callback
__all__ = ['CallbackToHook', 'HookToCallback']
......@@ -17,8 +16,6 @@ class CallbackToHook(tf.train.SessionRunHook):
You shouldn't need to use this.
"""
_chief_only = False
def __init__(self, cb):
self._cb = cb
......
......@@ -22,18 +22,19 @@ else:
def send_dataflow_zmq(df, addr, hwm=50, format=None, bind=False):
"""
Run DataFlow and send data to a ZMQ socket addr.
It will __connect__ to this addr,
serialize and send each datapoint to this addr with a PUSH socket.
This function never returns unless an error is encountered.
It will serialize and send each datapoint to this address with a PUSH socket.
This function never returns.
Args:
df (DataFlow): Will infinitely loop over the DataFlow.
addr: a ZMQ socket endpoint.
hwm (int): ZMQ high-water mark (buffer size)
format (str): The serialization format.
Default format would use :mod:`tensorpack.utils.serialize`.
An alternate format is 'zmq_ops', used by https://github.com/tensorpack/zmq_ops.
bind (bool): whether to bind or connect to the endpoint.
Default format uses :mod:`tensorpack.utils.serialize`.
This format works with :class:`dataflow.RemoteDataZMQ`.
An alternate format is 'zmq_ops', used by https://github.com/tensorpack/zmq_ops
and :class:`input_source.ZMQInput`.
bind (bool): whether to bind or connect to the endpoint address.
"""
assert format in [None, 'zmq_op', 'zmq_ops']
if format is None:
......@@ -82,6 +83,8 @@ def send_dataflow_zmq(df, addr, hwm=50, format=None, bind=False):
class RemoteDataZMQ(DataFlow):
"""
Produce data from ZMQ PULL socket(s).
It is the receiver-side counterpart of :func:`send_dataflow_zmq`, which uses :mod:`tensorpack.utils.serialize`
for serialization.
See http://tensorpack.readthedocs.io/en/latest/tutorial/efficient-dataflow.html#distributed-dataflow
Attributes:
......
......@@ -373,7 +373,7 @@ class DummyConstantInput(TensorInput):
class ZMQInput(TensorInput):
"""
Receive tensors from a ZMQ endpoint, with ops from https://github.com/tensorpack/zmq_ops.
It works with :meth:`dataflow.remote.send_dataflow_zmq(format='zmq_op')`.
It works with :meth:`dataflow.remote.send_dataflow_zmq(format='zmq_ops')`.
"""
def __init__(self, end_point, hwm, bind=True):
"""
......
......@@ -196,10 +196,8 @@ class Trainer(object):
logger.info("Creating the session ...")
hooks = self._callbacks.get_hooks()
self.sess = session_creator.create_session()
self.hooked_sess = tf.train.MonitoredSession(
session_creator=ReuseSessionCreator(self.sess), hooks=hooks)
self.initialize_hooks()
if self.is_chief:
logger.info("Initializing the session ...")
......@@ -211,6 +209,18 @@ class Trainer(object):
self.sess.graph.finalize()
logger.info("Graph Finalized.")
@call_only_once
def initialize_hooks(self):
"""
Create SessionRunHooks for all callbacks, and hook it onto self.sess.
A new trainer may override this method to create multiple groups of hooks,
which can be useful when the training is not done by a single `train_op`.
"""
hooks = self._callbacks.get_hooks()
self.hooked_sess = tf.train.MonitoredSession(
session_creator=ReuseSessionCreator(self.sess), hooks=hooks)
@call_only_once
def main_loop(self, steps_per_epoch, starting_epoch, max_epoch):
"""
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment