Commit 25b31f68 authored by Yuxin Wu's avatar Yuxin Wu

clean-up trainv1

parent 0b2f3c11
# Faster-RCNN / Mask-RCNN on COCO # Faster-RCNN / Mask-RCNN on COCO
This example provides a minimal (only 1.6k lines) and faithful implementation of the following papers: This example provides a minimal (<2k lines) and faithful implementation of the following papers:
+ [Faster R-CNN: Towards Real-Time Object Detection with Region Proposal Networks](https://arxiv.org/abs/1506.01497) + [Faster R-CNN: Towards Real-Time Object Detection with Region Proposal Networks](https://arxiv.org/abs/1506.01497)
+ [Feature Pyramid Networks for Object Detection](https://arxiv.org/abs/1612.03144) + [Feature Pyramid Networks for Object Detection](https://arxiv.org/abs/1612.03144)
+ [Mask R-CNN](https://arxiv.org/abs/1703.06870) + [Mask R-CNN](https://arxiv.org/abs/1703.06870)
with the support of:
+ Multi-GPU / distributed training
+ [Cross-GPU BatchNorm](https://arxiv.org/abs/1711.07240)
## Dependencies ## Dependencies
+ Python 3; TensorFlow >= 1.6 (1.4 or 1.5 can run but may crash due to a TF bug); + Python 3; TensorFlow >= 1.6 (1.4 or 1.5 can run but may crash due to a TF bug);
+ [pycocotools](https://github.com/pdollar/coco/tree/master/PythonAPI/pycocotools), OpenCV. + [pycocotools](https://github.com/pdollar/coco/tree/master/PythonAPI/pycocotools), OpenCV.
......
...@@ -101,11 +101,6 @@ class Trainer(object): ...@@ -101,11 +101,6 @@ class Trainer(object):
""" """
def __init__(self): def __init__(self):
"""
config is only for compatibility reasons in case you're
using custom trainers with old-style API.
You should never use config.
"""
self._callbacks = [] self._callbacks = []
self.loop = TrainLoop() self.loop = TrainLoop()
...@@ -310,22 +305,13 @@ class Trainer(object): ...@@ -310,22 +305,13 @@ class Trainer(object):
session_creator, session_init, session_creator, session_init,
steps_per_epoch, starting_epoch, max_epoch) steps_per_epoch, starting_epoch, max_epoch)
# create the old trainer when called with TrainConfig
def __new__(cls, *args, **kwargs): def __new__(cls, *args, **kwargs):
if (len(args) > 0 and isinstance(args[0], TrainConfig)) \ if (len(args) > 0 and isinstance(args[0], TrainConfig)) \
or 'config' in kwargs: or 'config' in kwargs:
name = cls.__name__ logger.error("You're calling new trainers with old trainer API!")
try: logger.error("See https://github.com/tensorpack/tensorpack/issues/458 for more information.")
import tensorpack.trainv1 as old_train_mod # noqa import sys
old_trainer = getattr(old_train_mod, name) sys.exit(1)
except AttributeError:
# custom trainer. has to live with it
return super(Trainer, cls).__new__(cls)
else:
logger.warn("You're calling new trainers with old trainer API!")
logger.warn("Now it returns the old trainer for you, please switch to use new trainers soon!")
logger.warn("See https://github.com/tensorpack/tensorpack/issues/458 for more information.")
return old_trainer(*args, **kwargs)
else: else:
return super(Trainer, cls).__new__(cls) return super(Trainer, cls).__new__(cls)
......
# -*- coding: utf-8 -*-
# File: __init__.py
from pkgutil import iter_modules
import os
import os.path
__all__ = []
def global_import(name):
p = __import__(name, globals(), locals(), level=1)
lst = p.__all__ if '__all__' in dir(p) else []
del globals()[name]
for k in lst:
globals()[k] = p.__dict__[k]
__all__.append(k)
_CURR_DIR = os.path.dirname(__file__)
_SKIP = ['utility']
for _, module_name, _ in iter_modules(
[_CURR_DIR]):
srcpath = os.path.join(_CURR_DIR, module_name + '.py')
if not os.path.isfile(srcpath):
continue
if module_name.startswith('_'):
continue
if module_name not in _SKIP:
global_import(module_name)
# -*- coding: utf-8 -*-
# File: base.py
import time
import weakref
import six
from six.moves import range
import tensorflow as tf
from .config import TrainConfig
from ..utils import logger
from ..utils.develop import log_deprecated
from ..callbacks import Callback, Callbacks
from ..callbacks.monitor import Monitors, TrainingMonitor
from ..tfutils.model_utils import describe_trainable_vars
from ..tfutils.sesscreate import ReuseSessionCreator
from ..tfutils.sessinit import JustCurrentSession
from ..tfutils.tower import TowerFuncWrapper
from ..input_source import PlaceholderInput
from ..graph_builder.predict import SimplePredictBuilder
from ..predict.base import OnlinePredictor
from ..callbacks.steps import MaintainStepCounter
from ..train.base import StopTraining, TrainLoop
__all__ = ['Trainer', 'StopTraining']
class Trainer(object):
""" Base class for a trainer.
Attributes:
config (TrainConfig): the config used in this trainer.
model (ModelDesc): alias for ``config.model``.
sess (tf.Session): the current session in use.
hooked_sess (tf.train.MonitoredSession): the session with hooks.
monitors (Monitors): the monitors. Other callbacks can use it for logging.
"""
is_chief = True
"""
Whether this process is the chief worker in distributed training.
Only chief worker will run some callbacks.
"""
def __init__(self, config):
"""
Args:
config (TrainConfig): the train config.
"""
assert isinstance(config, TrainConfig), type(config)
config._deprecated_parsing()
self._config = config
self.model = config.model
if self.model is not None:
def f(*inputs):
self.model.build_graph(*inputs)
"""
Only to mimic new trainer interafce on inference.
"""
self.inputs_desc = self.model.get_inputs_desc()
self.tower_func = TowerFuncWrapper(f, self.inputs_desc)
self._callbacks = []
self._monitors = []
self.loop = TrainLoop()
self.loop.config(config.steps_per_epoch, config.starting_epoch, config.max_epoch)
self._setup() # subclass will setup the graph and InputSource
def register_callback(self, cb):
"""
Register a callback to the trainer.
It can only be called before :meth:`Trainer.train` gets called.
"""
assert isinstance(cb, Callback), cb
assert not isinstance(self._callbacks, Callbacks), \
"Cannot register more callbacks after trainer was setup!"
if not self.is_chief and cb.chief_only:
logger.warn("Callback {} is chief-only, skipped.".format(str(cb)))
else:
self._callbacks.append(cb)
def register_monitor(self, mon):
"""
Register a monitor to the trainer.
It can only be called before :meth:`Trainer.train` gets called.
"""
assert isinstance(mon, TrainingMonitor), mon
assert not isinstance(self._monitors, Monitors), \
"Cannot register more monitors after trainer was setup!"
if not self.is_chief and mon.chief_only:
logger.warn("Monitor {} is chief-only, skipped.".format(str(mon)))
else:
self._monitors.append(mon)
self.register_callback(mon)
@property
def monitors(self):
assert isinstance(self._monitors, Monitors), "Monitors haven't been setup!"
return self._monitors
def train(self):
""" Start training """
self.setup()
self.main_loop()
def run_step(self):
"""
Defines what to do in one iteration. The default is:
``self.hooked_sess.run(self.train_op)``.
The behavior can be changed by either defining what is ``train_op``,
or overriding this method.
"""
if not hasattr(self, 'train_op'):
raise NotImplementedError(
"Please either set `Trainer.train_op` or provide an implementation "
"of Trainer.run_step()!")
self.hooked_sess.run(self.train_op)
def setup(self):
"""
Setup the trainer and be ready for the main loop.
"""
self.register_callback(MaintainStepCounter())
for cb in self._config.callbacks:
self.register_callback(cb)
for m in self._config.monitors:
self.register_monitor(m)
self._monitors = Monitors(self._monitors)
self.register_callback(self._monitors)
describe_trainable_vars()
# some final operations that might modify the graph
logger.info("Setup callbacks graph ...")
self._callbacks = Callbacks(self._callbacks)
self._callbacks.setup_graph(weakref.proxy(self))
self._config.session_init._setup_graph()
logger.info("Creating the session ...")
self._create_session()
if self.is_chief:
logger.info("Initializing the session ...")
self._config.session_init._run_init(self.sess)
else:
if not isinstance(self._config.session_init, JustCurrentSession):
logger.warn("This is not a chief worker, 'session_init' was ignored!")
self.sess.graph.finalize()
logger.info("Graph Finalized.")
def _create_session(self):
"""
Setup self.sess (the raw tf.Session)
and self.hooked_sess (the session with hooks and coordinator)
"""
hooks = self._callbacks.get_hooks()
self.sess = self._config.session_creator.create_session()
self.hooked_sess = tf.train.MonitoredSession(
session_creator=ReuseSessionCreator(self.sess), hooks=hooks)
def _setup(self):
"""
Build the entire graph for training.
Responsible for setup InputSource as well (including registering InputSource callbacks)
Since this method will get called in constructor only,
you can simply leave it empty and build your graph outside the trainer.
"""
pass
def main_loop(self):
"""
Run the main training loop.
"""
with self.sess.as_default():
self.loop.update_global_step()
try:
self._callbacks.before_train()
# refresh global step (might have changed by callbacks) TODO ugly
self.loop.update_global_step()
for self.loop._epoch_num in range(
self.loop.starting_epoch, self.loop.max_epoch + 1):
logger.info("Start Epoch {} ...".format(self.loop.epoch_num))
start_time = time.time()
self._callbacks.before_epoch()
for self.loop._local_step in range(self.loop.steps_per_epoch):
if self.hooked_sess.should_stop():
return
self.run_step() # implemented by subclass
self._callbacks.trigger_step()
self._callbacks.after_epoch()
logger.info("Epoch {} (global_step {}) finished, time:{:.2f} sec.".format(
self.loop.epoch_num, self.loop.global_step, time.time() - start_time))
# trigger epoch outside the timing region.
self._callbacks.trigger_epoch()
logger.info("Training has finished!")
except (StopTraining, tf.errors.OutOfRangeError):
logger.info("Training was stopped.")
except KeyboardInterrupt:
logger.info("Detected Ctrl-C and exiting main loop.")
raise
finally:
self._callbacks.after_train()
self.hooked_sess.close()
def get_predictor(self, input_names, output_names, tower=0):
"""
Returns a callable predictor built under ``TowerContext(is_training=False)``.
Args:
input_names (list), output_names(list): list of names
tower (int): build the predictor on device '/gpu:{tower}' or use -1 for '/cpu:0'.
Returns:
an :class:`OnlinePredictor`.
"""
device = tower
assert self.tower_func is not None, "Must set tower_func on the trainer to use get_predictor()!"
tower_name = 'tower-pred-{}'.format(device) if device >= 0 else 'tower-pred-cpu'
try:
tower = self.tower_func.towers[tower_name]
except KeyError:
input = PlaceholderInput()
input.setup(self.inputs_desc)
with tf.variable_scope(tf.get_variable_scope(), reuse=True):
SimplePredictBuilder(
ns_name=tower_name, vs_name=self._main_tower_vs_name,
device=device).build(input, self.tower_func)
tower = self.tower_func.towers[tower_name]
input_tensors = tower.get_tensors(input_names)
output_tensors = tower.get_tensors(output_names)
return OnlinePredictor(input_tensors, output_tensors)
@property
def _main_tower_vs_name(self):
# The vs name a predictor should be built under.
# for internal use only. Should let graphbuilder return it.
return ""
@property
def config(self):
log_deprecated(
"Trainer.config",
"It is supposed to be private! Most of its attributes can be accessed by other means.",
"2017-12-31")
return self._config
# create new trainer when not called with TrainConfig
def __new__(cls, *args, **kwargs):
if (len(args) > 0 and isinstance(args[0], TrainConfig)) \
or 'config' in kwargs:
return super(Trainer, cls).__new__(cls)
else:
import tensorpack.train as new_train
name = cls.__name__
new_trainer = getattr(new_train, name)
logger.warn("You're calling old trainers with new trainer API!")
logger.warn("Now it returns the new trainer for you, please `export TENSORPACK_TRAIN_API=v2`"
" to import new trainers automatically.")
logger.warn("You can also ignore this warning and wait for new API to become the default.")
return new_trainer(*args, **kwargs)
def _get_property(name):
"""
Delegate property to self.loop
"""
ret = property(
lambda self: getattr(self.loop, name))
if six.PY3: # __doc__ is readonly in Py2
try:
ret.__doc__ = getattr(TrainLoop, name).__doc__
except AttributeError:
pass
return ret
for name in ['global_step', 'local_step', 'steps_per_epoch',
'epoch_num', 'starting_epoch', 'max_epoch']:
setattr(Trainer, name, _get_property(name))
# -*- coding: utf-8 -*-
# File: config.py
__all__ = ['TrainConfig']
from ..train.config import TrainConfig
# -*- coding: utf-8 -*-
# File: distributed.py
import os
from ..utils import logger
from ..callbacks import RunOp
from ..tfutils.sesscreate import NewSessionCreator
from ..tfutils import get_global_step_var
from ..tfutils.distributed import get_distributed_session_creator
from ..graph_builder.distributed import DistributedReplicatedBuilder
from ..graph_builder.utils import override_to_local_variable
from .base import Trainer
__all__ = ['DistributedTrainerReplicated']
class DistributedTrainerReplicated(Trainer):
__doc__ = DistributedReplicatedBuilder.__doc__
def __init__(self, config, server):
"""
Args:
config(TrainConfig): Must contain 'model' and 'data'.
server(tf.train.Server): the server object with ps and workers
"""
assert config.data is not None and config.model is not None
self.server = server
self.job_name = server.server_def.job_name
assert self.job_name in ['ps', 'worker'], self.job_name
if self.job_name == 'worker':
# ps doesn't build any graph
self._builder = DistributedReplicatedBuilder(config.tower, server)
self.is_chief = self._builder.is_chief
else:
self.is_chief = False
logger.info("Distributed training on cluster:\n" + str(server.server_def.cluster))
self._input_source = config.data
super(DistributedTrainerReplicated, self).__init__(config)
def _setup(self):
if self.job_name == 'ps':
logger.info("Running ps {}".format(self.server.server_def.task_index))
logger.info("Kill me with 'kill {}'".format(os.getpid()))
self.server.join() # this will never return tensorflow#4713
return
with override_to_local_variable():
get_global_step_var() # gs should be local
# input source may create variable (queue size summary)
# TODO This is not good because we don't know from here
# whether something should be global or local. We now assume
# they should be local.
cbs = self._input_source.setup(self.model.get_inputs_desc())
self._config.callbacks.extend(cbs)
self.train_op, initial_sync_op, model_sync_op = self._builder.build(
lambda: self.model._build_graph_get_grads(
*self._input_source.get_input_tensors()),
self.model.get_optimizer)
# initial local_vars syncing
cb = RunOp(lambda: initial_sync_op,
run_before=True, run_as_trigger=False, verbose=True)
cb.chief_only = False
self.register_callback(cb)
# model_variables syncing
if model_sync_op:
cb = RunOp(lambda: model_sync_op,
run_before=False, run_as_trigger=True, verbose=True)
logger.warn("For efficiency, local MODEL_VARIABLES are only synced to PS once "
"every epoch. Be careful if you save the model more frequently than this.")
self.register_callback(cb)
self._set_session_creator()
def _set_session_creator(self):
old_sess_creator = self._config.session_creator
if not isinstance(old_sess_creator, NewSessionCreator) \
or old_sess_creator.user_provided_config:
raise ValueError(
"Cannot set session_creator or session_config for distributed training! "
"To use a custom session config, pass it to tf.train.Server.")
self._config.session_creator = get_distributed_session_creator(self.server)
@property
def _main_tower_vs_name(self):
return "tower0"
# -*- coding: utf-8 -*-
# File: interface.py
__all__ = ['launch_train_with_config']
from ..train.interface import launch_train_with_config
# -*- coding: utf-8 -*-
# File: multigpu.py
import tensorflow as tf
from ..callbacks.graph import RunOp
from ..utils.develop import log_deprecated
from ..input_source import QueueInput, StagingInput, DummyConstantInput
from ..graph_builder.training import (
SyncMultiGPUParameterServerBuilder,
SyncMultiGPUReplicatedBuilder,
AsyncMultiGPUBuilder,
DataParallelBuilder)
from .base import Trainer
__all__ = ['MultiGPUTrainerBase',
'SyncMultiGPUTrainerReplicated',
'SyncMultiGPUTrainerParameterServer',
'AsyncMultiGPUTrainer',
'SyncMultiGPUTrainer']
class MultiGPUTrainerBase(Trainer):
"""
For backward compatibility only
"""
def build_on_multi_tower(towers, func, devices=None, use_vs=None):
log_deprecated("MultiGPUTrainerBase.build_on_multitower",
"Please use DataParallelBuilder.build_on_towers", "2018-01-31")
return DataParallelBuilder.build_on_towers(towers, func, devices, use_vs)
def apply_prefetch_policy(config, gpu_prefetch=True):
assert (config.data is not None or config.dataflow is not None) and config.model is not None
if config.data is None and config.dataflow is not None:
# always use Queue prefetch
config.data = QueueInput(config.dataflow)
config.dataflow = None
if len(config.tower) > 1 and gpu_prefetch:
assert tf.test.is_gpu_available()
# seem to only improve on >1 GPUs
if not isinstance(config.data, (StagingInput, DummyConstantInput)):
config.data = StagingInput(config.data)
class SyncMultiGPUTrainerParameterServer(Trainer):
__doc__ = SyncMultiGPUParameterServerBuilder.__doc__
def __init__(self, config, ps_device='gpu', gpu_prefetch=True):
"""
Args:
config(TrainConfig): Must contain 'model' and either one of 'data' or 'dataflow'.
ps_device: either 'gpu' or 'cpu', where variables are stored. Setting to 'cpu' might help when #gpu>=4
gpu_prefetch(bool): whether to prefetch the data to each GPU. Usually improve performance.
"""
apply_prefetch_policy(config, gpu_prefetch)
self._input_source = config.data
assert ps_device in ['gpu', 'cpu'], ps_device
self._ps_device = ps_device
super(SyncMultiGPUTrainerParameterServer, self).__init__(config)
def _setup(self):
callbacks = self._input_source.setup(self.model.get_inputs_desc())
self.train_op = SyncMultiGPUParameterServerBuilder(
self._config.tower, self._ps_device).build(
lambda: self.model._build_graph_get_grads(
*self._input_source.get_input_tensors()),
self.model.get_optimizer)
self._config.callbacks.extend(callbacks)
def SyncMultiGPUTrainer(config):
"""
Alias for ``SyncMultiGPUTrainerParameterServer(config, ps_device='gpu')``,
as this is the most commonly used synchronous multigpu trainer (but may
not be more efficient than the other).
"""
return SyncMultiGPUTrainerParameterServer(config, ps_device='gpu')
class SyncMultiGPUTrainerReplicated(Trainer):
__doc__ = SyncMultiGPUReplicatedBuilder.__doc__
def __init__(self, config, gpu_prefetch=True):
"""
Args:
config, gpu_prefetch: same as in :class:`SyncMultiGPUTrainerParameterServer`
"""
apply_prefetch_policy(config, gpu_prefetch)
self._input_source = config.data
super(SyncMultiGPUTrainerReplicated, self).__init__(config)
def _setup(self):
callbacks = self._input_source.setup(self.model.get_inputs_desc())
self.train_op, post_init_op = SyncMultiGPUReplicatedBuilder(
self._config.tower).build(
lambda: self.model._build_graph_get_grads(
*self._input_source.get_input_tensors()),
self.model.get_optimizer)
cb = RunOp(
lambda: post_init_op,
run_before=True, run_as_trigger=True, verbose=True)
self._config.callbacks.extend(callbacks + [cb])
class AsyncMultiGPUTrainer(Trainer):
__doc__ = AsyncMultiGPUBuilder.__doc__
def __init__(self, config, scale_gradient=True):
"""
Args:
config(TrainConfig): Must contain 'model' and either one of 'data' or 'dataflow'.
scale_gradient (bool): if True, will scale each gradient by ``1.0/nr_gpu``.
"""
apply_prefetch_policy(config)
self._input_source = config.data
self._scale_gradient = scale_gradient
super(AsyncMultiGPUTrainer, self).__init__(config)
def _setup(self):
callbacks = self._input_source.setup(self.model.get_inputs_desc())
self.train_op = AsyncMultiGPUBuilder(
self._config.tower, self._scale_gradient).build(
lambda: self.model._build_graph_get_grads(
*self._input_source.get_input_tensors()),
self.model.get_optimizer)
self._config.callbacks.extend(callbacks)
# -*- coding: utf-8 -*-
# File: simple.py
from .base import Trainer
from ..tfutils.tower import TowerContext
from ..utils import logger
from ..input_source import FeedInput, QueueInput
__all__ = ['SimpleTrainer', 'QueueInputTrainer']
class SimpleTrainer(Trainer):
""" A naive single-tower single-cost demo trainer.
It simply builds one tower and minimize `model.cost`.
It supports both InputSource and DataFlow.
When DataFlow is given instead of InputSource, the InputSource to be
used will be ``FeedInput(df)`` (no prefetch).
"""
def __init__(self, config):
"""
Args:
config (TrainConfig): Must contain 'model' and either one of 'data' or 'dataflow'.
"""
assert len(config.tower) == 1, \
"Got nr_tower={}, but doesn't support multigpu!" \
" Use Sync/AsyncMultiGPUTrainer instead.".format(len(config.tower))
assert (config.data is not None or config.dataflow is not None) and config.model is not None
if config.dataflow is None:
self._input_source = config.data
else:
self._input_source = FeedInput(config.dataflow)
logger.warn("FeedInput is slow (and this is the default of SimpleTrainer). "
"Consider QueueInput or other InputSource instead.")
super(SimpleTrainer, self).__init__(config)
def _setup(self):
cbs = self._input_source.setup(self.model.get_inputs_desc())
with TowerContext('', is_training=True):
grads = self.model._build_graph_get_grads(
*self._input_source.get_input_tensors())
opt = self.model.get_optimizer()
self.train_op = opt.apply_gradients(grads, name='min_op')
self._config.callbacks.extend(cbs)
def QueueInputTrainer(config, input_queue=None):
"""
A wrapper trainer which automatically wraps ``config.dataflow`` by a :class:`QueueInput`.
It is an equivalent of ``SimpleTrainer(config)`` with ``config.data = QueueInput(dataflow)``.
Args:
config (TrainConfig): Must contain 'model' and 'dataflow'.
input_queue (tf.QueueBase): an input queue. Defaults to the :class:`QueueInput` default.
"""
assert (config.data is not None or config.dataflow is not None) and config.model is not None
if config.data is not None:
assert isinstance(config.data, QueueInput), config.data
else:
config.data = QueueInput(config.dataflow, input_queue)
config.dataflow = None
return SimpleTrainer(config)
# -*- coding: utf-8 -*-
# File: utility.py
# for backwards-compatibility
from ..graph_builder.utils import ( # noqa
override_to_local_variable, LeastLoadedDeviceSetter)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment