Commit f257f0e0 authored by Yuxin Wu's avatar Yuxin Wu

fix horovod trainer broadcast stage again

parent 8fec1bfb
tensorpack.input_source package tensorpack.input_source package
================================ ================================
Relevant tutorials: :doc:`../tutorial/input-source`. Read the relevant tutorials first for an overview of InputSource: :doc:`../tutorial/input-source`.
.. automodule:: tensorpack.input_source .. automodule:: tensorpack.input_source
:members: :members:
......
...@@ -84,24 +84,27 @@ You just need the right interface to connect Python to the graph directly, effic ...@@ -84,24 +84,27 @@ You just need the right interface to connect Python to the graph directly, effic
## InputSource ## InputSource
`InputSource` is an abstract interface in tensorpack, to describe where the inputs come from and how they enter the graph. `InputSource` is an abstract interface used by tensorpack trainers, to describe where the inputs come from and how they enter the graph.
For example, Some choices are:
1. [FeedInput](../modules/input_source.html#tensorpack.input_source.FeedInput): 1. [FeedInput](../modules/input_source.html#tensorpack.input_source.FeedInput):
Come from a DataFlow and get fed to the graph (slow). Data come from a DataFlow and get fed to the graph (slow).
2. [QueueInput](../modules/input_source.html#tensorpack.input_source.QueueInput): 2. [QueueInput](../modules/input_source.html#tensorpack.input_source.QueueInput):
Come from a DataFlow and get buffered on CPU by a TF queue. Data come from a DataFlow and get buffered on CPU by a TF queue.
3. [StagingInput](../modules/input_source.html#tensorpack.input_source.StagingInput): 3. [StagingInput](../modules/input_source.html#tensorpack.input_source.StagingInput):
Come from some `InputSource`, then prefetched on GPU by a TF StagingArea. Come from some other `InputSource`, then prefetched on GPU by a TF StagingArea.
4. [TFDatasetInput](http://tensorpack.readthedocs.io/en/latest/modules/input_source.html#tensorpack.input_source.TFDatasetInput) 4. [TFDatasetInput](http://tensorpack.readthedocs.io/en/latest/modules/input_source.html#tensorpack.input_source.TFDatasetInput)
Come from a `tf.data.Dataset`. Come from a `tf.data.Dataset`.
5. [dataflow_to_dataset](http://tensorpack.readthedocs.io/en/latest/modules/input_source.html#tensorpack.input_source.TFDatasetInput.dataflow_to_dataset) 5. [dataflow_to_dataset](http://tensorpack.readthedocs.io/en/latest/modules/input_source.html#tensorpack.input_source.TFDatasetInput.dataflow_to_dataset)
Come from a DataFlow, and further processed by `tf.data.Dataset`. Come from a DataFlow, and then lfurther processed by utilities in `tf.data.Dataset`.
6. [TensorInput](../modules/input_source.html#tensorpack.input_source.TensorInput): 6. [TensorInput](../modules/input_source.html#tensorpack.input_source.TensorInput):
Come from some tensors you define (can be reading ops, for example). Come from some tensors you define (can be reading ops, for example).
7. [ZMQInput](http://tensorpack.readthedocs.io/en/latest/modules/input_source.html#tensorpack.input_source.ZMQInput) 7. [ZMQInput](http://tensorpack.readthedocs.io/en/latest/modules/input_source.html#tensorpack.input_source.ZMQInput)
Come from some ZeroMQ pipe, where the reading/preprocessing may happen in a different process or even a different machine. Come from some ZeroMQ pipe, where the reading/preprocessing may happen in a different process or even a different machine.
Typically, we recommend `QueueInput + StagingInput` as it's good for most use cases. Typically, we recommend using `DataFlow + QueueInput` as it's good for most use cases.
If your data has to come from a separate process for whatever reasons, use `ZMQInput`. If your data has to come from a separate process for whatever reasons, use `ZMQInput`.
If you still like to use TF reading ops, define a `tf.data.Dataset` and use `TFDatasetInput`. If you need to use TF reading ops directly, either define a `tf.data.Dataset`
and use `TFDatasetInput`, or use `TensorInput`.
Refer to the documentation of these `InputSource` for more details.
...@@ -3,14 +3,14 @@ ...@@ -3,14 +3,14 @@
This example is mainly to demonstrate: This example is mainly to demonstrate:
1. How to train an RNN with persistent state between iterations. 1. How to train an RNN with persistent state between iterations. Here it simply manages the state inside the graph.
Here it simply manages the state inside the graph. `state_saving_rnn` can be used for more complicated use case.
2. How to use a TF reader pipeline instead of a DataFlow, for both training & inference. 2. How to use a TF reader pipeline instead of a DataFlow, for both training & inference.
It trains an language model on PTB dataset, basically an equivalent of the PTB example It trains an language model on PTB dataset, basically an equivalent of the PTB example
in [tensorflow/models](https://github.com/tensorflow/models/tree/master/tutorials/rnn/ptb) in [tensorflow/models](https://github.com/tensorflow/models/tree/master/tutorials/rnn/ptb)
with its "medium" config. with its "medium" config.
It has the same performance & speed as the original example as well. It has the same performance & speed as the original example as well.
Note that the data pipeline is completely copied from the tensorflow example. Note that the data pipeline is completely copied from the tensorflow example.
To Train: To Train:
......
...@@ -103,7 +103,7 @@ class ProgressBar(Callback): ...@@ -103,7 +103,7 @@ class ProgressBar(Callback):
class MaintainStepCounter(Callback): class MaintainStepCounter(Callback):
""" """
It maintains the global step in the graph, making sure it's increased by one. It maintains the global step in the graph, making sure it's increased by one.
This callback is used by the trainer, you don't need to worry about it. This callback is used internally by the trainer, you don't need to worry about it.
""" """
_chief_only = False _chief_only = False
......
...@@ -96,7 +96,8 @@ class FeedInput(InputSource): ...@@ -96,7 +96,8 @@ class FeedInput(InputSource):
infinite (bool): When set to False, will raise StopIteration when infinite (bool): When set to False, will raise StopIteration when
ds is exhausted. ds is exhausted.
""" """
assert isinstance(ds, DataFlow), ds if not isinstance(ds, DataFlow):
raise ValueError("FeedInput takes a DataFlow! Got {}".format(ds))
self.ds = ds self.ds = ds
if infinite: if infinite:
self._iter_ds = RepeatedData(self.ds, -1) self._iter_ds = RepeatedData(self.ds, -1)
...@@ -198,7 +199,8 @@ class QueueInput(FeedfreeInput): ...@@ -198,7 +199,8 @@ class QueueInput(FeedfreeInput):
should match the corresponding InputDesc of the model. should match the corresponding InputDesc of the model.
Defaults to a FIFO queue of size 50. Defaults to a FIFO queue of size 50.
""" """
assert isinstance(ds, DataFlow), ds if not isinstance(ds, DataFlow):
raise ValueError("QueueInput takes a DataFlow! Got {}".format(ds))
self.queue = queue self.queue = queue
self.ds = ds self.ds = ds
self._inf_ds = RepeatedData(ds, -1) self._inf_ds = RepeatedData(ds, -1)
...@@ -352,6 +354,8 @@ class TensorInput(FeedfreeInput): ...@@ -352,6 +354,8 @@ class TensorInput(FeedfreeInput):
The returned tensors will be evaluated every iteration, it's your job to make sure it's possible. The returned tensors will be evaluated every iteration, it's your job to make sure it's possible.
size(int): size of this input. Use None to leave it undefined. size(int): size of this input. Use None to leave it undefined.
""" """
if not callable(get_tensor_fn):
raise ValueError("get_tensor_fn has to be a function! Got {}".format(get_tensor_fn))
self.get_tensor_fn = get_tensor_fn self.get_tensor_fn = get_tensor_fn
if size is not None: if size is not None:
size = int(size) size = int(size)
...@@ -369,7 +373,9 @@ class TensorInput(FeedfreeInput): ...@@ -369,7 +373,9 @@ class TensorInput(FeedfreeInput):
def _get_input_tensors(self): def _get_input_tensors(self):
with self.cached_name_scope(): with self.cached_name_scope():
ret = self.get_tensor_fn() ret = self.get_tensor_fn()
assert len(ret) == len(self._desc), "{} != {}".format(len(ret), len(self._desc)) assert isinstance(ret, (list, tuple)), "get_tensor_fn needs to return a list!"
assert len(ret) == len(self._desc), \
"get_tensor_fn returns {} tensors but there are {} inputs".format(len(ret), len(self._desc))
return ret return ret
...@@ -436,7 +442,7 @@ class ZMQInput(TensorInput): ...@@ -436,7 +442,7 @@ class ZMQInput(TensorInput):
class TFDatasetInput(FeedfreeInput): class TFDatasetInput(FeedfreeInput):
""" """
Use a :class:`tf.contrib.data.Dataset` instance as input. Use a :class:`tf.data.Dataset` instance as input.
Note: Note:
In training, the dataset should be infinite (use :func:`repeat()`). In training, the dataset should be infinite (use :func:`repeat()`).
...@@ -444,8 +450,10 @@ class TFDatasetInput(FeedfreeInput): ...@@ -444,8 +450,10 @@ class TFDatasetInput(FeedfreeInput):
def __init__(self, dataset): def __init__(self, dataset):
""" """
Args: Args:
dataset (tf.contrib.data.Dataset): dataset (tf.data.Dataset):
""" """
if not isinstance(dataset, tf.data.Dataset):
raise ValueError("TFDatasetInput takes a tf.data.Dataset! Got {}".format(dataset))
self._dataset = dataset self._dataset = dataset
def _setup(self, inputs_desc): def _setup(self, inputs_desc):
...@@ -474,7 +482,8 @@ class TFDatasetInput(FeedfreeInput): ...@@ -474,7 +482,8 @@ class TFDatasetInput(FeedfreeInput):
def _get_input_tensors(self): def _get_input_tensors(self):
desc_shapes = [k.shape for k in self._desc] desc_shapes = [k.shape for k in self._desc]
ret = self._iterator.get_next() ret = self._iterator.get_next()
assert len(ret) == len(desc_shapes) assert len(ret) == len(desc_shapes), \
"Dataset returns {} tensors but there are {} inputs!".format(len(ret), len(desc_shapes))
for t, shp in zip(ret, desc_shapes): for t, shp in zip(ret, desc_shapes):
t.set_shape(shp) t.set_shape(shp)
return ret return ret
...@@ -491,7 +500,7 @@ class TFDatasetInput(FeedfreeInput): ...@@ -491,7 +500,7 @@ class TFDatasetInput(FeedfreeInput):
Args: Args:
df (DataFlow): a dataflow which produces lists df (DataFlow): a dataflow which produces lists
types([tf.DType]) types([tf.DType]): list of types
Returns: Returns:
(tf.data.Dataset) (tf.data.Dataset)
...@@ -559,13 +568,14 @@ class StagingInput(FeedfreeInput): ...@@ -559,13 +568,14 @@ class StagingInput(FeedfreeInput):
""" """
Args: Args:
input (FeedfreeInput): input (FeedfreeInput):
nr_stage: number of elements to prefetch into each StagingArea, at the beginning. nr_stage (int): number of elements to prefetch into each StagingArea, at the beginning.
Since enqueue and dequeue are synchronized, prefetching 1 element should be sufficient. Since enqueue and dequeue are synchronized, prefetching 1 element should be sufficient.
device (str or None): if not None, place the StagingArea on a specific device. e.g., '/cpu:0'. device (str or None): if not None, place the StagingArea on a specific device. e.g., '/cpu:0'.
Otherwise, they are placed under where `get_inputs_tensors` Otherwise, they are placed under where `get_inputs_tensors`
gets called, which could be unspecified in case of simple trainers. gets called, which could be unspecified in case of simple trainers.
""" """
assert isinstance(input, FeedfreeInput), input if not isinstance(input, FeedfreeInput):
raise ValueError("StagingInput takes a FeedfreeInput! Got {}".format(input))
self._input = input self._input = input
self._nr_stage = nr_stage self._nr_stage = nr_stage
......
...@@ -70,7 +70,11 @@ def get_global_step_var(): ...@@ -70,7 +70,11 @@ def get_global_step_var():
def get_global_step_value(): def get_global_step_value():
""" """
Returns: Returns:
int: global_step value in current graph and session""" int: global_step value in current graph and session
Has to be called under a default session.
"""
return tf.train.global_step( return tf.train.global_step(
tf.get_default_session(), tf.get_default_session(),
get_global_step_var()) get_global_step_var())
......
...@@ -214,7 +214,7 @@ class Trainer(object): ...@@ -214,7 +214,7 @@ class Trainer(object):
if not isinstance(session_init, JustCurrentSession): if not isinstance(session_init, JustCurrentSession):
logger.warn("This is not a chief worker, 'session_init' was ignored!") logger.warn("This is not a chief worker, 'session_init' was ignored!")
self.sess.graph.finalize() self.sess.graph.finalize() # possibly already finalized by ChiefSessionCreator
logger.info("Graph Finalized.") logger.info("Graph Finalized.")
@call_only_once @call_only_once
......
...@@ -5,7 +5,7 @@ import os ...@@ -5,7 +5,7 @@ import os
import tensorflow as tf import tensorflow as tf
import multiprocessing as mp import multiprocessing as mp
from ..callbacks import RunOp from ..callbacks import RunOp, CallbackFactory
from ..tfutils.sesscreate import NewSessionCreator from ..tfutils.sesscreate import NewSessionCreator
from ..utils import logger from ..utils import logger
...@@ -379,15 +379,23 @@ class HorovodTrainer(SingleCostTrainer): ...@@ -379,15 +379,23 @@ class HorovodTrainer(SingleCostTrainer):
opt = get_opt_fn() opt = get_opt_fn()
self.train_op = opt.apply_gradients(grads, name='min_op') self.train_op = opt.apply_gradients(grads, name='min_op')
with tf.name_scope('horovod_broadcast'):
self._broadcast_op = hvd.broadcast_global_variables(0) def broadcast(self):
cb = RunOp( logger.info("Running horovod broadcast ...")
self._broadcast_op, run_before=False, # the op will be created later in initialize()
run_as_trigger=True, verbose=True) self.trainer._broadcast_op.run()
cb = CallbackFactory(trigger=broadcast)
return [cb] return [cb]
@HIDE_DOC @HIDE_DOC
def initialize(self, session_creator, session_init): def initialize(self, session_creator, session_init):
# broadcast_op should be the last setup_graph: it needs to be created
# "right before" the session is initialized,
# because it needs to capture all the variables (which may be created by callbacks).
with tf.name_scope('horovod_broadcast'):
self._broadcast_op = hvd.broadcast_global_variables(0)
if not isinstance(session_creator, NewSessionCreator): if not isinstance(session_creator, NewSessionCreator):
raise ValueError( raise ValueError(
"session_creator has to be `NewSessionCreator` for horovod training! ") "session_creator has to be `NewSessionCreator` for horovod training! ")
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment