Commit cbd698ad authored by Yuxin Wu's avatar Yuxin Wu

get_model_loader->SmartRestore; Improve horovod integration

parent d71184c4
......@@ -377,6 +377,7 @@ _DEPRECATED_NAMES = set([
'InputDesc',
'inputs_desc',
'Augmentor',
"get_model_loader",
# renamed items that should not appear in docs
'DumpTensor',
......
......@@ -21,7 +21,7 @@ By writing callbacks to implement what to do at each place, tensorpack trainers
will call the callbacks at the proper time.
Therefore these features can be reused with one single line, as long as you are using tensorpack trainers.
For example, these are the callbacks I used when training a ResNet:
For example, here are some useful callbacks I used during model development:
```python
callbacks=[
......@@ -43,7 +43,7 @@ callbacks=[
-d type=note -d title="validation error" \\
-d body={val-error-top1} > /dev/null 2>&1',
'val-error-top1'),
# record GPU utilizations during training
# record GPU utilization during training
GPUUtilizationTracker(),
# touch a file to pause the training and start a debug shell, to observe what's going on
InjectShell(shell='ipython'),
......@@ -69,12 +69,12 @@ monitors=[ # monitors are a special kind of callbacks. these are also ena
]
```
You can see from the above snippet, that callbacks cover every detail of training, ranging from graph operations to the progress bar.
You can see from the above snippet, that callbacks cover every detail of training, from graph operations to the progress bar.
This means you can customize every part of the training to your preference, e.g. display something
different in the progress bar, evaluate part of the summaries at a different frequency, etc.
Similar concepts also exists in other frameworks, such as Keras callbacks, or
`tf.train.SessionRunHook`. But tensorpack callbacks have more functionalities in
design, and can achive much more features, as you can see above.
design, and can achieve much more features, as you can see above.
These features are not always necessary, but think about how messy the main loop would look like if you
were to write these logic together with the loops, and how easy your life will be if you could enable
......
......@@ -20,10 +20,10 @@ demos how to print all variables and their shapes in a checkpoint.
Tensorpack includes another tool to save variables to TF checkpoint, see
[save_chkpt_vars](../modules/tfutils.html#tensorpack.tfutils.varmanip.save_chkpt_vars).
## Work with npz Files in Model Zoo
## Work with .npz Files in the Model Zoo
Most models provided by tensorpack are in npz (dictionary) format,
because it's easy to manipulate without TF dependency.
because it's easy to use without TF dependency.
You can read/write them with `np.load` and `np.savez`.
[scripts/dump-model-params.py](../scripts/dump-model-params.py) can be used to remove unnecessary variables in a checkpoint
......@@ -34,24 +34,24 @@ It dumps the model to a `var-name: value` dict saved in npz format.
## Load a Model to a Session
Model loading (in both training and inference) is through the `session_init` interface.
For training, use `session_init` in `TrainConfig` or `Trainer.train()`.
For inference, use `session_init` in `PredictConfig`.
For training, use `session_init` in `TrainConfig(...)` or `Trainer.train(...)`.
For inference, use `session_init` in `PredictConfig(...)`.
There are two ways a session can be initialized:
[session_init=SaverRestore(...)](../modules/tfutils.html#tensorpack.tfutils.sessinit.SaverRestore)
which restores a TF checkpoint,
or [session_init=DictRestore(...)](../modules/tfutils.html#tensorpack.tfutils.sessinit.DictRestore) which restores a dict.
`DictRestore` is the most general loader because you can make arbitrary changes
you need (e.g., remove variables, rename variables) to the dict.
To load multiple models, use [ChainInit](../modules/tfutils.html#tensorpack.tfutils.sessinit.ChainInit).
There are a few ways a session can be initialized:
```
session_init=SmartRestore("path/to/checkpoint") # load a TF checkpoint
session_init=SmartRestore("path/to/model_zoo.npz") # load tensorpack model zoo
session_init=SmartRestore(dict_of_parameters) # load a dictionary
session_init=SmartRestore(["path1", dict2]) # load them sequentially
```
To load an npz file from tensorpack model zoo to a session, you can use `DictRestore(dict(np.load(filename)))`.
You can also use
[get_model_loader(filename)](../modules/tfutils.html#tensorpack.tfutils.sessinit.get_model_loader),
a small helper which returns either a `SaverRestore` or a `DictRestore` based on the file name.
[SmartRestore](../modules/tfutils.html#tensorpack.tfutils.sessinit.SmartRestore)
is in fact a small helper which uses some heuristics to return you one of
[SaverRestore](../modules/tfutils.html#tensorpack.tfutils.sessinit.SaverRestore) or
[DictRestore](../modules/tfutils.html#tensorpack.tfutils.sessinit.DictRestore).
They are responsible for the actual initialization work.
Whatever you use in `session_init`, this is what happen during the loading:
Whatever you use in `session_init`, this is what happens during the loading:
* Variable restoring is completely based on __exact name match__ between
variables in the current graph and variables in the `session_init` initializer.
......
# Trainers
Tensorpack follows the "define-and-run" paradigm.
TensorFlow & Tensorpack follow the "define-and-run" paradigm.
Therefore a training contains two steps:
1. __Define__: Build graph for the model.
Users can call whatever tensorflow functions to setup the graph.
Users may or may not use tensorpack `InputSource`, `ModelDesc` or other utilities to build the graph.
The goal of this step is to define "what to run" in later training steps,
and it can happen __either inside or outside__ tensorpack trainer.
The goal of this step is to define "what to run" in later training steps.
2. __Run__: Train the model (the [Trainer.train() method](/modules/train.html#tensorpack.train.Trainer.train)):
......@@ -26,7 +25,7 @@ by exploiting some universal patterns.
In research we do training of various kind.
Tensorpack trainers avoid making assumptions on what type of training
you want to do. For example, unlike Keras, tensorpack does not wrongly assume that:
1. Your training is batched
1. Your training data is batched
2. Your training is gradient-based optimization
3. Your data has `X`(inputs) and `y`(outputs)
4. You want to evaluate on zero or one validation dataset
......@@ -48,7 +47,8 @@ Users or derived trainers should implement __what the iterations are__.
In fact, the steps per epoch can be any number
and it only affects the [schedule of callbacks](callback.html).
In other words, an "epoch" in tensorpack is the __default period to run
callbacks__ (validation, summary, checkpoint, etc.). It has nothing to do with your dataset.
callbacks__ (validation, summary, checkpoint, etc.).
So this assumption effectively puts no extra constraints.
### Built-in Trainers
......
......@@ -42,7 +42,7 @@ After defining such a model, use it with `TrainConfig` and `launch_train_with_co
config = TrainConfig(
model=MyModel()
dataflow=my_dataflow,
# data=my_inputsource, # alternatively, use a customized InputSource
# data=my_inputsource, # alternatively, use an InputSource
callbacks=[...], # some default callbacks are automatically applied
# some default monitors are automatically applied
steps_per_epoch=300, # default to the size of your InputSource/DataFlow
......
......@@ -284,6 +284,8 @@ def finalize_configs(is_training):
if _C.TRAINER == 'horovod':
import horovod.tensorflow as hvd
ngpu = hvd.size()
logger.info("Horovod Rank={}, Size={}, LocalRank={}".format(
hvd.rank(), hvd.size(), hvd.local_rank()))
else:
assert 'OMPI_COMM_WORLD_SIZE' not in os.environ
ngpu = get_num_gpu()
......
......@@ -45,18 +45,17 @@ if __name__ == '__main__':
register_coco(cfg.DATA.BASEDIR) # add COCO datasets to the registry
register_balloon(cfg.DATA.BASEDIR) # add the demo balloon datasets to the registry
# Setup logger ...
# Setup logging ...
is_horovod = cfg.TRAINER == 'horovod'
if is_horovod:
hvd.init()
logger.info("Horovod Rank={}, Size={}".format(hvd.rank(), hvd.size()))
if not is_horovod or hvd.rank() == 0:
logger.set_logger_dir(args.logdir, 'd')
logger.info("Environment Information:\n" + collect_env_info())
finalize_configs(is_training=True)
# Create model
MODEL = ResNetFPNModel() if cfg.MODE_FPN else ResNetC4Model()
# Compute the training schedule from the number of GPUs ...
stepnum = cfg.TRAIN.STEPS_PER_EPOCH
# warmup is step based, lr is epoch based
......@@ -77,9 +76,7 @@ if __name__ == '__main__':
total_passes = cfg.TRAIN.LR_SCHEDULE[-1] * 8 / train_dataflow.size()
logger.info("Total passes of the training set is: {:.5g}".format(total_passes))
# Create model and callbacks ...
MODEL = ResNetFPNModel() if cfg.MODE_FPN else ResNetC4Model()
# Create callbacks ...
callbacks = [
PeriodicCallback(
ModelSaver(max_to_keep=10, keep_checkpoint_every_n_hours=1),
......@@ -93,23 +90,22 @@ if __name__ == '__main__':
ThroughputTracker(samples_per_step=cfg.TRAIN.NUM_GPUS),
EstimatedTimeLeft(median=True),
SessionRunTimeout(60000), # 1 minute timeout
GPUUtilizationTracker()
]
if cfg.TRAIN.EVAL_PERIOD > 0:
callbacks.extend([
EvalCallback(dataset, *MODEL.get_inference_tensor_names(), args.logdir)
for dataset in cfg.DATA.VAL
])
if not is_horovod:
callbacks.append(GPUUtilizationTracker())
if is_horovod and hvd.rank() > 0:
session_init = None
else:
if args.load:
# ignore mismatched values, so you can `--load` a model for fine-tuning
session_init = get_model_loader(args.load, ignore_mismatch=True)
session_init = SmartRestore(args.load, ignore_mismatch=True)
else:
session_init = get_model_loader(cfg.BACKBONE.WEIGHTS) if cfg.BACKBONE.WEIGHTS else None
session_init = SmartRestore(cfg.BACKBONE.WEIGHTS)
traincfg = TrainConfig(
model=MODEL,
......@@ -120,6 +116,7 @@ if __name__ == '__main__':
session_init=session_init,
starting_epoch=cfg.TRAIN.STARTING_EPOCH
)
if is_horovod:
trainer = HorovodTrainer(average=False)
else:
......
......@@ -37,26 +37,41 @@ class GPUUtilizationTracker(Callback):
def __init__(self, devices=None):
"""
Args:
devices (list[int]): physical GPU ids. If None, will use CUDA_VISIBLE_DEVICES
devices (list[int]): physical GPU ids to monitor. If None, will guess from the environment.
"""
assert os.name != 'nt', "GPUUtilizationTracker does not support windows!"
if devices is None:
self._devices = devices
self._enabled = True
def _guess_devices(self):
env = os.environ.get('CUDA_VISIBLE_DEVICES')
if env is None:
self._devices = list(range(get_num_gpu()))
if len(self._devices) > 1:
devices = list(range(get_num_gpu()))
if len(devices) > 1:
logger.warn("[GPUUtilizationTracker] Both devices and CUDA_VISIBLE_DEVICES are None! "
"Will monitor all {} visible GPUs!".format(len(self._devices)))
"Will monitor all {} visible GPUs!".format(len(devices)))
else:
if len(env):
self._devices = list(map(int, env.split(',')))
devices = list(map(int, env.split(',')))
else:
self._devices = []
devices = []
return devices
def _setup_graph(self):
# special heuristics for Horovod
from ..train import HorovodTrainer
if isinstance(self.trainer, HorovodTrainer):
if self.trainer.mpi_enabled():
logger.warn("GPUUtilizationTracker is disabled under MPI.")
self._enabled = False
return
else:
self._devices = devices
self._devices = [self.trainer.hvd.local_rank()]
if self._devices is None:
self._devices = self._guess_devices()
assert len(self._devices), "[GPUUtilizationTracker] No GPU device given!"
def _setup_graph(self):
self._evt = mp.Event()
self._stop_evt = mp.Event()
self._queue = mp.Queue()
......@@ -69,9 +84,11 @@ class GPUUtilizationTracker(Callback):
assert gpu_available_in_session(), "[GPUUtilizationTracker] needs GPU!"
def _before_epoch(self):
if self._enabled:
self._evt.set()
def _after_epoch(self):
if self._enabled:
while self._evt.is_set(): # unlikely, unless the epoch is extremely fast
pass
self._evt.set()
......@@ -79,6 +96,8 @@ class GPUUtilizationTracker(Callback):
def _trigger_epoch(self):
# Don't do this in after_epoch because
# before,after_epoch are supposed to be extremely fast by design.
if not self._enabled:
return
try:
stats = self._queue.get(timeout=60)
except queue.Empty:
......@@ -94,6 +113,7 @@ class GPUUtilizationTracker(Callback):
self.trainer.monitors.put_scalar('GPUUtil/{}'.format(dev), stats[idx])
def _after_train(self):
if self._enabled:
self._stop_evt.set()
self._evt.set()
self._proc.terminate()
......
......@@ -221,7 +221,7 @@ def collect_env_info():
# Other important dependencies:
try:
import horovod
data.append(("horovod", horovod.__version__))
data.append(("Horovod", horovod.__version__))
except ImportError:
pass
......
......@@ -12,7 +12,7 @@ from .varmanip import SessionUpdate, get_checkpoint_path, get_savename_from_varn
__all__ = ['SessionInit', 'ChainInit',
'SaverRestore', 'SaverRestoreRelaxed', 'DictRestore',
'JustCurrentSession', 'get_model_loader']
'JustCurrentSession', 'get_model_loader', 'SmartRestore']
class SessionInit(object):
......@@ -260,32 +260,52 @@ class ChainInit(SessionInit):
i._run_init(sess)
def get_model_loader(filename, ignore_mismatch=False):
def SmartRestore(obj, ignore_mismatch=False):
"""
Get a corresponding model loader by looking at the file name.
Create a :class:`SessionInit` to be loaded to a session,
automatically from any supported objects, with some smart heuristics.
The object can be:
+ A TF checkpoint
+ A dict of numpy arrays
+ A npz file
+ An empty string or None
+ A list of supported objects
Args:
filename (str): either a tensorflow checkpoint, or a npz file.
obj: a supported object
ignore_mismatch (bool): ignore failures when the value and the
variable does not match in their shapes.
If False, it will throw exception on such errors.
If True, it will only print a warning.
Returns:
SessInit: either a :class:`DictRestore` (if name ends with 'npy/npz') or
:class:`SaverRestore` (otherwise).
SessionInit:
"""
assert isinstance(filename, six.string_types), filename
filename = os.path.expanduser(filename)
if not obj:
return JustCurrentSession()
if isinstance(obj, list):
return ChainInit([SmartRestore(x, ignore_mismatch=ignore_mismatch) for x in obj])
if isinstance(obj, six.string_types):
obj = os.path.expanduser(obj)
if obj.endswith(".npy") or obj.endswith(".npz"):
assert tf.gfile.Exists(obj), "File {} does not exist!".format(obj)
filename = obj
logger.info("Loading dictionary from {} ...".format(filename))
if filename.endswith('.npy'):
assert tf.gfile.Exists(filename), filename
return DictRestore(np.load(filename, encoding='latin1').item(), ignore_mismatch=ignore_mismatch)
obj = np.load(filename, encoding='latin1').item()
elif filename.endswith('.npz'):
assert tf.gfile.Exists(filename), filename
obj = np.load(filename)
return DictRestore(dict(obj), ignore_mismatch=ignore_mismatch)
else:
if ignore_mismatch:
return SaverRestoreRelaxed(filename)
obj = dict(np.load(filename))
elif len(tf.gfile.Glob(obj + "*")):
# Assume to be a TF checkpoint.
# A TF checkpoint must be a prefix of an actual file.
return (SaverRestoreRelaxed if ignore_mismatch else SaverRestore)(obj)
else:
return SaverRestore(filename)
raise ValueError("Invalid argument to SmartRestore: " + obj)
if isinstance(obj, dict):
return DictRestore(obj, ignore_mismatch=ignore_mismatch)
raise ValueError("Invalid argument to SmartRestore: " + type(obj))
get_model_loader = SmartRestore
......@@ -74,11 +74,11 @@ class QueueInputTrainer(SimpleTrainer):
class SyncMultiGPUTrainerParameterServer(SingleCostTrainer):
__doc__ = SyncMultiGPUParameterServerBuilder.__doc__
__doc__ = SyncMultiGPUParameterServerBuilder.__doc__ + """
Attributes:
devices (list[int]): List of GPU ids.
devices = None
"""
List of GPU ids.
"""
@map_arg(gpus=_int_to_range)
......@@ -117,11 +117,11 @@ def SyncMultiGPUTrainer(gpus):
class AsyncMultiGPUTrainer(SingleCostTrainer):
__doc__ = AsyncMultiGPUBuilder.__doc__
__doc__ = AsyncMultiGPUBuilder.__doc__ + """
Attributes:
devices (list[int]): List of GPU ids.
devices = None
"""
List of GPU ids.
"""
@map_arg(gpus=_int_to_range)
......@@ -146,15 +146,12 @@ class AsyncMultiGPUTrainer(SingleCostTrainer):
class SyncMultiGPUTrainerReplicated(SingleCostTrainer):
__doc__ = SyncMultiGPUReplicatedBuilder.__doc__
__doc__ = SyncMultiGPUReplicatedBuilder.__doc__ + """
devices = None
"""
List of GPU ids.
"""
Attributes:
devices (list[int]): List of GPU ids.
BROADCAST_EVERY_EPOCH = True
"""
BROADCAST_EVERY_EPOCH (bool):
Whether to broadcast the variables every epoch.
Theoretically this is a no-op (because the variables
are supposed to be in-sync).
......@@ -162,6 +159,8 @@ class SyncMultiGPUTrainerReplicated(SingleCostTrainer):
certain numerical issues in practice.
"""
BROADCAST_EVERY_EPOCH = True
@map_arg(gpus=_int_to_range)
def __init__(self, gpus, average=True, mode=None):
"""
......@@ -338,6 +337,10 @@ class HorovodTrainer(SingleCostTrainer):
# If using all GPUs, you can always skip the `CUDA_VISIBLE_DEVICES` option.
# There are other MPI options that can potentially improve performance especially on special hardwares.
Horovod can also be launched without MPI. See
`its documentation <https://github.com/horovod/horovod#running-horovod>`_
for more details.
Note:
1. To reach the maximum speed in your system, there are many options to tune
for Horovod installation and in the MPI command line.
......@@ -348,9 +351,10 @@ class HorovodTrainer(SingleCostTrainer):
must be avoided.
You can, however, use `tf.config.experimental.list_physical_devices('GPU')`, introduced in TF 1.14.
2. MPI does not like `fork()`. If your dataflow contains multiprocessing, it may cause problems.
3. Horovod supports both MPI and gloo. There are a few drawbacks of the MPI backend:
3. MPI sometimes fails to kill all processes in the end. Be sure to check it afterwards.
+ MPI does not like `fork()`. If your code (e.g. dataflow) contains multiprocessing, it may cause problems.
+ MPI sometimes fails to kill all processes in the end. Be sure to check it afterwards.
4. Keep in mind that there is one process running the script per GPU, therefore:
......@@ -364,7 +368,8 @@ class HorovodTrainer(SingleCostTrainer):
+ Callbacks have an option to be run only in the chief process, or in all processes.
See :meth:`Callback.set_chief_only()`. Most callbacks have a reasonable
default already, but certain callbacks may not behave properly by default. Report an issue if you find any.
default already, but certain callbacks may need your customization.
Report an issue if you find any bad defaults.
+ You can use Horovod API such as `hvd.rank()` to know which process you are and choose
different code path. Chief process has rank 0.
......@@ -373,7 +378,18 @@ class HorovodTrainer(SingleCostTrainer):
`ResNet-Horovod <https://github.com/tensorpack/benchmarks/tree/master/ResNet-Horovod>`_
for a full example which has handled these common issues.
This example can train ImageNet in roughly an hour following the paper's setup.
Attributes:
BROADCAST_EVERY_EPOCH (bool):
Whether to broadcast the variables every epoch.
Theoretically this is a no-op (because the variables
are supposed to be in-sync).
But this cheap operation may help prevent
certain numerical issues in practice.
"""
BROADCAST_EVERY_EPOCH = True
def __init__(self, average=True, compression=None):
"""
Args:
......@@ -399,6 +415,16 @@ class HorovodTrainer(SingleCostTrainer):
logger.info("[HorovodTrainer] local rank={}".format(self._local_rank))
super(HorovodTrainer, self).__init__()
def mpi_enabled(self):
"""
Returns:
bool: whether hvd is currently running under MPI
"""
try:
return self.hvd.mpi_enabled()
except AttributeError:
return False
def allreduce(self, grads):
if self.hvd.size() == 1:
return grads
......@@ -424,7 +450,10 @@ class HorovodTrainer(SingleCostTrainer):
opt = get_opt_fn()
self.train_op = opt.apply_gradients(grads, name='train_op')
cb = CallbackFactory(before_train=self.broadcast, trigger=self.broadcast).set_chief_only(False)
cb = CallbackFactory(
before_train=self.broadcast,
trigger=self.broadcast if self.BROADCAST_EVERY_EPOCH else None
).set_chief_only(False)
return [cb]
def broadcast(self, _):
......@@ -502,3 +531,10 @@ class BytePSTrainer(HorovodTrainer):
self._has_compression = False
logger.info("[BytePSTrainer] local rank={}".format(self._local_rank))
SingleCostTrainer.__init__(self)
def mpi_enabled(self):
"""
Returns:
bool: whether hvd is currently running under MPI
"""
return False
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment