Commit a673974c authored by Yuxin Wu's avatar Yuxin Wu

initial commit of new trainer interface (#318)

parent 82187086
......@@ -83,17 +83,10 @@ class TrainConfig(object):
if callbacks is None:
callbacks = []
assert_type(callbacks, list)
if extra_callbacks is None:
extra_callbacks = [
MovingAverageSummary(),
ProgressBar(),
MergeAllSummaries(),
RunUpdateOps()]
self._callbacks = callbacks + extra_callbacks
if monitors is None:
monitors = [TFEventWriter(), JSONWriter(), ScalarPrinter()]
self.monitors = monitors
self._callbacks = callbacks + \
(extra_callbacks or TrainConfig.DEFAULT_EXTRA_CALLBACKS())
self.monitors = monitors or TrainConfig.DEFAULT_MONITORS()
if session_init is None:
session_init = JustCurrentSession()
......@@ -155,3 +148,15 @@ class TrainConfig(object):
@property
def callbacks(self): # disable setter
return self._callbacks
@staticmethod
def DEFAULT_EXTRA_CALLBACKS():
return [
MovingAverageSummary(),
ProgressBar(),
MergeAllSummaries(),
RunUpdateOps()]
@staticmethod
def DEFAULT_MONITORS():
return [TFEventWriter(), JSONWriter(), ScalarPrinter()]
# -*- coding: UTF-8 -*-
# File: __init__.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
from pkgutil import iter_modules
import os
import os.path
__all__ = []
def global_import(name):
p = __import__(name, globals(), locals(), level=1)
lst = p.__all__ if '__all__' in dir(p) else []
del globals()[name]
for k in lst:
globals()[k] = p.__dict__[k]
__all__.append(k)
_CURR_DIR = os.path.dirname(__file__)
_SKIP = []
for _, module_name, _ in iter_modules(
[_CURR_DIR]):
srcpath = os.path.join(_CURR_DIR, module_name + '.py')
if not os.path.isfile(srcpath):
continue
if module_name.startswith('_'):
continue
if module_name not in _SKIP:
global_import(module_name)
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# File: base.py
import tensorflow as tf
import weakref
import time
from six.moves import range
import six
from abc import abstractmethod, ABCMeta
from ..utils import logger
from ..callbacks import Callback, Callbacks
from ..callbacks.monitor import Monitors, TrainingMonitor
from ..tfutils.model_utils import describe_trainable_vars
from ..tfutils.sessinit import JustCurrentSession
from ..tfutils.sesscreate import ReuseSessionCreator
from ..callbacks.steps import MaintainStepCounter
from ..train.base import StopTraining, TrainLoop
__all__ = ['Trainer', 'SingleCostTrainer']
class Trainer(object):
""" Base class for a trainer.
"""
is_chief = True
def __init__(self):
self._callbacks = []
self.loop = TrainLoop()
self._monitors = [] # Clarify the type. Don't change from list to monitors.
def _register_callback(self, cb):
"""
Register a callback to the trainer.
It can only be called before :meth:`Trainer.train` gets called.
"""
assert isinstance(cb, Callback), cb
assert not isinstance(self._callbacks, Callbacks), \
"Cannot register more callbacks after trainer was setup!"
if not self.is_chief and cb.chief_only:
logger.warn("Callback {} is chief-only, skipped.".format(str(cb)))
else:
self._callbacks.append(cb)
def _register_monitor(self, mon):
"""
Register a monitor to the trainer.
It can only be called before :meth:`Trainer.train` gets called.
"""
assert isinstance(mon, TrainingMonitor), mon
assert not isinstance(self._monitors, Monitors), \
"Cannot register more monitors after trainer was setup!"
if not self.is_chief and mon.chief_only:
logger.warn("Monitor {} is chief-only, skipped.".format(str(mon)))
else:
self._register_callback(mon)
def run_step(self):
"""
Defines what to do in one iteration. The default is:
``self.hooked_sess.run(self.train_op)``.
The behavior can be changed by either defining what is ``train_op``,
or overriding this method.
"""
if not hasattr(self, 'train_op'):
raise NotImplementedError(
"Please either set `Trainer.train_op` or provide an implementation "
"of Trainer.run_step()!")
self.hooked_sess.run(self.train_op)
def setup_callbacks(self, callbacks, monitors):
"""
Setup callbacks and monitors. Must be called after the main graph is built.
"""
describe_trainable_vars() # TODO weird
self._register_callback(MaintainStepCounter())
for cb in callbacks:
self._register_callback(cb)
for m in monitors:
self._register_monitor(m)
self.monitors = Monitors(monitors)
self._register_callback(self.monitors) # monitors is also a callback
# some final operations that might modify the graph
logger.info("Setup callbacks graph ...")
self._callbacks = Callbacks(self._callbacks)
self._callbacks.setup_graph(weakref.proxy(self))
def initialize(self, session_creator, session_init):
"""
Initialize self.sess and self.hooked_sess.
Must be called after callbacks are setup.
"""
logger.info("Creating the session ...")
hooks = self._callbacks.get_hooks()
self.sess = session_creator.create_session()
self.hooked_sess = tf.train.MonitoredSession(
session_creator=ReuseSessionCreator(self.sess), hooks=hooks)
if self.is_chief:
logger.info("Initializing the session ...")
session_init.init(self.sess)
else:
assert isinstance(session_init, JustCurrentSession), \
"session_init is only valid for chief worker session!"
self.sess.graph.finalize()
logger.info("Graph Finalized.")
def _create_session(self):
"""
Setup self.sess (the raw tf.Session)
and self.hooked_sess (the session with hooks and coordinator)
"""
def main_loop(self, steps_per_epoch, starting_epoch=1, max_epoch=99999):
"""
Run the main training loop.
"""
with self.sess.as_default():
self.loop.config(steps_per_epoch, starting_epoch, max_epoch)
self.loop.update_global_step()
try:
self._callbacks.before_train()
# refresh global step (might have changed by callbacks) TODO ugly
# what if gs is changed later?
self.loop.update_global_step()
for self.loop._epoch_num in range(
self.loop.starting_epoch, self.loop.max_epoch + 1):
logger.info("Start Epoch {} ...".format(self.loop.epoch_num))
start_time = time.time()
self._callbacks.before_epoch()
for self.loop._local_step in range(self.loop.steps_per_epoch):
if self.hooked_sess.should_stop():
return
self.run_step() # implemented by subclass
self._callbacks.trigger_step()
self._callbacks.after_epoch()
logger.info("Epoch {} (global_step {}) finished, time:{:.2f} sec.".format(
self.loop.epoch_num, self.loop.global_step, time.time() - start_time))
# trigger epoch outside the timing region.
self._callbacks.trigger_epoch()
logger.info("Training has finished!")
except (StopTraining, tf.errors.OutOfRangeError):
logger.info("Training was stopped.")
except KeyboardInterrupt:
logger.info("Detected Ctrl-C and exiting main loop.")
except:
raise
finally:
self._callbacks.after_train()
self.hooked_sess.close()
def train(self,
callbacks, monitors,
session_creator, session_init,
steps_per_epoch, starting_epoch, max_epoch):
"""
Implemented by:
.. code-block:: python
self.setup_callbacks(callbacks, monitors)
self.initialize(session_creator, session_init)
self.main_loop(steps_per_epoch, starting_epoch, max_epoch)
You can call those methods by yourself to have better control on details if needed.
"""
self.setup_callbacks(callbacks, monitors)
self.initialize(session_creator, session_init)
self.main_loop(steps_per_epoch, starting_epoch, max_epoch)
def _get_property(name):
"""
Delegate property to self.loop
"""
ret = property(
lambda self: getattr(self.loop, name))
if six.PY3: # __doc__ is readonly in Py2
try:
ret.__doc__ = getattr(TrainLoop, name).__doc__
except AttributeError:
pass
return ret
for name in ['global_step', 'local_step', 'steps_per_epoch',
'epoch_num', 'starting_epoch', 'max_epoch']:
setattr(Trainer, name, _get_property(name))
@six.add_metaclass(ABCMeta)
class SingleCostTrainer(Trainer):
"""
Base class for single-cost trainer.
Single-cost trainer has a :meth:`setup_graph` method which takes
(inputs_desc, input, get_cost_fn, get_opt_fn), and build the training operations from them.
To use a SingleCostTrainer object, call `trainer.setup_graph(...); trainer.train(...)`.
"""
def train(self,
callbacks, monitors,
session_creator, session_init,
steps_per_epoch, starting_epoch, max_epoch):
callbacks = callbacks + self._internal_callbacks
Trainer.train(
self,
callbacks, monitors,
session_creator, session_init,
steps_per_epoch, starting_epoch, max_epoch)
def setup_graph(self, inputs_desc, input, get_cost_fn, get_opt_fn):
"""
Build the main training graph. Defaults to do nothing.
You can either override it in subclasses, or build the graph outside
the trainer.
Returns:
[Callback]: a (possibly empty) list of callbacks needed for training.
These callbacks will be automatically added when you call `train()`.
So you can usually ignore the return value.
"""
assert not input.setup_done()
input_callbacks = input.setup(inputs_desc)
train_callbacks = self._setup_graph(input, get_cost_fn, get_opt_fn)
self._internal_callbacks = input_callbacks + train_callbacks
return self._internal_callbacks
@abstractmethod
def _setup_graph(self, input, get_cost_fn, get_opt_fn):
pass
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# File: interface.py
import tensorflow as tf
from ..input_source import (
FeedInput, QueueInput, StagingInputWrapper, DummyConstantInput)
from ..train.config import TrainConfig
from .base import SingleCostTrainer
from .trainers import SimpleTrainer, DistributedTrainerReplicated
__all__ = ['launch_train_with_config', 'TrainConfig']
def _maybe_gpu_prefetch(input, towers, gpu_prefetch):
# seem to only improve on >1 GPUs
if len(towers) > 1 and gpu_prefetch:
assert tf.test.is_gpu_available()
if not isinstance(input, (StagingInputWrapper, DummyConstantInput)):
input = StagingInputWrapper(input, towers)
return input
def launch_train_with_config(config, trainer):
"""
To mimic the old training interface, with a trainer and a config.
Args:
config (TrainConfig):
trainer (Trainer): an instance of the new trainer
Examples:
.. code-block:: python
# with the old trainer:
SyncMultiGPUTrainerParameterServer(config, ps_device='gpu').train()
# with the new trainer:
launch_train_with_config(
config, SyncMultiGPUTrainerParameterServer(towers, ps_device='gpu'))
"""
assert isinstance(trainer, SingleCostTrainer), trainer
assert isinstance(config, TrainConfig), config
assert config.model is not None
assert config.dataflow is not None or config.data is not None
model = config.model
inputs_desc = model.get_inputs_desc()
input = config.data
# some check & input wrappers to mimic same behavior of the old trainer interface
if input is None:
if type(trainer) == SimpleTrainer:
input = FeedInput(config.dataflow)
else:
input = QueueInput(config.dataflow)
if config.nr_tower > 1:
assert not isinstance(trainer, SimpleTrainer)
input = _maybe_gpu_prefetch(input, config.tower, True)
if isinstance(trainer, DistributedTrainerReplicated) and \
config.session_config is not None:
raise ValueError(
"Cannot set session_config for distributed training! "
"To use a custom session config, pass it to tf.train.Server.")
trainer.setup_graph(
inputs_desc, input,
model.build_graph_get_cost, model.get_optimizer)
trainer.train(
config.callbacks,
config.monitors,
config.session_creator,
config.session_init,
config.steps_per_epoch,
config.starting_epoch,
config.max_epoch)
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# File: trainers.py
import os
from ..callbacks.graph import RunOp
from ..tfutils.sesscreate import NewSessionCreator
from ..graph_builder.training import (
SimpleBuilder,
SyncMultiGPUParameterServerBuilder,
SyncMultiGPUReplicatedBuilder,
AsyncMultiGPUBuilder,
DistributedReplicatedBuilder)
from ..graph_builder.utils import override_to_local_variable
from ..utils import logger
from ..tfutils import get_global_step_var
from ..tfutils.distributed import get_distributed_session_creator
from ..input_source import QueueInput
from .base import Trainer, SingleCostTrainer
__all__ = ['SimpleTrainer',
'QueueInputTrainer',
'SyncMultiGPUTrainerReplicated',
'SyncMultiGPUTrainerParameterServer',
'AsyncMultiGPUTrainer',
'DistributedTrainerReplicated']
class SimpleTrainer(SingleCostTrainer):
"""
Single-GPU single-cost single-tower trainer.
"""
def _setup_graph(self, input, get_cost_fn, get_opt_fn):
self.train_op = SimpleBuilder().build(
input, get_cost_fn, get_opt_fn)
return []
# Only works for type check
class QueueInputTrainer(SimpleTrainer):
def _setup_graph(self, input, get_cost_fn, get_opt_fn):
assert isinstance(input, QueueInput)
return super(QueueInputTrainer, self)._setup_graph(input, get_cost_fn, get_opt_fn)
class SyncMultiGPUTrainerParameterServer(SingleCostTrainer):
__doc__ = SyncMultiGPUParameterServerBuilder.__doc__
def __init__(self, towers, ps_device='gpu'):
"""
Args:
towers ([int]): list of GPU ids.
ps_device: either 'gpu' or 'cpu', where variables are stored. Setting to 'cpu' might help when #gpu>=4
"""
self._builder = SyncMultiGPUParameterServerBuilder(towers, ps_device)
super(SyncMultiGPUTrainerParameterServer, self).__init__()
def _setup_graph(self, input, get_cost_fn, get_opt_fn):
self.train_op = self._builder.build(input, get_cost_fn, get_opt_fn)
return []
class AsyncMultiGPUTrainer(SingleCostTrainer):
__doc__ = AsyncMultiGPUBuilder.__doc__
def __init__(self, towers, scale_gradient=True):
"""
Args:
towers ([int]): list of GPU ids.
scale_gradient (bool): if True, will scale each gradient by ``1.0/nr_gpu``.
"""
self._builder = AsyncMultiGPUBuilder(towers, scale_gradient)
super(AsyncMultiGPUTrainer, self).__init__()
def _setup_graph(self, input, get_cost_fn, get_opt_fn):
self.train_op = self._builder.build(input, get_cost_fn, get_opt_fn)
return []
class SyncMultiGPUTrainerReplicated(SingleCostTrainer):
__doc__ = SyncMultiGPUReplicatedBuilder.__doc__
def __init__(self, towers):
"""
Args:
towers ([int]): list of GPU ids.
"""
self._builder = SyncMultiGPUReplicatedBuilder(towers)
super(SyncMultiGPUTrainerReplicated, self).__init__()
def _setup_graph(self, input, get_cost_fn, get_opt_fn):
self.train_op, post_init_op = self._builder.build(
input, get_cost_fn, get_opt_fn)
cb = RunOp(
post_init_op,
run_before=True, run_as_trigger=True, verbose=True)
return [cb]
class DistributedTrainerReplicated(SingleCostTrainer):
__doc__ = DistributedReplicatedBuilder.__doc__
def __init__(self, towers, server):
"""
Args:
towers (list[int]): list of GPU ids.
server (tf.train.Server): the server with ps and workers.
The job_name must be 'worker' because 'ps' job doesn't need to
build any graph.
"""
self.server = server
self.job_name = server.server_def.job_name
assert self.job_name in ['ps', 'worker'], self.job_name
if self.job_name == 'worker':
# ps doesn't build any graph
self._builder = DistributedReplicatedBuilder(towers, server)
self.is_chief = self._builder.is_chief
else:
self.is_chief = False
logger.info("Distributed training on cluster:\n" + str(server.server_def.cluster))
def train(self,
inputs_desc, input, get_cost_fn, get_opt_fn,
callbacks, monitors,
session_creator, session_init,
steps_per_epoch, starting_epoch, max_epoch):
if self.job_name == 'ps':
logger.info("Running ps {}".format(self.server.server_def.task_index))
logger.info("Kill me with 'kill {}'".format(os.getpid()))
self.server.join() # this will never return tensorflow#4713
return
with override_to_local_variable():
get_global_step_var() # gs should be local
# input source may create variable (queue size summary)
# TODO This is not good because we don't know from here
# whether something should be global or local. We now assume
# they should be local.
input_callbacks = input.setup(inputs_desc)
train_callbacks = self.setup_graph(input, get_cost_fn, get_opt_fn)
Trainer.train(
self,
callbacks + input_callbacks + train_callbacks, monitors,
session_creator, session_init,
steps_per_epoch, starting_epoch, max_epoch)
def _setup_graph(self, input, get_cost_fn, get_opt_fn):
self.train_op, initial_sync_op, model_sync_op = self._builder.build(
input, get_cost_fn, get_opt_fn)
callbacks = []
# initial local_vars syncing
cb = RunOp(lambda: initial_sync_op,
run_before=True, run_as_trigger=False, verbose=True)
cb.chief_only = False
callbacks.append(cb)
# model_variables syncing
if model_sync_op:
cb = RunOp(lambda: model_sync_op,
run_before=False, run_as_trigger=True, verbose=True)
logger.warn("For efficiency, local MODEL_VARIABLES are only synced to PS once "
"every epoch. Be careful if you save the model more frequently than this.")
callbacks.append(cb)
return callbacks
def initialize(self, session_creator, session_init):
if not isinstance(session_creator, NewSessionCreator):
raise ValueError(
"Cannot set session_creator for distributed training! "
"To use a custom session config, pass it to tf.train.Server.")
super(DistributedTrainerReplicated, self).initialize(
get_distributed_session_creator(), session_init)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment