Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
cbd698ad
Commit
cbd698ad
authored
Sep 01, 2019
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
get_model_loader->SmartRestore; Improve horovod integration
parent
d71184c4
Changes
11
Show whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
180 additions
and
104 deletions
+180
-104
docs/conf.py
docs/conf.py
+1
-0
docs/tutorial/callback.md
docs/tutorial/callback.md
+4
-4
docs/tutorial/save-load.md
docs/tutorial/save-load.md
+17
-17
docs/tutorial/trainer.md
docs/tutorial/trainer.md
+5
-5
docs/tutorial/training-interface.md
docs/tutorial/training-interface.md
+1
-1
examples/FasterRCNN/config.py
examples/FasterRCNN/config.py
+2
-0
examples/FasterRCNN/train.py
examples/FasterRCNN/train.py
+9
-12
tensorpack/callbacks/prof.py
tensorpack/callbacks/prof.py
+42
-22
tensorpack/tfutils/common.py
tensorpack/tfutils/common.py
+1
-1
tensorpack/tfutils/sessinit.py
tensorpack/tfutils/sessinit.py
+39
-19
tensorpack/train/trainers.py
tensorpack/train/trainers.py
+59
-23
No files found.
docs/conf.py
View file @
cbd698ad
...
...
@@ -377,6 +377,7 @@ _DEPRECATED_NAMES = set([
'InputDesc'
,
'inputs_desc'
,
'Augmentor'
,
"get_model_loader"
,
# renamed items that should not appear in docs
'DumpTensor'
,
...
...
docs/tutorial/callback.md
View file @
cbd698ad
...
...
@@ -21,7 +21,7 @@ By writing callbacks to implement what to do at each place, tensorpack trainers
will call the callbacks at the proper time.
Therefore these features can be reused with one single line, as long as you are using tensorpack trainers.
For example,
these are the callbacks I used when training a ResNe
t:
For example,
here are some useful callbacks I used during model developmen
t:
```
python
callbacks
=
[
...
...
@@ -43,7 +43,7 @@ callbacks=[
-d type=note -d title="validation error"
\\
-d body={val-error-top1} > /dev/null 2>&1'
,
'val-error-top1'
),
# record GPU utilization
s
during training
# record GPU utilization during training
GPUUtilizationTracker
(),
# touch a file to pause the training and start a debug shell, to observe what's going on
InjectShell
(
shell
=
'ipython'
),
...
...
@@ -69,12 +69,12 @@ monitors=[ # monitors are a special kind of callbacks. these are also ena
]
```
You can see from the above snippet, that callbacks cover every detail of training,
ranging
from graph operations to the progress bar.
You can see from the above snippet, that callbacks cover every detail of training, from graph operations to the progress bar.
This means you can customize every part of the training to your preference, e.g. display something
different in the progress bar, evaluate part of the summaries at a different frequency, etc.
Similar concepts also exists in other frameworks, such as Keras callbacks, or
`tf.train.SessionRunHook`
. But tensorpack callbacks have more functionalities in
design, and can achive much more features, as you can see above.
design, and can achi
e
ve much more features, as you can see above.
These features are not always necessary, but think about how messy the main loop would look like if you
were to write these logic together with the loops, and how easy your life will be if you could enable
...
...
docs/tutorial/save-load.md
View file @
cbd698ad
...
...
@@ -20,10 +20,10 @@ demos how to print all variables and their shapes in a checkpoint.
Tensorpack includes another tool to save variables to TF checkpoint, see
[
save_chkpt_vars
](
../modules/tfutils.html#tensorpack.tfutils.varmanip.save_chkpt_vars
)
.
## Work with
npz Files in
Model Zoo
## Work with
.npz Files in the
Model Zoo
Most models provided by tensorpack are in npz (dictionary) format,
because it's easy to
manipulat
e without TF dependency.
because it's easy to
us
e without TF dependency.
You can read/write them with
`np.load`
and
`np.savez`
.
[
scripts/dump-model-params.py
](
../scripts/dump-model-params.py
)
can be used to remove unnecessary variables in a checkpoint
...
...
@@ -34,24 +34,24 @@ It dumps the model to a `var-name: value` dict saved in npz format.
## Load a Model to a Session
Model loading (in both training and inference) is through the
`session_init`
interface.
For training, use
`session_init`
in
`TrainConfig
`
or
`Trainer.train(
)`
.
For inference, use
`session_init`
in
`PredictConfig`
.
For training, use
`session_init`
in
`TrainConfig
(...)`
or
`Trainer.train(...
)`
.
For inference, use
`session_init`
in
`PredictConfig
(...)
`
.
There are two ways a session can be initialized:
[
session_init=SaverRestore(...)
](
../modules/tfutils.html#tensorpack.tfutils.sessinit.SaverRestore
)
which restores a TF checkpoint,
or
[
session_init=DictRestore(...)
](
../modules/tfutils.html#tensorpack.tfutils.sessinit.DictRestore
)
which restores a dict.
`DictRestore`
is the most general loader because you can make arbitrary changes
you need (e.g., remove variables, rename variables) to the dict.
To load multiple models, use
[
ChainInit
](
../modules/tfutils.html#tensorpack.tfutils.sessinit.ChainInit
)
.
There are a few ways a session can be initialized:
```
session_init=SmartRestore("path/to/checkpoint") # load a TF checkpoint
session_init=SmartRestore("path/to/model_zoo.npz") # load tensorpack model zoo
session_init=SmartRestore(dict_of_parameters) # load a dictionary
session_init=SmartRestore(["path1", dict2]) # load them sequentially
```
To load an npz file from tensorpack model zoo to a session, you can use
`DictRestore(dict(np.load(filename)))`
.
You can also use
[
get_model_loader(filename)
](
../modules/tfutils.html#tensorpack.tfutils.sessinit.get_model_loader
)
,
a small helper which returns either a
`SaverRestore`
or a
`DictRestore`
based on the file name.
[
SmartRestore
](
../modules/tfutils.html#tensorpack.tfutils.sessinit.SmartRestore
)
is in fact a small helper which uses some heuristics to return you one of
[
SaverRestore
](
../modules/tfutils.html#tensorpack.tfutils.sessinit.SaverRestore
)
or
[
DictRestore
](
../modules/tfutils.html#tensorpack.tfutils.sessinit.DictRestore
)
.
They are responsible for the actual initialization work.
Whatever you use in
`session_init`
, this is what happen during the loading:
Whatever you use in
`session_init`
, this is what happen
s
during the loading:
*
Variable restoring is completely based on __exact name match__ between
variables in the current graph and variables in the
`session_init`
initializer.
...
...
docs/tutorial/trainer.md
View file @
cbd698ad
# Trainers
Tensor
pack follows
the "define-and-run" paradigm.
Tensor
Flow & Tensorpack follow
the "define-and-run" paradigm.
Therefore a training contains two steps:
1.
__Define__: Build graph for the model.
Users can call whatever tensorflow functions to setup the graph.
Users may or may not use tensorpack
`InputSource`
,
`ModelDesc`
or other utilities to build the graph.
The goal of this step is to define "what to run" in later training steps,
and it can happen __either inside or outside__ tensorpack trainer.
The goal of this step is to define "what to run" in later training steps.
2.
__Run__: Train the model (the
[
Trainer.train() method
](
/modules/train.html#tensorpack.train.Trainer.train
)
):
...
...
@@ -26,7 +25,7 @@ by exploiting some universal patterns.
In research we do training of various kind.
Tensorpack trainers avoid making assumptions on what type of training
you want to do. For example, unlike Keras, tensorpack does not wrongly assume that:
1.
Your training is batched
1.
Your training
data
is batched
2.
Your training is gradient-based optimization
3.
Your data has
`X`
(inputs) and
`y`
(outputs)
4.
You want to evaluate on zero or one validation dataset
...
...
@@ -48,7 +47,8 @@ Users or derived trainers should implement __what the iterations are__.
In fact, the steps per epoch can be any number
and it only affects the
[
schedule of callbacks
](
callback.html
)
.
In other words, an "epoch" in tensorpack is the __default period to run
callbacks__ (validation, summary, checkpoint, etc.). It has nothing to do with your dataset.
callbacks__ (validation, summary, checkpoint, etc.).
So this assumption effectively puts no extra constraints.
### Built-in Trainers
...
...
docs/tutorial/training-interface.md
View file @
cbd698ad
...
...
@@ -42,7 +42,7 @@ After defining such a model, use it with `TrainConfig` and `launch_train_with_co
config
=
TrainConfig
(
model
=
MyModel
()
dataflow
=
my_dataflow
,
# data=my_inputsource, # alternatively, use a
customized
InputSource
# data=my_inputsource, # alternatively, use a
n
InputSource
callbacks
=
[
...
],
# some default callbacks are automatically applied
# some default monitors are automatically applied
steps_per_epoch
=
300
,
# default to the size of your InputSource/DataFlow
...
...
examples/FasterRCNN/config.py
View file @
cbd698ad
...
...
@@ -284,6 +284,8 @@ def finalize_configs(is_training):
if
_C
.
TRAINER
==
'horovod'
:
import
horovod.tensorflow
as
hvd
ngpu
=
hvd
.
size
()
logger
.
info
(
"Horovod Rank={}, Size={}, LocalRank={}"
.
format
(
hvd
.
rank
(),
hvd
.
size
(),
hvd
.
local_rank
()))
else
:
assert
'OMPI_COMM_WORLD_SIZE'
not
in
os
.
environ
ngpu
=
get_num_gpu
()
...
...
examples/FasterRCNN/train.py
View file @
cbd698ad
...
...
@@ -45,18 +45,17 @@ if __name__ == '__main__':
register_coco
(
cfg
.
DATA
.
BASEDIR
)
# add COCO datasets to the registry
register_balloon
(
cfg
.
DATA
.
BASEDIR
)
# add the demo balloon datasets to the registry
# Setup logg
er
...
# Setup logg
ing
...
is_horovod
=
cfg
.
TRAINER
==
'horovod'
if
is_horovod
:
hvd
.
init
()
logger
.
info
(
"Horovod Rank={}, Size={}"
.
format
(
hvd
.
rank
(),
hvd
.
size
()))
if
not
is_horovod
or
hvd
.
rank
()
==
0
:
logger
.
set_logger_dir
(
args
.
logdir
,
'd'
)
logger
.
info
(
"Environment Information:
\n
"
+
collect_env_info
())
finalize_configs
(
is_training
=
True
)
# Create model
MODEL
=
ResNetFPNModel
()
if
cfg
.
MODE_FPN
else
ResNetC4Model
()
# Compute the training schedule from the number of GPUs ...
stepnum
=
cfg
.
TRAIN
.
STEPS_PER_EPOCH
# warmup is step based, lr is epoch based
...
...
@@ -77,9 +76,7 @@ if __name__ == '__main__':
total_passes
=
cfg
.
TRAIN
.
LR_SCHEDULE
[
-
1
]
*
8
/
train_dataflow
.
size
()
logger
.
info
(
"Total passes of the training set is: {:.5g}"
.
format
(
total_passes
))
# Create model and callbacks ...
MODEL
=
ResNetFPNModel
()
if
cfg
.
MODE_FPN
else
ResNetC4Model
()
# Create callbacks ...
callbacks
=
[
PeriodicCallback
(
ModelSaver
(
max_to_keep
=
10
,
keep_checkpoint_every_n_hours
=
1
),
...
...
@@ -93,23 +90,22 @@ if __name__ == '__main__':
ThroughputTracker
(
samples_per_step
=
cfg
.
TRAIN
.
NUM_GPUS
),
EstimatedTimeLeft
(
median
=
True
),
SessionRunTimeout
(
60000
),
# 1 minute timeout
GPUUtilizationTracker
()
]
if
cfg
.
TRAIN
.
EVAL_PERIOD
>
0
:
callbacks
.
extend
([
EvalCallback
(
dataset
,
*
MODEL
.
get_inference_tensor_names
(),
args
.
logdir
)
for
dataset
in
cfg
.
DATA
.
VAL
])
if
not
is_horovod
:
callbacks
.
append
(
GPUUtilizationTracker
())
if
is_horovod
and
hvd
.
rank
()
>
0
:
session_init
=
None
else
:
if
args
.
load
:
# ignore mismatched values, so you can `--load` a model for fine-tuning
session_init
=
get_model_loader
(
args
.
load
,
ignore_mismatch
=
True
)
session_init
=
SmartRestore
(
args
.
load
,
ignore_mismatch
=
True
)
else
:
session_init
=
get_model_loader
(
cfg
.
BACKBONE
.
WEIGHTS
)
if
cfg
.
BACKBONE
.
WEIGHTS
else
None
session_init
=
SmartRestore
(
cfg
.
BACKBONE
.
WEIGHTS
)
traincfg
=
TrainConfig
(
model
=
MODEL
,
...
...
@@ -120,6 +116,7 @@ if __name__ == '__main__':
session_init
=
session_init
,
starting_epoch
=
cfg
.
TRAIN
.
STARTING_EPOCH
)
if
is_horovod
:
trainer
=
HorovodTrainer
(
average
=
False
)
else
:
...
...
tensorpack/callbacks/prof.py
View file @
cbd698ad
...
...
@@ -37,26 +37,41 @@ class GPUUtilizationTracker(Callback):
def
__init__
(
self
,
devices
=
None
):
"""
Args:
devices (list[int]): physical GPU ids
. If None, will use CUDA_VISIBLE_DEVICES
devices (list[int]): physical GPU ids
to monitor. If None, will guess from the environment.
"""
assert
os
.
name
!=
'nt'
,
"GPUUtilizationTracker does not support windows!"
if
devices
is
None
:
self
.
_devices
=
devices
self
.
_enabled
=
True
def
_guess_devices
(
self
):
env
=
os
.
environ
.
get
(
'CUDA_VISIBLE_DEVICES'
)
if
env
is
None
:
self
.
_
devices
=
list
(
range
(
get_num_gpu
()))
if
len
(
self
.
_
devices
)
>
1
:
devices
=
list
(
range
(
get_num_gpu
()))
if
len
(
devices
)
>
1
:
logger
.
warn
(
"[GPUUtilizationTracker] Both devices and CUDA_VISIBLE_DEVICES are None! "
"Will monitor all {} visible GPUs!"
.
format
(
len
(
self
.
_
devices
)))
"Will monitor all {} visible GPUs!"
.
format
(
len
(
devices
)))
else
:
if
len
(
env
):
self
.
_
devices
=
list
(
map
(
int
,
env
.
split
(
','
)))
devices
=
list
(
map
(
int
,
env
.
split
(
','
)))
else
:
self
.
_devices
=
[]
devices
=
[]
return
devices
def
_setup_graph
(
self
):
# special heuristics for Horovod
from
..train
import
HorovodTrainer
if
isinstance
(
self
.
trainer
,
HorovodTrainer
):
if
self
.
trainer
.
mpi_enabled
():
logger
.
warn
(
"GPUUtilizationTracker is disabled under MPI."
)
self
.
_enabled
=
False
return
else
:
self
.
_devices
=
devices
self
.
_devices
=
[
self
.
trainer
.
hvd
.
local_rank
()]
if
self
.
_devices
is
None
:
self
.
_devices
=
self
.
_guess_devices
()
assert
len
(
self
.
_devices
),
"[GPUUtilizationTracker] No GPU device given!"
def
_setup_graph
(
self
):
self
.
_evt
=
mp
.
Event
()
self
.
_stop_evt
=
mp
.
Event
()
self
.
_queue
=
mp
.
Queue
()
...
...
@@ -69,9 +84,11 @@ class GPUUtilizationTracker(Callback):
assert
gpu_available_in_session
(),
"[GPUUtilizationTracker] needs GPU!"
def
_before_epoch
(
self
):
if
self
.
_enabled
:
self
.
_evt
.
set
()
def
_after_epoch
(
self
):
if
self
.
_enabled
:
while
self
.
_evt
.
is_set
():
# unlikely, unless the epoch is extremely fast
pass
self
.
_evt
.
set
()
...
...
@@ -79,6 +96,8 @@ class GPUUtilizationTracker(Callback):
def
_trigger_epoch
(
self
):
# Don't do this in after_epoch because
# before,after_epoch are supposed to be extremely fast by design.
if
not
self
.
_enabled
:
return
try
:
stats
=
self
.
_queue
.
get
(
timeout
=
60
)
except
queue
.
Empty
:
...
...
@@ -94,6 +113,7 @@ class GPUUtilizationTracker(Callback):
self
.
trainer
.
monitors
.
put_scalar
(
'GPUUtil/{}'
.
format
(
dev
),
stats
[
idx
])
def
_after_train
(
self
):
if
self
.
_enabled
:
self
.
_stop_evt
.
set
()
self
.
_evt
.
set
()
self
.
_proc
.
terminate
()
...
...
tensorpack/tfutils/common.py
View file @
cbd698ad
...
...
@@ -221,7 +221,7 @@ def collect_env_info():
# Other important dependencies:
try
:
import
horovod
data
.
append
((
"
h
orovod"
,
horovod
.
__version__
))
data
.
append
((
"
H
orovod"
,
horovod
.
__version__
))
except
ImportError
:
pass
...
...
tensorpack/tfutils/sessinit.py
View file @
cbd698ad
...
...
@@ -12,7 +12,7 @@ from .varmanip import SessionUpdate, get_checkpoint_path, get_savename_from_varn
__all__
=
[
'SessionInit'
,
'ChainInit'
,
'SaverRestore'
,
'SaverRestoreRelaxed'
,
'DictRestore'
,
'JustCurrentSession'
,
'get_model_loader'
]
'JustCurrentSession'
,
'get_model_loader'
,
'SmartRestore'
]
class
SessionInit
(
object
):
...
...
@@ -260,32 +260,52 @@ class ChainInit(SessionInit):
i
.
_run_init
(
sess
)
def
get_model_loader
(
filename
,
ignore_mismatch
=
False
):
def
SmartRestore
(
obj
,
ignore_mismatch
=
False
):
"""
Get a corresponding model loader by looking at the file name.
Create a :class:`SessionInit` to be loaded to a session,
automatically from any supported objects, with some smart heuristics.
The object can be:
+ A TF checkpoint
+ A dict of numpy arrays
+ A npz file
+ An empty string or None
+ A list of supported objects
Args:
filename (str): either a tensorflow checkpoint, or a npz file.
obj: a supported object
ignore_mismatch (bool): ignore failures when the value and the
variable does not match in their shapes.
If False, it will throw exception on such errors.
If True, it will only print a warning.
Returns:
SessInit: either a :class:`DictRestore` (if name ends with 'npy/npz') or
:class:`SaverRestore` (otherwise).
SessionInit:
"""
assert
isinstance
(
filename
,
six
.
string_types
),
filename
filename
=
os
.
path
.
expanduser
(
filename
)
if
not
obj
:
return
JustCurrentSession
()
if
isinstance
(
obj
,
list
):
return
ChainInit
([
SmartRestore
(
x
,
ignore_mismatch
=
ignore_mismatch
)
for
x
in
obj
])
if
isinstance
(
obj
,
six
.
string_types
):
obj
=
os
.
path
.
expanduser
(
obj
)
if
obj
.
endswith
(
".npy"
)
or
obj
.
endswith
(
".npz"
):
assert
tf
.
gfile
.
Exists
(
obj
),
"File {} does not exist!"
.
format
(
obj
)
filename
=
obj
logger
.
info
(
"Loading dictionary from {} ..."
.
format
(
filename
))
if
filename
.
endswith
(
'.npy'
):
assert
tf
.
gfile
.
Exists
(
filename
),
filename
return
DictRestore
(
np
.
load
(
filename
,
encoding
=
'latin1'
)
.
item
(),
ignore_mismatch
=
ignore_mismatch
)
obj
=
np
.
load
(
filename
,
encoding
=
'latin1'
)
.
item
()
elif
filename
.
endswith
(
'.npz'
):
assert
tf
.
gfile
.
Exists
(
filename
),
filename
obj
=
np
.
load
(
filename
)
return
DictRestore
(
dict
(
obj
),
ignore_mismatch
=
ignore_mismatch
)
else
:
if
ignore_mismatch
:
return
SaverRestoreRelaxed
(
filename
)
obj
=
dict
(
np
.
load
(
filename
))
elif
len
(
tf
.
gfile
.
Glob
(
obj
+
"*"
)):
# Assume to be a TF checkpoint.
# A TF checkpoint must be a prefix of an actual file.
return
(
SaverRestoreRelaxed
if
ignore_mismatch
else
SaverRestore
)(
obj
)
else
:
return
SaverRestore
(
filename
)
raise
ValueError
(
"Invalid argument to SmartRestore: "
+
obj
)
if
isinstance
(
obj
,
dict
):
return
DictRestore
(
obj
,
ignore_mismatch
=
ignore_mismatch
)
raise
ValueError
(
"Invalid argument to SmartRestore: "
+
type
(
obj
))
get_model_loader
=
SmartRestore
tensorpack/train/trainers.py
View file @
cbd698ad
...
...
@@ -74,11 +74,11 @@ class QueueInputTrainer(SimpleTrainer):
class
SyncMultiGPUTrainerParameterServer
(
SingleCostTrainer
):
__doc__
=
SyncMultiGPUParameterServerBuilder
.
__doc__
__doc__
=
SyncMultiGPUParameterServerBuilder
.
__doc__
+
"""
Attributes:
devices (list[int]): List of GPU ids.
devices
=
None
"""
List of GPU ids.
"""
@
map_arg
(
gpus
=
_int_to_range
)
...
...
@@ -117,11 +117,11 @@ def SyncMultiGPUTrainer(gpus):
class
AsyncMultiGPUTrainer
(
SingleCostTrainer
):
__doc__
=
AsyncMultiGPUBuilder
.
__doc__
__doc__
=
AsyncMultiGPUBuilder
.
__doc__
+
"""
Attributes:
devices (list[int]): List of GPU ids.
devices
=
None
"""
List of GPU ids.
"""
@
map_arg
(
gpus
=
_int_to_range
)
...
...
@@ -146,15 +146,12 @@ class AsyncMultiGPUTrainer(SingleCostTrainer):
class
SyncMultiGPUTrainerReplicated
(
SingleCostTrainer
):
__doc__
=
SyncMultiGPUReplicatedBuilder
.
__doc__
__doc__
=
SyncMultiGPUReplicatedBuilder
.
__doc__
+
"""
devices
=
None
"""
List of GPU ids.
"""
Attributes:
devices (list[int]): List of GPU ids.
BROADCAST_EVERY_EPOCH
=
True
"""
BROADCAST_EVERY_EPOCH (bool):
Whether to broadcast the variables every epoch.
Theoretically this is a no-op (because the variables
are supposed to be in-sync).
...
...
@@ -162,6 +159,8 @@ class SyncMultiGPUTrainerReplicated(SingleCostTrainer):
certain numerical issues in practice.
"""
BROADCAST_EVERY_EPOCH
=
True
@
map_arg
(
gpus
=
_int_to_range
)
def
__init__
(
self
,
gpus
,
average
=
True
,
mode
=
None
):
"""
...
...
@@ -338,6 +337,10 @@ class HorovodTrainer(SingleCostTrainer):
# If using all GPUs, you can always skip the `CUDA_VISIBLE_DEVICES` option.
# There are other MPI options that can potentially improve performance especially on special hardwares.
Horovod can also be launched without MPI. See
`its documentation <https://github.com/horovod/horovod#running-horovod>`_
for more details.
Note:
1. To reach the maximum speed in your system, there are many options to tune
for Horovod installation and in the MPI command line.
...
...
@@ -348,9 +351,10 @@ class HorovodTrainer(SingleCostTrainer):
must be avoided.
You can, however, use `tf.config.experimental.list_physical_devices('GPU')`, introduced in TF 1.14.
2. MPI does not like `fork()`. If your dataflow contains multiprocessing, it may cause problems.
3. Horovod supports both MPI and gloo. There are a few drawbacks of the MPI backend:
3. MPI sometimes fails to kill all processes in the end. Be sure to check it afterwards.
+ MPI does not like `fork()`. If your code (e.g. dataflow) contains multiprocessing, it may cause problems.
+ MPI sometimes fails to kill all processes in the end. Be sure to check it afterwards.
4. Keep in mind that there is one process running the script per GPU, therefore:
...
...
@@ -364,7 +368,8 @@ class HorovodTrainer(SingleCostTrainer):
+ Callbacks have an option to be run only in the chief process, or in all processes.
See :meth:`Callback.set_chief_only()`. Most callbacks have a reasonable
default already, but certain callbacks may not behave properly by default. Report an issue if you find any.
default already, but certain callbacks may need your customization.
Report an issue if you find any bad defaults.
+ You can use Horovod API such as `hvd.rank()` to know which process you are and choose
different code path. Chief process has rank 0.
...
...
@@ -373,7 +378,18 @@ class HorovodTrainer(SingleCostTrainer):
`ResNet-Horovod <https://github.com/tensorpack/benchmarks/tree/master/ResNet-Horovod>`_
for a full example which has handled these common issues.
This example can train ImageNet in roughly an hour following the paper's setup.
Attributes:
BROADCAST_EVERY_EPOCH (bool):
Whether to broadcast the variables every epoch.
Theoretically this is a no-op (because the variables
are supposed to be in-sync).
But this cheap operation may help prevent
certain numerical issues in practice.
"""
BROADCAST_EVERY_EPOCH
=
True
def
__init__
(
self
,
average
=
True
,
compression
=
None
):
"""
Args:
...
...
@@ -399,6 +415,16 @@ class HorovodTrainer(SingleCostTrainer):
logger
.
info
(
"[HorovodTrainer] local rank={}"
.
format
(
self
.
_local_rank
))
super
(
HorovodTrainer
,
self
)
.
__init__
()
def
mpi_enabled
(
self
):
"""
Returns:
bool: whether hvd is currently running under MPI
"""
try
:
return
self
.
hvd
.
mpi_enabled
()
except
AttributeError
:
return
False
def
allreduce
(
self
,
grads
):
if
self
.
hvd
.
size
()
==
1
:
return
grads
...
...
@@ -424,7 +450,10 @@ class HorovodTrainer(SingleCostTrainer):
opt
=
get_opt_fn
()
self
.
train_op
=
opt
.
apply_gradients
(
grads
,
name
=
'train_op'
)
cb
=
CallbackFactory
(
before_train
=
self
.
broadcast
,
trigger
=
self
.
broadcast
)
.
set_chief_only
(
False
)
cb
=
CallbackFactory
(
before_train
=
self
.
broadcast
,
trigger
=
self
.
broadcast
if
self
.
BROADCAST_EVERY_EPOCH
else
None
)
.
set_chief_only
(
False
)
return
[
cb
]
def
broadcast
(
self
,
_
):
...
...
@@ -502,3 +531,10 @@ class BytePSTrainer(HorovodTrainer):
self
.
_has_compression
=
False
logger
.
info
(
"[BytePSTrainer] local rank={}"
.
format
(
self
.
_local_rank
))
SingleCostTrainer
.
__init__
(
self
)
def
mpi_enabled
(
self
):
"""
Returns:
bool: whether hvd is currently running under MPI
"""
return
False
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment