Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
f257f0e0
Commit
f257f0e0
authored
Aug 30, 2018
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
fix horovod trainer broadcast stage again
parent
8fec1bfb
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
54 additions
and
29 deletions
+54
-29
docs/modules/input_source.rst
docs/modules/input_source.rst
+1
-1
docs/tutorial/input-source.md
docs/tutorial/input-source.md
+11
-8
examples/PennTreebank/README.md
examples/PennTreebank/README.md
+2
-2
tensorpack/callbacks/steps.py
tensorpack/callbacks/steps.py
+1
-1
tensorpack/input_source/input_source.py
tensorpack/input_source/input_source.py
+19
-9
tensorpack/tfutils/common.py
tensorpack/tfutils/common.py
+5
-1
tensorpack/train/base.py
tensorpack/train/base.py
+1
-1
tensorpack/train/trainers.py
tensorpack/train/trainers.py
+14
-6
No files found.
docs/modules/input_source.rst
View file @
f257f0e0
tensorpack.input_source package
================================
Re
levant tutorials
: :doc:`../tutorial/input-source`.
Re
ad the relevant tutorials first for an overview of InputSource
: :doc:`../tutorial/input-source`.
.. automodule:: tensorpack.input_source
:members:
...
...
docs/tutorial/input-source.md
View file @
f257f0e0
...
...
@@ -84,24 +84,27 @@ You just need the right interface to connect Python to the graph directly, effic
## InputSource
`InputSource`
is an abstract interface
in tensorpack
, to describe where the inputs come from and how they enter the graph.
For example,
`InputSource`
is an abstract interface
used by tensorpack trainers
, to describe where the inputs come from and how they enter the graph.
Some choices are:
1.
[
FeedInput
](
../modules/input_source.html#tensorpack.input_source.FeedInput
)
:
C
ome from a DataFlow and get fed to the graph (slow).
Data c
ome from a DataFlow and get fed to the graph (slow).
2.
[
QueueInput
](
../modules/input_source.html#tensorpack.input_source.QueueInput
)
:
C
ome from a DataFlow and get buffered on CPU by a TF queue.
Data c
ome from a DataFlow and get buffered on CPU by a TF queue.
3.
[
StagingInput
](
../modules/input_source.html#tensorpack.input_source.StagingInput
)
:
Come from some
`InputSource`
, then prefetched on GPU by a TF StagingArea.
Come from some
other
`InputSource`
, then prefetched on GPU by a TF StagingArea.
4.
[
TFDatasetInput
](
http://tensorpack.readthedocs.io/en/latest/modules/input_source.html#tensorpack.input_source.TFDatasetInput
)
Come from a
`tf.data.Dataset`
.
5.
[
dataflow_to_dataset
](
http://tensorpack.readthedocs.io/en/latest/modules/input_source.html#tensorpack.input_source.TFDatasetInput.dataflow_to_dataset
)
Come from a DataFlow, and
further processed by
`tf.data.Dataset`
.
Come from a DataFlow, and
then lfurther processed by utilities in
`tf.data.Dataset`
.
6.
[
TensorInput
](
../modules/input_source.html#tensorpack.input_source.TensorInput
)
:
Come from some tensors you define (can be reading ops, for example).
7.
[
ZMQInput
](
http://tensorpack.readthedocs.io/en/latest/modules/input_source.html#tensorpack.input_source.ZMQInput
)
Come from some ZeroMQ pipe, where the reading/preprocessing may happen in a different process or even a different machine.
Typically, we recommend
`QueueInput + Staging
Input`
as it's good for most use cases.
Typically, we recommend
using
`DataFlow + Queue
Input`
as it's good for most use cases.
If your data has to come from a separate process for whatever reasons, use
`ZMQInput`
.
If you still like to use TF reading ops, define a
`tf.data.Dataset`
and use
`TFDatasetInput`
.
If you need to use TF reading ops directly, either define a
`tf.data.Dataset`
and use
`TFDatasetInput`
, or use
`TensorInput`
.
Refer to the documentation of these
`InputSource`
for more details.
examples/PennTreebank/README.md
View file @
f257f0e0
...
...
@@ -3,14 +3,14 @@
This example is mainly to demonstrate:
1.
How to train an RNN with persistent state between iterations.
Here it simply manages the state inside the graph.
`state_saving_rnn`
can be used for more complicated use case.
1.
How to train an RNN with persistent state between iterations. Here it simply manages the state inside the graph.
2.
How to use a TF reader pipeline instead of a DataFlow, for both training & inference.
It trains an language model on PTB dataset, basically an equivalent of the PTB example
in
[
tensorflow/models
](
https://github.com/tensorflow/models/tree/master/tutorials/rnn/ptb
)
with its "medium" config.
It has the same performance & speed as the original example as well.
Note that the data pipeline is completely copied from the tensorflow example.
To Train:
...
...
tensorpack/callbacks/steps.py
View file @
f257f0e0
...
...
@@ -103,7 +103,7 @@ class ProgressBar(Callback):
class
MaintainStepCounter
(
Callback
):
"""
It maintains the global step in the graph, making sure it's increased by one.
This callback is used by the trainer, you don't need to worry about it.
This callback is used
internally
by the trainer, you don't need to worry about it.
"""
_chief_only
=
False
...
...
tensorpack/input_source/input_source.py
View file @
f257f0e0
...
...
@@ -96,7 +96,8 @@ class FeedInput(InputSource):
infinite (bool): When set to False, will raise StopIteration when
ds is exhausted.
"""
assert
isinstance
(
ds
,
DataFlow
),
ds
if
not
isinstance
(
ds
,
DataFlow
):
raise
ValueError
(
"FeedInput takes a DataFlow! Got {}"
.
format
(
ds
))
self
.
ds
=
ds
if
infinite
:
self
.
_iter_ds
=
RepeatedData
(
self
.
ds
,
-
1
)
...
...
@@ -198,7 +199,8 @@ class QueueInput(FeedfreeInput):
should match the corresponding InputDesc of the model.
Defaults to a FIFO queue of size 50.
"""
assert
isinstance
(
ds
,
DataFlow
),
ds
if
not
isinstance
(
ds
,
DataFlow
):
raise
ValueError
(
"QueueInput takes a DataFlow! Got {}"
.
format
(
ds
))
self
.
queue
=
queue
self
.
ds
=
ds
self
.
_inf_ds
=
RepeatedData
(
ds
,
-
1
)
...
...
@@ -352,6 +354,8 @@ class TensorInput(FeedfreeInput):
The returned tensors will be evaluated every iteration, it's your job to make sure it's possible.
size(int): size of this input. Use None to leave it undefined.
"""
if
not
callable
(
get_tensor_fn
):
raise
ValueError
(
"get_tensor_fn has to be a function! Got {}"
.
format
(
get_tensor_fn
))
self
.
get_tensor_fn
=
get_tensor_fn
if
size
is
not
None
:
size
=
int
(
size
)
...
...
@@ -369,7 +373,9 @@ class TensorInput(FeedfreeInput):
def
_get_input_tensors
(
self
):
with
self
.
cached_name_scope
():
ret
=
self
.
get_tensor_fn
()
assert
len
(
ret
)
==
len
(
self
.
_desc
),
"{} != {}"
.
format
(
len
(
ret
),
len
(
self
.
_desc
))
assert
isinstance
(
ret
,
(
list
,
tuple
)),
"get_tensor_fn needs to return a list!"
assert
len
(
ret
)
==
len
(
self
.
_desc
),
\
"get_tensor_fn returns {} tensors but there are {} inputs"
.
format
(
len
(
ret
),
len
(
self
.
_desc
))
return
ret
...
...
@@ -436,7 +442,7 @@ class ZMQInput(TensorInput):
class
TFDatasetInput
(
FeedfreeInput
):
"""
Use a :class:`tf.
contrib.
data.Dataset` instance as input.
Use a :class:`tf.data.Dataset` instance as input.
Note:
In training, the dataset should be infinite (use :func:`repeat()`).
...
...
@@ -444,8 +450,10 @@ class TFDatasetInput(FeedfreeInput):
def
__init__
(
self
,
dataset
):
"""
Args:
dataset (tf.
contrib.
data.Dataset):
dataset (tf.data.Dataset):
"""
if
not
isinstance
(
dataset
,
tf
.
data
.
Dataset
):
raise
ValueError
(
"TFDatasetInput takes a tf.data.Dataset! Got {}"
.
format
(
dataset
))
self
.
_dataset
=
dataset
def
_setup
(
self
,
inputs_desc
):
...
...
@@ -474,7 +482,8 @@ class TFDatasetInput(FeedfreeInput):
def
_get_input_tensors
(
self
):
desc_shapes
=
[
k
.
shape
for
k
in
self
.
_desc
]
ret
=
self
.
_iterator
.
get_next
()
assert
len
(
ret
)
==
len
(
desc_shapes
)
assert
len
(
ret
)
==
len
(
desc_shapes
),
\
"Dataset returns {} tensors but there are {} inputs!"
.
format
(
len
(
ret
),
len
(
desc_shapes
))
for
t
,
shp
in
zip
(
ret
,
desc_shapes
):
t
.
set_shape
(
shp
)
return
ret
...
...
@@ -491,7 +500,7 @@ class TFDatasetInput(FeedfreeInput):
Args:
df (DataFlow): a dataflow which produces lists
types([tf.DType])
types([tf.DType])
: list of types
Returns:
(tf.data.Dataset)
...
...
@@ -559,13 +568,14 @@ class StagingInput(FeedfreeInput):
"""
Args:
input (FeedfreeInput):
nr_stage: number of elements to prefetch into each StagingArea, at the beginning.
nr_stage
(int)
: number of elements to prefetch into each StagingArea, at the beginning.
Since enqueue and dequeue are synchronized, prefetching 1 element should be sufficient.
device (str or None): if not None, place the StagingArea on a specific device. e.g., '/cpu:0'.
Otherwise, they are placed under where `get_inputs_tensors`
gets called, which could be unspecified in case of simple trainers.
"""
assert
isinstance
(
input
,
FeedfreeInput
),
input
if
not
isinstance
(
input
,
FeedfreeInput
):
raise
ValueError
(
"StagingInput takes a FeedfreeInput! Got {}"
.
format
(
input
))
self
.
_input
=
input
self
.
_nr_stage
=
nr_stage
...
...
tensorpack/tfutils/common.py
View file @
f257f0e0
...
...
@@ -70,7 +70,11 @@ def get_global_step_var():
def
get_global_step_value
():
"""
Returns:
int: global_step value in current graph and session"""
int: global_step value in current graph and session
Has to be called under a default session.
"""
return
tf
.
train
.
global_step
(
tf
.
get_default_session
(),
get_global_step_var
())
...
...
tensorpack/train/base.py
View file @
f257f0e0
...
...
@@ -214,7 +214,7 @@ class Trainer(object):
if
not
isinstance
(
session_init
,
JustCurrentSession
):
logger
.
warn
(
"This is not a chief worker, 'session_init' was ignored!"
)
self
.
sess
.
graph
.
finalize
()
self
.
sess
.
graph
.
finalize
()
# possibly already finalized by ChiefSessionCreator
logger
.
info
(
"Graph Finalized."
)
@
call_only_once
...
...
tensorpack/train/trainers.py
View file @
f257f0e0
...
...
@@ -5,7 +5,7 @@ import os
import
tensorflow
as
tf
import
multiprocessing
as
mp
from
..callbacks
import
RunOp
from
..callbacks
import
RunOp
,
CallbackFactory
from
..tfutils.sesscreate
import
NewSessionCreator
from
..utils
import
logger
...
...
@@ -379,15 +379,23 @@ class HorovodTrainer(SingleCostTrainer):
opt
=
get_opt_fn
()
self
.
train_op
=
opt
.
apply_gradients
(
grads
,
name
=
'min_op'
)
with
tf
.
name_scope
(
'horovod_broadcast'
):
self
.
_broadcast_op
=
hvd
.
broadcast_global_variables
(
0
)
cb
=
RunOp
(
self
.
_broadcast_op
,
run_before
=
False
,
run_as_trigger
=
True
,
verbose
=
True
)
def
broadcast
(
self
):
logger
.
info
(
"Running horovod broadcast ..."
)
# the op will be created later in initialize()
self
.
trainer
.
_broadcast_op
.
run
()
cb
=
CallbackFactory
(
trigger
=
broadcast
)
return
[
cb
]
@
HIDE_DOC
def
initialize
(
self
,
session_creator
,
session_init
):
# broadcast_op should be the last setup_graph: it needs to be created
# "right before" the session is initialized,
# because it needs to capture all the variables (which may be created by callbacks).
with
tf
.
name_scope
(
'horovod_broadcast'
):
self
.
_broadcast_op
=
hvd
.
broadcast_global_variables
(
0
)
if
not
isinstance
(
session_creator
,
NewSessionCreator
):
raise
ValueError
(
"session_creator has to be `NewSessionCreator` for horovod training! "
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment