Commit 25b31f68 authored by Yuxin Wu's avatar Yuxin Wu

clean-up trainv1

parent 0b2f3c11
# Faster-RCNN / Mask-RCNN on COCO # Faster-RCNN / Mask-RCNN on COCO
This example provides a minimal (only 1.6k lines) and faithful implementation of the following papers: This example provides a minimal (<2k lines) and faithful implementation of the following papers:
+ [Faster R-CNN: Towards Real-Time Object Detection with Region Proposal Networks](https://arxiv.org/abs/1506.01497) + [Faster R-CNN: Towards Real-Time Object Detection with Region Proposal Networks](https://arxiv.org/abs/1506.01497)
+ [Feature Pyramid Networks for Object Detection](https://arxiv.org/abs/1612.03144) + [Feature Pyramid Networks for Object Detection](https://arxiv.org/abs/1612.03144)
+ [Mask R-CNN](https://arxiv.org/abs/1703.06870) + [Mask R-CNN](https://arxiv.org/abs/1703.06870)
with the support of:
+ Multi-GPU / distributed training
+ [Cross-GPU BatchNorm](https://arxiv.org/abs/1711.07240)
## Dependencies ## Dependencies
+ Python 3; TensorFlow >= 1.6 (1.4 or 1.5 can run but may crash due to a TF bug); + Python 3; TensorFlow >= 1.6 (1.4 or 1.5 can run but may crash due to a TF bug);
+ [pycocotools](https://github.com/pdollar/coco/tree/master/PythonAPI/pycocotools), OpenCV. + [pycocotools](https://github.com/pdollar/coco/tree/master/PythonAPI/pycocotools), OpenCV.
......
...@@ -101,11 +101,6 @@ class Trainer(object): ...@@ -101,11 +101,6 @@ class Trainer(object):
""" """
def __init__(self): def __init__(self):
"""
config is only for compatibility reasons in case you're
using custom trainers with old-style API.
You should never use config.
"""
self._callbacks = [] self._callbacks = []
self.loop = TrainLoop() self.loop = TrainLoop()
...@@ -310,22 +305,13 @@ class Trainer(object): ...@@ -310,22 +305,13 @@ class Trainer(object):
session_creator, session_init, session_creator, session_init,
steps_per_epoch, starting_epoch, max_epoch) steps_per_epoch, starting_epoch, max_epoch)
# create the old trainer when called with TrainConfig
def __new__(cls, *args, **kwargs): def __new__(cls, *args, **kwargs):
if (len(args) > 0 and isinstance(args[0], TrainConfig)) \ if (len(args) > 0 and isinstance(args[0], TrainConfig)) \
or 'config' in kwargs: or 'config' in kwargs:
name = cls.__name__ logger.error("You're calling new trainers with old trainer API!")
try: logger.error("See https://github.com/tensorpack/tensorpack/issues/458 for more information.")
import tensorpack.trainv1 as old_train_mod # noqa import sys
old_trainer = getattr(old_train_mod, name) sys.exit(1)
except AttributeError:
# custom trainer. has to live with it
return super(Trainer, cls).__new__(cls)
else:
logger.warn("You're calling new trainers with old trainer API!")
logger.warn("Now it returns the old trainer for you, please switch to use new trainers soon!")
logger.warn("See https://github.com/tensorpack/tensorpack/issues/458 for more information.")
return old_trainer(*args, **kwargs)
else: else:
return super(Trainer, cls).__new__(cls) return super(Trainer, cls).__new__(cls)
......
# -*- coding: utf-8 -*-
# File: __init__.py
from pkgutil import iter_modules
import os
import os.path
__all__ = []
def global_import(name):
p = __import__(name, globals(), locals(), level=1)
lst = p.__all__ if '__all__' in dir(p) else []
del globals()[name]
for k in lst:
globals()[k] = p.__dict__[k]
__all__.append(k)
_CURR_DIR = os.path.dirname(__file__)
_SKIP = ['utility']
for _, module_name, _ in iter_modules(
[_CURR_DIR]):
srcpath = os.path.join(_CURR_DIR, module_name + '.py')
if not os.path.isfile(srcpath):
continue
if module_name.startswith('_'):
continue
if module_name not in _SKIP:
global_import(module_name)
This diff is collapsed.
# -*- coding: utf-8 -*-
# File: config.py
__all__ = ['TrainConfig']
from ..train.config import TrainConfig
# -*- coding: utf-8 -*-
# File: distributed.py
import os
from ..utils import logger
from ..callbacks import RunOp
from ..tfutils.sesscreate import NewSessionCreator
from ..tfutils import get_global_step_var
from ..tfutils.distributed import get_distributed_session_creator
from ..graph_builder.distributed import DistributedReplicatedBuilder
from ..graph_builder.utils import override_to_local_variable
from .base import Trainer
__all__ = ['DistributedTrainerReplicated']
class DistributedTrainerReplicated(Trainer):
__doc__ = DistributedReplicatedBuilder.__doc__
def __init__(self, config, server):
"""
Args:
config(TrainConfig): Must contain 'model' and 'data'.
server(tf.train.Server): the server object with ps and workers
"""
assert config.data is not None and config.model is not None
self.server = server
self.job_name = server.server_def.job_name
assert self.job_name in ['ps', 'worker'], self.job_name
if self.job_name == 'worker':
# ps doesn't build any graph
self._builder = DistributedReplicatedBuilder(config.tower, server)
self.is_chief = self._builder.is_chief
else:
self.is_chief = False
logger.info("Distributed training on cluster:\n" + str(server.server_def.cluster))
self._input_source = config.data
super(DistributedTrainerReplicated, self).__init__(config)
def _setup(self):
if self.job_name == 'ps':
logger.info("Running ps {}".format(self.server.server_def.task_index))
logger.info("Kill me with 'kill {}'".format(os.getpid()))
self.server.join() # this will never return tensorflow#4713
return
with override_to_local_variable():
get_global_step_var() # gs should be local
# input source may create variable (queue size summary)
# TODO This is not good because we don't know from here
# whether something should be global or local. We now assume
# they should be local.
cbs = self._input_source.setup(self.model.get_inputs_desc())
self._config.callbacks.extend(cbs)
self.train_op, initial_sync_op, model_sync_op = self._builder.build(
lambda: self.model._build_graph_get_grads(
*self._input_source.get_input_tensors()),
self.model.get_optimizer)
# initial local_vars syncing
cb = RunOp(lambda: initial_sync_op,
run_before=True, run_as_trigger=False, verbose=True)
cb.chief_only = False
self.register_callback(cb)
# model_variables syncing
if model_sync_op:
cb = RunOp(lambda: model_sync_op,
run_before=False, run_as_trigger=True, verbose=True)
logger.warn("For efficiency, local MODEL_VARIABLES are only synced to PS once "
"every epoch. Be careful if you save the model more frequently than this.")
self.register_callback(cb)
self._set_session_creator()
def _set_session_creator(self):
old_sess_creator = self._config.session_creator
if not isinstance(old_sess_creator, NewSessionCreator) \
or old_sess_creator.user_provided_config:
raise ValueError(
"Cannot set session_creator or session_config for distributed training! "
"To use a custom session config, pass it to tf.train.Server.")
self._config.session_creator = get_distributed_session_creator(self.server)
@property
def _main_tower_vs_name(self):
return "tower0"
# -*- coding: utf-8 -*-
# File: interface.py
__all__ = ['launch_train_with_config']
from ..train.interface import launch_train_with_config
# -*- coding: utf-8 -*-
# File: multigpu.py
import tensorflow as tf
from ..callbacks.graph import RunOp
from ..utils.develop import log_deprecated
from ..input_source import QueueInput, StagingInput, DummyConstantInput
from ..graph_builder.training import (
SyncMultiGPUParameterServerBuilder,
SyncMultiGPUReplicatedBuilder,
AsyncMultiGPUBuilder,
DataParallelBuilder)
from .base import Trainer
__all__ = ['MultiGPUTrainerBase',
'SyncMultiGPUTrainerReplicated',
'SyncMultiGPUTrainerParameterServer',
'AsyncMultiGPUTrainer',
'SyncMultiGPUTrainer']
class MultiGPUTrainerBase(Trainer):
"""
For backward compatibility only
"""
def build_on_multi_tower(towers, func, devices=None, use_vs=None):
log_deprecated("MultiGPUTrainerBase.build_on_multitower",
"Please use DataParallelBuilder.build_on_towers", "2018-01-31")
return DataParallelBuilder.build_on_towers(towers, func, devices, use_vs)
def apply_prefetch_policy(config, gpu_prefetch=True):
assert (config.data is not None or config.dataflow is not None) and config.model is not None
if config.data is None and config.dataflow is not None:
# always use Queue prefetch
config.data = QueueInput(config.dataflow)
config.dataflow = None
if len(config.tower) > 1 and gpu_prefetch:
assert tf.test.is_gpu_available()
# seem to only improve on >1 GPUs
if not isinstance(config.data, (StagingInput, DummyConstantInput)):
config.data = StagingInput(config.data)
class SyncMultiGPUTrainerParameterServer(Trainer):
__doc__ = SyncMultiGPUParameterServerBuilder.__doc__
def __init__(self, config, ps_device='gpu', gpu_prefetch=True):
"""
Args:
config(TrainConfig): Must contain 'model' and either one of 'data' or 'dataflow'.
ps_device: either 'gpu' or 'cpu', where variables are stored. Setting to 'cpu' might help when #gpu>=4
gpu_prefetch(bool): whether to prefetch the data to each GPU. Usually improve performance.
"""
apply_prefetch_policy(config, gpu_prefetch)
self._input_source = config.data
assert ps_device in ['gpu', 'cpu'], ps_device
self._ps_device = ps_device
super(SyncMultiGPUTrainerParameterServer, self).__init__(config)
def _setup(self):
callbacks = self._input_source.setup(self.model.get_inputs_desc())
self.train_op = SyncMultiGPUParameterServerBuilder(
self._config.tower, self._ps_device).build(
lambda: self.model._build_graph_get_grads(
*self._input_source.get_input_tensors()),
self.model.get_optimizer)
self._config.callbacks.extend(callbacks)
def SyncMultiGPUTrainer(config):
"""
Alias for ``SyncMultiGPUTrainerParameterServer(config, ps_device='gpu')``,
as this is the most commonly used synchronous multigpu trainer (but may
not be more efficient than the other).
"""
return SyncMultiGPUTrainerParameterServer(config, ps_device='gpu')
class SyncMultiGPUTrainerReplicated(Trainer):
__doc__ = SyncMultiGPUReplicatedBuilder.__doc__
def __init__(self, config, gpu_prefetch=True):
"""
Args:
config, gpu_prefetch: same as in :class:`SyncMultiGPUTrainerParameterServer`
"""
apply_prefetch_policy(config, gpu_prefetch)
self._input_source = config.data
super(SyncMultiGPUTrainerReplicated, self).__init__(config)
def _setup(self):
callbacks = self._input_source.setup(self.model.get_inputs_desc())
self.train_op, post_init_op = SyncMultiGPUReplicatedBuilder(
self._config.tower).build(
lambda: self.model._build_graph_get_grads(
*self._input_source.get_input_tensors()),
self.model.get_optimizer)
cb = RunOp(
lambda: post_init_op,
run_before=True, run_as_trigger=True, verbose=True)
self._config.callbacks.extend(callbacks + [cb])
class AsyncMultiGPUTrainer(Trainer):
__doc__ = AsyncMultiGPUBuilder.__doc__
def __init__(self, config, scale_gradient=True):
"""
Args:
config(TrainConfig): Must contain 'model' and either one of 'data' or 'dataflow'.
scale_gradient (bool): if True, will scale each gradient by ``1.0/nr_gpu``.
"""
apply_prefetch_policy(config)
self._input_source = config.data
self._scale_gradient = scale_gradient
super(AsyncMultiGPUTrainer, self).__init__(config)
def _setup(self):
callbacks = self._input_source.setup(self.model.get_inputs_desc())
self.train_op = AsyncMultiGPUBuilder(
self._config.tower, self._scale_gradient).build(
lambda: self.model._build_graph_get_grads(
*self._input_source.get_input_tensors()),
self.model.get_optimizer)
self._config.callbacks.extend(callbacks)
# -*- coding: utf-8 -*-
# File: simple.py
from .base import Trainer
from ..tfutils.tower import TowerContext
from ..utils import logger
from ..input_source import FeedInput, QueueInput
__all__ = ['SimpleTrainer', 'QueueInputTrainer']
class SimpleTrainer(Trainer):
""" A naive single-tower single-cost demo trainer.
It simply builds one tower and minimize `model.cost`.
It supports both InputSource and DataFlow.
When DataFlow is given instead of InputSource, the InputSource to be
used will be ``FeedInput(df)`` (no prefetch).
"""
def __init__(self, config):
"""
Args:
config (TrainConfig): Must contain 'model' and either one of 'data' or 'dataflow'.
"""
assert len(config.tower) == 1, \
"Got nr_tower={}, but doesn't support multigpu!" \
" Use Sync/AsyncMultiGPUTrainer instead.".format(len(config.tower))
assert (config.data is not None or config.dataflow is not None) and config.model is not None
if config.dataflow is None:
self._input_source = config.data
else:
self._input_source = FeedInput(config.dataflow)
logger.warn("FeedInput is slow (and this is the default of SimpleTrainer). "
"Consider QueueInput or other InputSource instead.")
super(SimpleTrainer, self).__init__(config)
def _setup(self):
cbs = self._input_source.setup(self.model.get_inputs_desc())
with TowerContext('', is_training=True):
grads = self.model._build_graph_get_grads(
*self._input_source.get_input_tensors())
opt = self.model.get_optimizer()
self.train_op = opt.apply_gradients(grads, name='min_op')
self._config.callbacks.extend(cbs)
def QueueInputTrainer(config, input_queue=None):
"""
A wrapper trainer which automatically wraps ``config.dataflow`` by a :class:`QueueInput`.
It is an equivalent of ``SimpleTrainer(config)`` with ``config.data = QueueInput(dataflow)``.
Args:
config (TrainConfig): Must contain 'model' and 'dataflow'.
input_queue (tf.QueueBase): an input queue. Defaults to the :class:`QueueInput` default.
"""
assert (config.data is not None or config.dataflow is not None) and config.model is not None
if config.data is not None:
assert isinstance(config.data, QueueInput), config.data
else:
config.data = QueueInput(config.dataflow, input_queue)
config.dataflow = None
return SimpleTrainer(config)
# -*- coding: utf-8 -*-
# File: utility.py
# for backwards-compatibility
from ..graph_builder.utils import ( # noqa
override_to_local_variable, LeastLoadedDeviceSetter)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment