Commit f257f0e0 authored by Yuxin Wu's avatar Yuxin Wu

fix horovod trainer broadcast stage again

parent 8fec1bfb
tensorpack.input_source package
================================
Relevant tutorials: :doc:`../tutorial/input-source`.
Read the relevant tutorials first for an overview of InputSource: :doc:`../tutorial/input-source`.
.. automodule:: tensorpack.input_source
:members:
......
......@@ -84,24 +84,27 @@ You just need the right interface to connect Python to the graph directly, effic
## InputSource
`InputSource` is an abstract interface in tensorpack, to describe where the inputs come from and how they enter the graph.
For example,
`InputSource` is an abstract interface used by tensorpack trainers, to describe where the inputs come from and how they enter the graph.
Some choices are:
1. [FeedInput](../modules/input_source.html#tensorpack.input_source.FeedInput):
Come from a DataFlow and get fed to the graph (slow).
Data come from a DataFlow and get fed to the graph (slow).
2. [QueueInput](../modules/input_source.html#tensorpack.input_source.QueueInput):
Come from a DataFlow and get buffered on CPU by a TF queue.
Data come from a DataFlow and get buffered on CPU by a TF queue.
3. [StagingInput](../modules/input_source.html#tensorpack.input_source.StagingInput):
Come from some `InputSource`, then prefetched on GPU by a TF StagingArea.
Come from some other `InputSource`, then prefetched on GPU by a TF StagingArea.
4. [TFDatasetInput](http://tensorpack.readthedocs.io/en/latest/modules/input_source.html#tensorpack.input_source.TFDatasetInput)
Come from a `tf.data.Dataset`.
5. [dataflow_to_dataset](http://tensorpack.readthedocs.io/en/latest/modules/input_source.html#tensorpack.input_source.TFDatasetInput.dataflow_to_dataset)
Come from a DataFlow, and further processed by `tf.data.Dataset`.
Come from a DataFlow, and then lfurther processed by utilities in `tf.data.Dataset`.
6. [TensorInput](../modules/input_source.html#tensorpack.input_source.TensorInput):
Come from some tensors you define (can be reading ops, for example).
7. [ZMQInput](http://tensorpack.readthedocs.io/en/latest/modules/input_source.html#tensorpack.input_source.ZMQInput)
Come from some ZeroMQ pipe, where the reading/preprocessing may happen in a different process or even a different machine.
Typically, we recommend `QueueInput + StagingInput` as it's good for most use cases.
Typically, we recommend using `DataFlow + QueueInput` as it's good for most use cases.
If your data has to come from a separate process for whatever reasons, use `ZMQInput`.
If you still like to use TF reading ops, define a `tf.data.Dataset` and use `TFDatasetInput`.
If you need to use TF reading ops directly, either define a `tf.data.Dataset`
and use `TFDatasetInput`, or use `TensorInput`.
Refer to the documentation of these `InputSource` for more details.
......@@ -3,14 +3,14 @@
This example is mainly to demonstrate:
1. How to train an RNN with persistent state between iterations.
Here it simply manages the state inside the graph. `state_saving_rnn` can be used for more complicated use case.
1. How to train an RNN with persistent state between iterations. Here it simply manages the state inside the graph.
2. How to use a TF reader pipeline instead of a DataFlow, for both training & inference.
It trains an language model on PTB dataset, basically an equivalent of the PTB example
in [tensorflow/models](https://github.com/tensorflow/models/tree/master/tutorials/rnn/ptb)
with its "medium" config.
It has the same performance & speed as the original example as well.
Note that the data pipeline is completely copied from the tensorflow example.
To Train:
......
......@@ -103,7 +103,7 @@ class ProgressBar(Callback):
class MaintainStepCounter(Callback):
"""
It maintains the global step in the graph, making sure it's increased by one.
This callback is used by the trainer, you don't need to worry about it.
This callback is used internally by the trainer, you don't need to worry about it.
"""
_chief_only = False
......
......@@ -96,7 +96,8 @@ class FeedInput(InputSource):
infinite (bool): When set to False, will raise StopIteration when
ds is exhausted.
"""
assert isinstance(ds, DataFlow), ds
if not isinstance(ds, DataFlow):
raise ValueError("FeedInput takes a DataFlow! Got {}".format(ds))
self.ds = ds
if infinite:
self._iter_ds = RepeatedData(self.ds, -1)
......@@ -198,7 +199,8 @@ class QueueInput(FeedfreeInput):
should match the corresponding InputDesc of the model.
Defaults to a FIFO queue of size 50.
"""
assert isinstance(ds, DataFlow), ds
if not isinstance(ds, DataFlow):
raise ValueError("QueueInput takes a DataFlow! Got {}".format(ds))
self.queue = queue
self.ds = ds
self._inf_ds = RepeatedData(ds, -1)
......@@ -352,6 +354,8 @@ class TensorInput(FeedfreeInput):
The returned tensors will be evaluated every iteration, it's your job to make sure it's possible.
size(int): size of this input. Use None to leave it undefined.
"""
if not callable(get_tensor_fn):
raise ValueError("get_tensor_fn has to be a function! Got {}".format(get_tensor_fn))
self.get_tensor_fn = get_tensor_fn
if size is not None:
size = int(size)
......@@ -369,7 +373,9 @@ class TensorInput(FeedfreeInput):
def _get_input_tensors(self):
with self.cached_name_scope():
ret = self.get_tensor_fn()
assert len(ret) == len(self._desc), "{} != {}".format(len(ret), len(self._desc))
assert isinstance(ret, (list, tuple)), "get_tensor_fn needs to return a list!"
assert len(ret) == len(self._desc), \
"get_tensor_fn returns {} tensors but there are {} inputs".format(len(ret), len(self._desc))
return ret
......@@ -436,7 +442,7 @@ class ZMQInput(TensorInput):
class TFDatasetInput(FeedfreeInput):
"""
Use a :class:`tf.contrib.data.Dataset` instance as input.
Use a :class:`tf.data.Dataset` instance as input.
Note:
In training, the dataset should be infinite (use :func:`repeat()`).
......@@ -444,8 +450,10 @@ class TFDatasetInput(FeedfreeInput):
def __init__(self, dataset):
"""
Args:
dataset (tf.contrib.data.Dataset):
dataset (tf.data.Dataset):
"""
if not isinstance(dataset, tf.data.Dataset):
raise ValueError("TFDatasetInput takes a tf.data.Dataset! Got {}".format(dataset))
self._dataset = dataset
def _setup(self, inputs_desc):
......@@ -474,7 +482,8 @@ class TFDatasetInput(FeedfreeInput):
def _get_input_tensors(self):
desc_shapes = [k.shape for k in self._desc]
ret = self._iterator.get_next()
assert len(ret) == len(desc_shapes)
assert len(ret) == len(desc_shapes), \
"Dataset returns {} tensors but there are {} inputs!".format(len(ret), len(desc_shapes))
for t, shp in zip(ret, desc_shapes):
t.set_shape(shp)
return ret
......@@ -491,7 +500,7 @@ class TFDatasetInput(FeedfreeInput):
Args:
df (DataFlow): a dataflow which produces lists
types([tf.DType])
types([tf.DType]): list of types
Returns:
(tf.data.Dataset)
......@@ -559,13 +568,14 @@ class StagingInput(FeedfreeInput):
"""
Args:
input (FeedfreeInput):
nr_stage: number of elements to prefetch into each StagingArea, at the beginning.
nr_stage (int): number of elements to prefetch into each StagingArea, at the beginning.
Since enqueue and dequeue are synchronized, prefetching 1 element should be sufficient.
device (str or None): if not None, place the StagingArea on a specific device. e.g., '/cpu:0'.
Otherwise, they are placed under where `get_inputs_tensors`
gets called, which could be unspecified in case of simple trainers.
"""
assert isinstance(input, FeedfreeInput), input
if not isinstance(input, FeedfreeInput):
raise ValueError("StagingInput takes a FeedfreeInput! Got {}".format(input))
self._input = input
self._nr_stage = nr_stage
......
......@@ -70,7 +70,11 @@ def get_global_step_var():
def get_global_step_value():
"""
Returns:
int: global_step value in current graph and session"""
int: global_step value in current graph and session
Has to be called under a default session.
"""
return tf.train.global_step(
tf.get_default_session(),
get_global_step_var())
......
......@@ -214,7 +214,7 @@ class Trainer(object):
if not isinstance(session_init, JustCurrentSession):
logger.warn("This is not a chief worker, 'session_init' was ignored!")
self.sess.graph.finalize()
self.sess.graph.finalize() # possibly already finalized by ChiefSessionCreator
logger.info("Graph Finalized.")
@call_only_once
......
......@@ -5,7 +5,7 @@ import os
import tensorflow as tf
import multiprocessing as mp
from ..callbacks import RunOp
from ..callbacks import RunOp, CallbackFactory
from ..tfutils.sesscreate import NewSessionCreator
from ..utils import logger
......@@ -379,15 +379,23 @@ class HorovodTrainer(SingleCostTrainer):
opt = get_opt_fn()
self.train_op = opt.apply_gradients(grads, name='min_op')
with tf.name_scope('horovod_broadcast'):
self._broadcast_op = hvd.broadcast_global_variables(0)
cb = RunOp(
self._broadcast_op, run_before=False,
run_as_trigger=True, verbose=True)
def broadcast(self):
logger.info("Running horovod broadcast ...")
# the op will be created later in initialize()
self.trainer._broadcast_op.run()
cb = CallbackFactory(trigger=broadcast)
return [cb]
@HIDE_DOC
def initialize(self, session_creator, session_init):
# broadcast_op should be the last setup_graph: it needs to be created
# "right before" the session is initialized,
# because it needs to capture all the variables (which may be created by callbacks).
with tf.name_scope('horovod_broadcast'):
self._broadcast_op = hvd.broadcast_global_variables(0)
if not isinstance(session_creator, NewSessionCreator):
raise ValueError(
"session_creator has to be `NewSessionCreator` for horovod training! ")
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment