Commit ba4e3178 authored by Yuxin Wu's avatar Yuxin Wu

[Trainerv2] Swap trainer directory. change two examples.

parent 9268bc8c
...@@ -8,6 +8,14 @@ so you won't need to look at here very often. ...@@ -8,6 +8,14 @@ so you won't need to look at here very often.
Here are a list of things that were changed, starting from an early version. Here are a list of things that were changed, starting from an early version.
TensorFlow itself also changed APIs before 1.0 and those are not listed here. TensorFlow itself also changed APIs before 1.0 and those are not listed here.
+ [2017/10/21]
tensorpack is gradually switching to a new Trainer API.
Compatibility is kept in most ways but not guaranteed.
To switch to new API, the easiest way is to:
1. `export TENSORPACK_TRAIN_API=v2` (will be default in the future).
2. Replace `SomeTrainer(config, ...).train()` with `launch_train_with_config(config, SomeTrainer(...))`.
+ [2017/10/18] + [2017/10/18]
`TrainConfig(predict_tower)` was deprecated. You can set the inference device directly when creating the `InferenceRunner` callback. `TrainConfig(predict_tower)` was deprecated. You can set the inference device directly when creating the `InferenceRunner` callback.
+ [2017/10/12](https://github.com/ppwwyyxx/tensorpack/commit/7e963996f615b85f7459455596b4ee9bbd0bce8e). + [2017/10/12](https://github.com/ppwwyyxx/tensorpack/commit/7e963996f615b85f7459455596b4ee9bbd0bce8e).
......
...@@ -22,9 +22,11 @@ In other words, an "epoch" in tensorpack is the __default period to run callback ...@@ -22,9 +22,11 @@ In other words, an "epoch" in tensorpack is the __default period to run callback
### Common Trainers ### Common Trainers
Most neural network training tasks are single-cost optimization. <!--
Tensorpack provides some trainer implementations for such tasks. -Most neural network training tasks are single-cost optimization.
These trainers will build the graph based on the given `ModelDesc`, and minimizes `ModelDesc.cost`. -Tensorpack provides some trainer implementations for such tasks.
-These trainers will build the graph based on the given `ModelDesc`, and minimizes `ModelDesc.cost`.
-->
<!-- <!--
-To use trainers, pass a `TrainConfig` to configure them: -To use trainers, pass a `TrainConfig` to configure them:
...@@ -49,7 +51,7 @@ These trainers will build the graph based on the given `ModelDesc`, and minimize ...@@ -49,7 +51,7 @@ These trainers will build the graph based on the given `ModelDesc`, and minimize
-in the [Input Pipeline](input-source.html) tutorial. -in the [Input Pipeline](input-source.html) tutorial.
-You can set the InputSource instead, to customize this behavior. -You can set the InputSource instead, to customize this behavior.
--> -->
Trainers are being redesigned, so the recommended API will likely be changed soon. Trainers are being redesigned, this page will be updated soon.
Existing multi-GPU trainers include the logic of data-parallel training. Existing multi-GPU trainers include the logic of data-parallel training.
You can enable them by just one line, and all the necessary logic to achieve the best performance was baked into the trainers already. You can enable them by just one line, and all the necessary logic to achieve the best performance was baked into the trainers already.
......
...@@ -2,12 +2,13 @@ ...@@ -2,12 +2,13 @@
# -*- coding: UTF-8 -*- # -*- coding: UTF-8 -*-
# File: cifar-convnet.py # File: cifar-convnet.py
# Author: Yuxin Wu <ppwwyyxxc@gmail.com> # Author: Yuxin Wu <ppwwyyxxc@gmail.com>
from tensorpack import *
import tensorflow as tf import tensorflow as tf
import argparse import argparse
import numpy as np import numpy as np
import os import os
os.environ['TENSORPACK_TRAIN_API'] = 'v2' # will become default soon
from tensorpack import *
import tensorpack.tfutils.symbolic_functions as symbf import tensorpack.tfutils.symbolic_functions as symbf
from tensorpack.tfutils.summary import * from tensorpack.tfutils.summary import *
from tensorpack.dataflow import dataset from tensorpack.dataflow import dataset
...@@ -151,8 +152,7 @@ if __name__ == '__main__': ...@@ -151,8 +152,7 @@ if __name__ == '__main__':
if args.load: if args.load:
config.session_init = SaverRestore(args.load) config.session_init = SaverRestore(args.load)
config.nr_tower = max(len(args.gpu.split(',')), 1) nr_gpu = len(args.gpu.split(','))
if config.nr_tower <= 1: trainer = QueueInputTrainer() if nr_gpu <= 1 \
QueueInputTrainer(config).train() else SyncMultiGPUTrainerParameterServer(list(range(nr_gpu)))
else: launch_train_with_config(config, trainer)
SyncMultiGPUTrainerParameterServer(config).train()
...@@ -12,6 +12,7 @@ MNIST ConvNet example. ...@@ -12,6 +12,7 @@ MNIST ConvNet example.
about 0.6% validation error after 30 epochs. about 0.6% validation error after 30 epochs.
""" """
os.environ['TENSORPACK_TRAIN_API'] = 'v2' # will become default soon
# Just import everything into current namespace # Just import everything into current namespace
from tensorpack import * from tensorpack import *
from tensorpack.tfutils import summary from tensorpack.tfutils import summary
...@@ -142,4 +143,4 @@ if __name__ == '__main__': ...@@ -142,4 +143,4 @@ if __name__ == '__main__':
config.session_init = SaverRestore(args.load) config.session_init = SaverRestore(args.load)
# SimpleTrainer is slow, this is just a demo. # SimpleTrainer is slow, this is just a demo.
# You can use QueueInputTrainer instead # You can use QueueInputTrainer instead
SimpleTrainer(config).train() launch_train_with_config(config, SimpleTrainer())
[flake8] [flake8]
max-line-length = 120 max-line-length = 120
ignore = F403,F401,F405,F841,E401 ignore = F403,F401,F405,F841,E401,E402
exclude = private, exclude = private,
FasterRCNN/utils FasterRCNN/utils
...@@ -18,9 +18,9 @@ if _HAS_TF: ...@@ -18,9 +18,9 @@ if _HAS_TF:
# In development. Default to v1 # In development. Default to v1
if _os.environ.get('TENSORPACK_TRAIN_API', 'v1') == 'v2': if _os.environ.get('TENSORPACK_TRAIN_API', 'v1') == 'v2':
from tensorpack.trainv2 import *
else:
from tensorpack.train import * from tensorpack.train import *
else:
from tensorpack.trainv1 import *
from tensorpack.graph_builder import InputDesc, ModelDesc, ModelDescBase from tensorpack.graph_builder import InputDesc, ModelDesc, ModelDescBase
from tensorpack.input_source import * from tensorpack.input_source import *
from tensorpack.predict import * from tensorpack.predict import *
#!/usr/bin/env python
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# File: base.py # File: base.py
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
import time import tensorflow as tf
import weakref import weakref
import six import time
from six.moves import range from six.moves import range
import six
from abc import abstractmethod, ABCMeta
import tensorflow as tf
from .config import TrainConfig
from ..utils import logger from ..utils import logger
from ..utils.develop import log_deprecated from ..utils.argtools import call_only_once, memoized
from ..callbacks import Callback, Callbacks from ..callbacks import Callback, Callbacks
from ..callbacks.monitor import Monitors, TrainingMonitor from ..callbacks.monitor import Monitors, TrainingMonitor
from ..tfutils import get_global_step_value
from ..tfutils.model_utils import describe_trainable_vars from ..tfutils.model_utils import describe_trainable_vars
from ..tfutils.sesscreate import ReuseSessionCreator
from ..tfutils.sessinit import JustCurrentSession from ..tfutils.sessinit import JustCurrentSession
from ..tfutils.tower import TowerFuncWrapper from ..tfutils.sesscreate import ReuseSessionCreator
from ..tfutils.tower import TowerFuncWrapper, get_current_tower_context
from ..tfutils.gradproc import FilterNoneGrad
from ..callbacks.steps import MaintainStepCounter
from ..input_source import PlaceholderInput
from ..graph_builder.predictor_factory import SimplePredictBuilder from ..graph_builder.predictor_factory import SimplePredictBuilder
from ..input_source import PlaceholderInput
from ..predict.base import OnlinePredictor from ..predict.base import OnlinePredictor
from ..callbacks.steps import MaintainStepCounter
__all__ = ['Trainer', 'StopTraining']
class StopTraining(BaseException):
"""
An exception thrown to stop training.
"""
pass
class TrainLoop(object): import tensorpack.trainv1 as old_train # noqa
""" from ..trainv1.base import StopTraining, TrainLoop
Manage the double for loop. from ..trainv1.config import TrainConfig
"""
def __init__(self):
self._epoch_num = 0
self._global_step = 0
self._local_step = -1
def config(self, steps_per_epoch, starting_epoch, max_epoch):
"""
Configure the loop given the settings.
"""
self.starting_epoch = starting_epoch
self.max_epoch = max_epoch
self.steps_per_epoch = steps_per_epoch
self._epoch_num = starting_epoch - 1 __all__ = ['TrainConfig', 'Trainer', 'SingleCostTrainer', 'TowerTrainer']
def update_global_step(self):
"""
Update the Python-side global_step from TF.
This must be called under initialized default session.
"""
self._global_step = get_global_step_value()
@property
def epoch_num(self):
"""
The number of the currently ongoing epoch.
An epoch is defined to cover the moment before calling `before_epoch` until after calling `trigger_epoch`.
i.e., in the `trigger_epoch` of epoch 3, `self.epoch_num` is 3.
If you need use `self.epoch_num` in your callback, you'll need to know this.
"""
return self._epoch_num
@property
def global_step(self):
"""
The tensorflow global_step, i.e. how many times ``hooked_sess.run`` has been called.
Note:
1. global_step is incremented **after** each ``hooked_sess.run`` returns from TF runtime.
2. If you make zero or more than one calls to ``hooked_sess.run`` in one
:meth:`run_step`, local_step and global_step may increment at different speed.
"""
return self._global_step
@property
def local_step(self):
"""
The number of steps that have finished in the current epoch.
"""
return self._local_step
class Trainer(object): class Trainer(object):
""" Base class for a trainer. """ Base class for a trainer.
Attributes:
config (TrainConfig): the config used in this trainer.
model (ModelDesc): alias for ``config.model``.
sess (tf.Session): the current session in use.
hooked_sess (tf.train.MonitoredSession): the session with hooks.
monitors (Monitors): the monitors. Other callbacks can use it for logging.
""" """
_API_VERSION = 1 _API_VERSION = 2
is_chief = True is_chief = True
"""
Whether this process is the chief worker in distributed training.
Only chief worker will run some callbacks.
"""
def __init__(self, config): def __init__(self, config=None):
""" """
Args: config is only for compatibility reasons in case you're
config (TrainConfig): the train config. using custom trainers with old-style API.
You should never use config.
""" """
assert isinstance(config, TrainConfig), type(config)
self._config = config
self.model = config.model
if self.model is not None:
def f(*inputs):
self.model.build_graph(inputs)
"""
Only to mimic new trainer interafce on inference.
"""
self.inputs_desc = self.model.get_inputs_desc()
self.tower_func = TowerFuncWrapper(f, self.inputs_desc)
self._callbacks = [] self._callbacks = []
self._monitors = []
self.loop = TrainLoop() self.loop = TrainLoop()
self.loop.config(config.steps_per_epoch, config.starting_epoch, config.max_epoch) self._monitors = [] # Clarify the type. Don't change from list to monitors.
self._setup() # subclass will setup the graph and InputSource # Hacks!
if config is not None:
def register_callback(self, cb): logger.warn("You're initializing new trainer with old trainer API!")
logger.warn("This could happen if you wrote a custom trainer before.")
logger.warn("It may work now through some hacks, but please switch to the new API!")
self._config = config
self.inputs_desc = config.model.get_inputs_desc()
self.tower_func = TowerFuncWrapper(
lambda *inputs: config.model.build_graph(inputs),
self.inputs_desc)
self._main_tower_vs_name = ""
def gp(input_names, output_names, tower=0):
return TowerTrainer.get_predictor(self, input_names, output_names, device=tower)
self.get_predictor = gp
old_train = self.train
def train():
return old_train(
config.callbacks, config.monitors,
config.session_creator, config.session_init,
config.steps_per_epoch, config.starting_epoch, config.max_epoch)
self.train = train
def _register_callback(self, cb):
""" """
Register a callback to the trainer. Register a callback to the trainer.
It can only be called before :meth:`Trainer.train` gets called. It can only be called before :meth:`Trainer.train` gets called.
...@@ -151,7 +87,7 @@ class Trainer(object): ...@@ -151,7 +87,7 @@ class Trainer(object):
else: else:
self._callbacks.append(cb) self._callbacks.append(cb)
def register_monitor(self, mon): def _register_monitor(self, mon):
""" """
Register a monitor to the trainer. Register a monitor to the trainer.
It can only be called before :meth:`Trainer.train` gets called. It can only be called before :meth:`Trainer.train` gets called.
...@@ -162,18 +98,7 @@ class Trainer(object): ...@@ -162,18 +98,7 @@ class Trainer(object):
if not self.is_chief and mon.chief_only: if not self.is_chief and mon.chief_only:
logger.warn("Monitor {} is chief-only, skipped.".format(str(mon))) logger.warn("Monitor {} is chief-only, skipped.".format(str(mon)))
else: else:
self._monitors.append(mon) self._register_callback(mon)
self.register_callback(mon)
@property
def monitors(self):
assert isinstance(self._monitors, Monitors), "Monitors haven't been setup!"
return self._monitors
def train(self):
""" Start training """
self.setup()
self.main_loop()
def run_step(self): def run_step(self):
""" """
...@@ -189,32 +114,44 @@ class Trainer(object): ...@@ -189,32 +114,44 @@ class Trainer(object):
"of Trainer.run_step()!") "of Trainer.run_step()!")
self.hooked_sess.run(self.train_op) self.hooked_sess.run(self.train_op)
def setup(self): @call_only_once
def setup_callbacks(self, callbacks, monitors):
""" """
Setup the trainer and be ready for the main loop. Setup callbacks and monitors. Must be called after the main graph is built.
""" """
self.register_callback(MaintainStepCounter()) describe_trainable_vars() # TODO weird
for cb in self._config.callbacks:
self.register_callback(cb)
for m in self._config.monitors:
self.register_monitor(m)
self._monitors = Monitors(self._monitors)
self.register_callback(self._monitors)
describe_trainable_vars() self._register_callback(MaintainStepCounter())
for cb in callbacks:
self._register_callback(cb)
for m in monitors:
self._register_monitor(m)
self.monitors = Monitors(monitors)
self._register_callback(self.monitors) # monitors is also a callback
# some final operations that might modify the graph # some final operations that might modify the graph
logger.info("Setup callbacks graph ...") logger.info("Setup callbacks graph ...")
self._callbacks = Callbacks(self._callbacks) self._callbacks = Callbacks(self._callbacks)
self._callbacks.setup_graph(weakref.proxy(self)) self._callbacks.setup_graph(weakref.proxy(self))
self._config.session_init._setup_graph()
@call_only_once
def initialize(self, session_creator, session_init):
"""
Initialize self.sess and self.hooked_sess.
Must be called after callbacks are setup.
"""
session_init._setup_graph()
logger.info("Creating the session ...") logger.info("Creating the session ...")
self._create_session()
hooks = self._callbacks.get_hooks()
self.sess = session_creator.create_session()
self.hooked_sess = tf.train.MonitoredSession(
session_creator=ReuseSessionCreator(self.sess), hooks=hooks)
if self.is_chief: if self.is_chief:
logger.info("Initializing the session ...") logger.info("Initializing the session ...")
self._config.session_init._run_init(self.sess) session_init._run_init(self.sess)
else: else:
if not isinstance(self._config.session_init, JustCurrentSession): if not isinstance(self._config.session_init, JustCurrentSession):
logger.warn("This is not a chief worker, 'session_init' was ignored!") logger.warn("This is not a chief worker, 'session_init' was ignored!")
...@@ -222,35 +159,18 @@ class Trainer(object): ...@@ -222,35 +159,18 @@ class Trainer(object):
self.sess.graph.finalize() self.sess.graph.finalize()
logger.info("Graph Finalized.") logger.info("Graph Finalized.")
def _create_session(self): @call_only_once
""" def main_loop(self, steps_per_epoch, starting_epoch=1, max_epoch=99999):
Setup self.sess (the raw tf.Session)
and self.hooked_sess (the session with hooks and coordinator)
"""
hooks = self._callbacks.get_hooks()
self.sess = self._config.session_creator.create_session()
self.hooked_sess = tf.train.MonitoredSession(
session_creator=ReuseSessionCreator(self.sess), hooks=hooks)
def _setup(self):
"""
Build the entire graph for training.
Responsible for setup InputSource as well (including registering InputSource callbacks)
Since this method will get called in constructor only,
you can simply leave it empty and build your graph outside the trainer.
"""
pass
def main_loop(self):
""" """
Run the main training loop. Run the main training loop.
""" """
with self.sess.as_default(): with self.sess.as_default():
self.loop.config(steps_per_epoch, starting_epoch, max_epoch)
self.loop.update_global_step() self.loop.update_global_step()
try: try:
self._callbacks.before_train() self._callbacks.before_train()
# refresh global step (might have changed by callbacks) TODO ugly # refresh global step (might have changed by callbacks) TODO ugly
# what if gs is changed later?
self.loop.update_global_step() self.loop.update_global_step()
for self.loop._epoch_num in range( for self.loop._epoch_num in range(
self.loop.starting_epoch, self.loop.max_epoch + 1): self.loop.starting_epoch, self.loop.max_epoch + 1):
...@@ -279,18 +199,106 @@ class Trainer(object): ...@@ -279,18 +199,106 @@ class Trainer(object):
self._callbacks.after_train() self._callbacks.after_train()
self.hooked_sess.close() self.hooked_sess.close()
def get_predictor(self, input_names, output_names, tower=0): def train(self,
callbacks, monitors,
session_creator, session_init,
steps_per_epoch, starting_epoch, max_epoch):
"""
Implemented by:
.. code-block:: python
self.setup_callbacks(callbacks, monitors)
self.initialize(session_creator, session_init)
self.main_loop(steps_per_epoch, starting_epoch, max_epoch)
You can call those methods by yourself to have better control on details if needed.
"""
self.setup_callbacks(callbacks, monitors)
self.initialize(session_creator, session_init)
self.main_loop(steps_per_epoch, starting_epoch, max_epoch)
# create the old trainer when called with TrainConfig
def __new__(cls, *args, **kwargs):
if (len(args) > 0 and isinstance(args[0], TrainConfig)) \
or 'config' in kwargs:
name = cls.__name__
try:
old_trainer = getattr(old_train, name)
except AttributeError:
# custom trainer. has to live with it
return super(Trainer, cls).__new__(cls)
else:
logger.warn("You're calling new trainers with old trainer API!")
logger.warn("Now it returns the old trainer for you, please switch to use new trainers correctly!")
logger.warn("'SomeTrainer(config, ...).train()' should be equivalent to "
"'launch_train_with_config(config, SomeTrainer(...))' in the new API.")
return old_trainer(*args, **kwargs)
else:
return super(Trainer, cls).__new__(cls)
def _get_property(name):
"""
Delegate property to self.loop
"""
ret = property(
lambda self: getattr(self.loop, name))
if six.PY3: # __doc__ is readonly in Py2
try:
ret.__doc__ = getattr(TrainLoop, name).__doc__
except AttributeError:
pass
return ret
for name in ['global_step', 'local_step', 'steps_per_epoch',
'epoch_num', 'starting_epoch', 'max_epoch']:
setattr(Trainer, name, _get_property(name))
class TowerTrainer(Trainer):
"""
Base trainers for models that can be built by calling a tower function under a :class:`TowerContext`.
This is required by some features that replicates the model
automatically, e.g. creating a predictor.
"""
tower_func = None
"""
A :class:`TowerFuncWrapper` instance.
A callable which takes some input tensors and builds one replicate of the model.
"""
@call_only_once
def set_tower_func(self, tower_func):
"""
Args:
tower_func (TowerFuncWrapper)
"""
assert isinstance(tower_func, TowerFuncWrapper), tower_func
self.tower_func = tower_func
@property
def inputs_desc(self):
"""
Returns:
list[InputDesc]: metainfo about the inputs to the tower.
"""
return self.tower_func.inputs_desc
def get_predictor(self, input_names, output_names, device=0):
""" """
Returns a callable predictor built under ``TowerContext(is_training=False)``. Returns a callable predictor built under ``TowerContext(is_training=False)``.
Args: Args:
input_names (list), output_names(list): list of names input_names (list), output_names(list): list of names
tower (int): build the predictor on device '/gpu:{tower}' or use -1 for '/cpu:0'. device (int): build the predictor on device '/gpu:{device}' or use -1 for '/cpu:0'.
Returns: Returns:
an :class:`OnlinePredictor`. an :class:`OnlinePredictor`.
""" """
device = tower
assert self.tower_func is not None, "Must set tower_func on the trainer to use get_predictor()!" assert self.tower_func is not None, "Must set tower_func on the trainer to use get_predictor()!"
tower_name = 'tower-pred-{}'.format(device) if device >= 0 else 'tower-pred-cpu' tower_name = 'tower-pred-{}'.format(device) if device >= 0 else 'tower-pred-cpu'
...@@ -311,33 +319,91 @@ class Trainer(object): ...@@ -311,33 +319,91 @@ class Trainer(object):
@property @property
def _main_tower_vs_name(self): def _main_tower_vs_name(self):
# The vs name a predictor should be built under. """
# for internal use only. Should let graphbuilder return it. The vs name for the "main" copy of the model,
to be used to build predictors.
"""
return "" return ""
@property
def config(self):
log_deprecated(
"Trainer.config",
"It is supposed to be private! Most of its attributes can be accessed by other means.",
"2017-12-31")
return self._config
def _get_property(name): @six.add_metaclass(ABCMeta)
class SingleCostTrainer(TowerTrainer):
""" """
Delegate property to self.loop Base class for single-cost trainer.
Single-cost trainer has a :meth:`setup_graph` method which takes
(inputs_desc, input, get_cost_fn, get_opt_fn), and build the training operations from them.
To use a SingleCostTrainer object, call `trainer.setup_graph(...); trainer.train(...)`.
""" """
ret = property(
lambda self: getattr(self.loop, name))
if six.PY3: # __doc__ is readonly in Py2
try:
ret.__doc__ = getattr(TrainLoop, name).__doc__
except AttributeError:
pass
return ret
def train(self,
callbacks, monitors,
session_creator, session_init,
steps_per_epoch, starting_epoch, max_epoch):
"""
Same as :meth:`Trainer.train()`, except that the callbacks this
trainer needs are automatically added.
"""
callbacks = callbacks + self._internal_callbacks
super(SingleCostTrainer, self).train(
callbacks, monitors,
session_creator, session_init,
steps_per_epoch, starting_epoch, max_epoch)
@call_only_once
def setup_graph(self, inputs_desc, input, get_cost_fn, get_opt_fn):
"""
Responsible for building the main training graph for single-cost training.
for name in ['global_step', 'local_step', 'steps_per_epoch', Args:
'epoch_num', 'starting_epoch', 'max_epoch']: inputs_desc ([InputDesc]):
setattr(Trainer, name, _get_property(name)) input (InputSource):
get_cost_fn ([tf.Tensor] -> tf.Tensor): callable, takes some input tenosrs and return a cost tensor.
Might get called multiple times for data-parallel training or inference.
get_opt_fn (-> tf.train.Optimizer): callable which returns an
optimizer. Will only be called once.
Returns:
[Callback]: a (possibly empty) list of callbacks needed for training.
These callbacks will be automatically added when you call `train()`.
So you can usually ignore the return value.
"""
get_cost_fn = TowerFuncWrapper(get_cost_fn, inputs_desc)
get_opt_fn = memoized(get_opt_fn)
self.set_tower_func(get_cost_fn)
input_callbacks = self._setup_input(inputs_desc, input)
train_callbacks = self._setup_graph(input, get_cost_fn, get_opt_fn)
self._internal_callbacks = input_callbacks + train_callbacks
return self._internal_callbacks
@abstractmethod
def _setup_graph(self, input, get_cost_fn, get_opt_fn):
pass
def _setup_input(self, inputs_desc, input):
assert not input.setup_done()
return input.setup(inputs_desc)
def _make_get_grad_fn(self, input, get_cost_fn, get_opt_fn):
"""
Returns:
a get_grad_fn for GraphBuilder to use.
"""
# internal use only
assert input.setup_done()
def get_grad_fn():
ctx = get_current_tower_context()
cost = get_cost_fn(*input.get_input_tensors())
varlist = ctx.filter_vars_by_vs_name(tf.trainable_variables())
opt = get_opt_fn()
grads = opt.compute_gradients(
cost, var_list=varlist,
gate_gradients=False, colocate_gradients_with_ops=True)
grads = FilterNoneGrad().process(grads)
return grads
return get_grad_fn
...@@ -7,11 +7,11 @@ import tensorflow as tf ...@@ -7,11 +7,11 @@ import tensorflow as tf
from ..input_source import ( from ..input_source import (
InputSource, FeedInput, QueueInput, StagingInputWrapper, DummyConstantInput) InputSource, FeedInput, QueueInput, StagingInputWrapper, DummyConstantInput)
from ..train.config import TrainConfig from ..trainv1.config import TrainConfig
from .base import SingleCostTrainer from .base import SingleCostTrainer
from .trainers import SimpleTrainer, DistributedTrainerReplicated from .trainers import SimpleTrainer, DistributedTrainerReplicated
__all__ = ['launch_train_with_config', 'TrainConfig', 'apply_default_prefetch'] __all__ = ['launch_train_with_config', 'apply_default_prefetch']
def apply_default_prefetch(input_source_or_dataflow, trainer, towers): def apply_default_prefetch(input_source_or_dataflow, trainer, towers):
......
...@@ -24,6 +24,7 @@ from .base import SingleCostTrainer ...@@ -24,6 +24,7 @@ from .base import SingleCostTrainer
__all__ = ['SimpleTrainer', __all__ = ['SimpleTrainer',
'QueueInputTrainer', 'QueueInputTrainer',
'SyncMultiGPUTrainer',
'SyncMultiGPUTrainerReplicated', 'SyncMultiGPUTrainerReplicated',
'SyncMultiGPUTrainerParameterServer', 'SyncMultiGPUTrainerParameterServer',
'AsyncMultiGPUTrainer', 'AsyncMultiGPUTrainer',
...@@ -68,6 +69,17 @@ class SyncMultiGPUTrainerParameterServer(SingleCostTrainer): ...@@ -68,6 +69,17 @@ class SyncMultiGPUTrainerParameterServer(SingleCostTrainer):
return [] return []
def SyncMultiGPUTrainer(towers):
"""
Return a default multi-GPU trainer, if you don't care about the details.
It may not be the most efficient one for your task.
Args:
towers (list[int]): list of GPU ids.
"""
return SyncMultiGPUTrainerParameterServer(towers, ps_device='gpu')
class AsyncMultiGPUTrainer(SingleCostTrainer): class AsyncMultiGPUTrainer(SingleCostTrainer):
__doc__ = AsyncMultiGPUBuilder.__doc__ __doc__ = AsyncMultiGPUBuilder.__doc__
......
...@@ -19,7 +19,7 @@ def global_import(name): ...@@ -19,7 +19,7 @@ def global_import(name):
_CURR_DIR = os.path.dirname(__file__) _CURR_DIR = os.path.dirname(__file__)
_SKIP = [] _SKIP = ['utility']
for _, module_name, _ in iter_modules( for _, module_name, _ in iter_modules(
[_CURR_DIR]): [_CURR_DIR]):
srcpath = os.path.join(_CURR_DIR, module_name + '.py') srcpath = os.path.join(_CURR_DIR, module_name + '.py')
......
#!/usr/bin/env python
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# File: base.py # File: base.py
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
import tensorflow as tf
import weakref
import time import time
from six.moves import range import weakref
import six import six
from abc import abstractmethod, ABCMeta from six.moves import range
import tensorflow as tf
from .config import TrainConfig
from ..utils import logger from ..utils import logger
from ..utils.argtools import call_only_once, memoized from ..utils.develop import log_deprecated
from ..callbacks import Callback, Callbacks from ..callbacks import Callback, Callbacks
from ..callbacks.monitor import Monitors, TrainingMonitor from ..callbacks.monitor import Monitors, TrainingMonitor
from ..tfutils import get_global_step_value
from ..tfutils.model_utils import describe_trainable_vars from ..tfutils.model_utils import describe_trainable_vars
from ..tfutils.sessinit import JustCurrentSession
from ..tfutils.sesscreate import ReuseSessionCreator from ..tfutils.sesscreate import ReuseSessionCreator
from ..tfutils.tower import TowerFuncWrapper, get_current_tower_context from ..tfutils.sessinit import JustCurrentSession
from ..tfutils.gradproc import FilterNoneGrad from ..tfutils.tower import TowerFuncWrapper
from ..callbacks.steps import MaintainStepCounter
from ..input_source import PlaceholderInput
from ..graph_builder.predictor_factory import SimplePredictBuilder from ..graph_builder.predictor_factory import SimplePredictBuilder
from ..input_source import FeedfreeInput, PlaceholderInput
from ..predict.base import OnlinePredictor from ..predict.base import OnlinePredictor
from ..callbacks.steps import MaintainStepCounter
__all__ = ['Trainer', 'StopTraining']
class StopTraining(BaseException):
"""
An exception thrown to stop training.
"""
pass
class TrainLoop(object):
"""
Manage the double for loop.
"""
import tensorpack.train as old_train # noqa def __init__(self):
from ..train.base import StopTraining, TrainLoop self._epoch_num = 0
self._global_step = 0
self._local_step = -1
__all__ = ['Trainer', 'SingleCostTrainer', 'TowerTrainer'] def config(self, steps_per_epoch, starting_epoch, max_epoch):
"""
Configure the loop given the settings.
"""
self.starting_epoch = starting_epoch
self.max_epoch = max_epoch
self.steps_per_epoch = steps_per_epoch
self._epoch_num = starting_epoch - 1
def update_global_step(self):
"""
Update the Python-side global_step from TF.
This must be called under initialized default session.
"""
self._global_step = get_global_step_value()
@property
def epoch_num(self):
"""
The number of the currently ongoing epoch.
An epoch is defined to cover the moment before calling `before_epoch` until after calling `trigger_epoch`.
i.e., in the `trigger_epoch` of epoch 3, `self.epoch_num` is 3.
If you need use `self.epoch_num` in your callback, you'll need to know this.
"""
return self._epoch_num
@property
def global_step(self):
"""
The tensorflow global_step, i.e. how many times ``hooked_sess.run`` has been called.
Note:
1. global_step is incremented **after** each ``hooked_sess.run`` returns from TF runtime.
2. If you make zero or more than one calls to ``hooked_sess.run`` in one
:meth:`run_step`, local_step and global_step may increment at different speed.
"""
return self._global_step
@property
def local_step(self):
"""
The number of steps that have finished in the current epoch.
"""
return self._local_step
class Trainer(object): class Trainer(object):
""" Base class for a trainer. """ Base class for a trainer.
Attributes:
config (TrainConfig): the config used in this trainer.
model (ModelDesc): alias for ``config.model``.
sess (tf.Session): the current session in use.
hooked_sess (tf.train.MonitoredSession): the session with hooks.
monitors (Monitors): the monitors. Other callbacks can use it for logging.
""" """
_API_VERSION = 2 _API_VERSION = 1
is_chief = True is_chief = True
"""
Whether this process is the chief worker in distributed training.
Only chief worker will run some callbacks.
"""
def __init__(self, config=None): def __init__(self, config):
""" """
config is only for compatibility reasons in case you're Args:
using custom trainers with old-style API. config (TrainConfig): the train config.
You should never use config.
""" """
assert isinstance(config, TrainConfig), type(config)
self._config = config
self.model = config.model
if self.model is not None:
def f(*inputs):
self.model.build_graph(inputs)
"""
Only to mimic new trainer interafce on inference.
"""
self.inputs_desc = self.model.get_inputs_desc()
self.tower_func = TowerFuncWrapper(f, self.inputs_desc)
self._callbacks = [] self._callbacks = []
self._monitors = []
self.loop = TrainLoop() self.loop = TrainLoop()
self._monitors = [] # Clarify the type. Don't change from list to monitors. self.loop.config(config.steps_per_epoch, config.starting_epoch, config.max_epoch)
# Hacks! self._setup() # subclass will setup the graph and InputSource
if config is not None:
logger.warn("You're initializing new trainer with old trainer API!") def register_callback(self, cb):
logger.warn("This could happen if you wrote a custom trainer before.")
logger.warn("It may work now through some hacks, but please switch to the new API!")
self._config = config
self.inputs_desc = config.model.get_inputs_desc()
self.tower_func = TowerFuncWrapper(
lambda *inputs: config.model.build_graph(inputs),
self.inputs_desc)
self._main_tower_vs_name = ""
def gp(input_names, output_names, tower=0):
return TowerTrainer.get_predictor(self, input_names, output_names, device=tower)
self.get_predictor = gp
old_train = self.train
def train():
return old_train(
config.callbacks, config.monitors,
config.session_creator, config.session_init,
config.steps_per_epoch, config.starting_epoch, config.max_epoch)
self.train = train
def _register_callback(self, cb):
""" """
Register a callback to the trainer. Register a callback to the trainer.
It can only be called before :meth:`Trainer.train` gets called. It can only be called before :meth:`Trainer.train` gets called.
...@@ -86,7 +151,7 @@ class Trainer(object): ...@@ -86,7 +151,7 @@ class Trainer(object):
else: else:
self._callbacks.append(cb) self._callbacks.append(cb)
def _register_monitor(self, mon): def register_monitor(self, mon):
""" """
Register a monitor to the trainer. Register a monitor to the trainer.
It can only be called before :meth:`Trainer.train` gets called. It can only be called before :meth:`Trainer.train` gets called.
...@@ -97,7 +162,18 @@ class Trainer(object): ...@@ -97,7 +162,18 @@ class Trainer(object):
if not self.is_chief and mon.chief_only: if not self.is_chief and mon.chief_only:
logger.warn("Monitor {} is chief-only, skipped.".format(str(mon))) logger.warn("Monitor {} is chief-only, skipped.".format(str(mon)))
else: else:
self._register_callback(mon) self._monitors.append(mon)
self.register_callback(mon)
@property
def monitors(self):
assert isinstance(self._monitors, Monitors), "Monitors haven't been setup!"
return self._monitors
def train(self):
""" Start training """
self.setup()
self.main_loop()
def run_step(self): def run_step(self):
""" """
...@@ -113,44 +189,32 @@ class Trainer(object): ...@@ -113,44 +189,32 @@ class Trainer(object):
"of Trainer.run_step()!") "of Trainer.run_step()!")
self.hooked_sess.run(self.train_op) self.hooked_sess.run(self.train_op)
@call_only_once def setup(self):
def setup_callbacks(self, callbacks, monitors):
""" """
Setup callbacks and monitors. Must be called after the main graph is built. Setup the trainer and be ready for the main loop.
""" """
describe_trainable_vars() # TODO weird self.register_callback(MaintainStepCounter())
for cb in self._config.callbacks:
self.register_callback(cb)
for m in self._config.monitors:
self.register_monitor(m)
self._monitors = Monitors(self._monitors)
self.register_callback(self._monitors)
self._register_callback(MaintainStepCounter()) describe_trainable_vars()
for cb in callbacks:
self._register_callback(cb)
for m in monitors:
self._register_monitor(m)
self.monitors = Monitors(monitors)
self._register_callback(self.monitors) # monitors is also a callback
# some final operations that might modify the graph # some final operations that might modify the graph
logger.info("Setup callbacks graph ...") logger.info("Setup callbacks graph ...")
self._callbacks = Callbacks(self._callbacks) self._callbacks = Callbacks(self._callbacks)
self._callbacks.setup_graph(weakref.proxy(self)) self._callbacks.setup_graph(weakref.proxy(self))
self._config.session_init._setup_graph()
@call_only_once
def initialize(self, session_creator, session_init):
"""
Initialize self.sess and self.hooked_sess.
Must be called after callbacks are setup.
"""
session_init._setup_graph()
logger.info("Creating the session ...") logger.info("Creating the session ...")
self._create_session()
hooks = self._callbacks.get_hooks()
self.sess = session_creator.create_session()
self.hooked_sess = tf.train.MonitoredSession(
session_creator=ReuseSessionCreator(self.sess), hooks=hooks)
if self.is_chief: if self.is_chief:
logger.info("Initializing the session ...") logger.info("Initializing the session ...")
session_init._run_init(self.sess) self._config.session_init._run_init(self.sess)
else: else:
if not isinstance(self._config.session_init, JustCurrentSession): if not isinstance(self._config.session_init, JustCurrentSession):
logger.warn("This is not a chief worker, 'session_init' was ignored!") logger.warn("This is not a chief worker, 'session_init' was ignored!")
...@@ -158,18 +222,35 @@ class Trainer(object): ...@@ -158,18 +222,35 @@ class Trainer(object):
self.sess.graph.finalize() self.sess.graph.finalize()
logger.info("Graph Finalized.") logger.info("Graph Finalized.")
@call_only_once def _create_session(self):
def main_loop(self, steps_per_epoch, starting_epoch=1, max_epoch=99999): """
Setup self.sess (the raw tf.Session)
and self.hooked_sess (the session with hooks and coordinator)
"""
hooks = self._callbacks.get_hooks()
self.sess = self._config.session_creator.create_session()
self.hooked_sess = tf.train.MonitoredSession(
session_creator=ReuseSessionCreator(self.sess), hooks=hooks)
def _setup(self):
"""
Build the entire graph for training.
Responsible for setup InputSource as well (including registering InputSource callbacks)
Since this method will get called in constructor only,
you can simply leave it empty and build your graph outside the trainer.
"""
pass
def main_loop(self):
""" """
Run the main training loop. Run the main training loop.
""" """
with self.sess.as_default(): with self.sess.as_default():
self.loop.config(steps_per_epoch, starting_epoch, max_epoch)
self.loop.update_global_step() self.loop.update_global_step()
try: try:
self._callbacks.before_train() self._callbacks.before_train()
# refresh global step (might have changed by callbacks) TODO ugly # refresh global step (might have changed by callbacks) TODO ugly
# what if gs is changed later?
self.loop.update_global_step() self.loop.update_global_step()
for self.loop._epoch_num in range( for self.loop._epoch_num in range(
self.loop.starting_epoch, self.loop.max_epoch + 1): self.loop.starting_epoch, self.loop.max_epoch + 1):
...@@ -198,106 +279,18 @@ class Trainer(object): ...@@ -198,106 +279,18 @@ class Trainer(object):
self._callbacks.after_train() self._callbacks.after_train()
self.hooked_sess.close() self.hooked_sess.close()
def train(self, def get_predictor(self, input_names, output_names, tower=0):
callbacks, monitors,
session_creator, session_init,
steps_per_epoch, starting_epoch, max_epoch):
"""
Implemented by:
.. code-block:: python
self.setup_callbacks(callbacks, monitors)
self.initialize(session_creator, session_init)
self.main_loop(steps_per_epoch, starting_epoch, max_epoch)
You can call those methods by yourself to have better control on details if needed.
"""
self.setup_callbacks(callbacks, monitors)
self.initialize(session_creator, session_init)
self.main_loop(steps_per_epoch, starting_epoch, max_epoch)
# create the old trainer when called with TrainConfig
def __new__(cls, *args, **kwargs):
if (len(args) > 0 and isinstance(args[0], old_train.TrainConfig)) \
or 'config' in kwargs:
name = cls.__name__
try:
old_trainer = getattr(old_train, name)
except AttributeError:
# custom trainer. has to live with it
return super(Trainer, cls).__new__(cls)
else:
logger.warn("You're creating trainers with old trainer API!")
logger.warn("Now it returns the old trainer for you, please switch to the new API!")
logger.warn("'SomeTrainer(config, ...).train()' should be equivalent to "
"'launch_train_with_config(config, SomeTrainer(...))' in the new API.")
return old_trainer(*args, **kwargs)
else:
return super(Trainer, cls).__new__(cls)
def _get_property(name):
"""
Delegate property to self.loop
"""
ret = property(
lambda self: getattr(self.loop, name))
if six.PY3: # __doc__ is readonly in Py2
try:
ret.__doc__ = getattr(TrainLoop, name).__doc__
except AttributeError:
pass
return ret
for name in ['global_step', 'local_step', 'steps_per_epoch',
'epoch_num', 'starting_epoch', 'max_epoch']:
setattr(Trainer, name, _get_property(name))
class TowerTrainer(Trainer):
"""
Base trainers for models that can be built by calling a tower function under a :class:`TowerContext`.
This is required by some features that replicates the model
automatically, e.g. creating a predictor.
"""
tower_func = None
"""
A :class:`TowerFuncWrapper` instance.
A callable which takes some input tensors and builds one replicate of the model.
"""
@call_only_once
def set_tower_func(self, tower_func):
"""
Args:
tower_func (TowerFuncWrapper)
"""
assert isinstance(tower_func, TowerFuncWrapper), tower_func
self.tower_func = tower_func
@property
def inputs_desc(self):
"""
Returns:
list[InputDesc]: metainfo about the inputs to the tower.
"""
return self.tower_func.inputs_desc
def get_predictor(self, input_names, output_names, device=0):
""" """
Returns a callable predictor built under ``TowerContext(is_training=False)``. Returns a callable predictor built under ``TowerContext(is_training=False)``.
Args: Args:
input_names (list), output_names(list): list of names input_names (list), output_names(list): list of names
device (int): build the predictor on device '/gpu:{device}' or use -1 for '/cpu:0'. tower (int): build the predictor on device '/gpu:{tower}' or use -1 for '/cpu:0'.
Returns: Returns:
an :class:`OnlinePredictor`. an :class:`OnlinePredictor`.
""" """
device = tower
assert self.tower_func is not None, "Must set tower_func on the trainer to use get_predictor()!" assert self.tower_func is not None, "Must set tower_func on the trainer to use get_predictor()!"
tower_name = 'tower-pred-{}'.format(device) if device >= 0 else 'tower-pred-cpu' tower_name = 'tower-pred-{}'.format(device) if device >= 0 else 'tower-pred-cpu'
...@@ -318,92 +311,48 @@ class TowerTrainer(Trainer): ...@@ -318,92 +311,48 @@ class TowerTrainer(Trainer):
@property @property
def _main_tower_vs_name(self): def _main_tower_vs_name(self):
""" # The vs name a predictor should be built under.
The vs name for the "main" copy of the model, # for internal use only. Should let graphbuilder return it.
to be used to build predictors.
"""
return "" return ""
@property
def config(self):
log_deprecated(
"Trainer.config",
"It is supposed to be private! Most of its attributes can be accessed by other means.",
"2017-12-31")
return self._config
# create new trainer when not called with TrainConfig
def __new__(cls, *args, **kwargs):
if (len(args) > 0 and isinstance(args[0], TrainConfig)) \
or 'config' in kwargs:
return super(Trainer, cls).__new__(cls)
else:
import tensorpack.train as new_train
name = cls.__name__
new_trainer = getattr(new_train, name)
logger.warn("You're calling old trainers with new trainer API!")
logger.warn("Now it returns the new trainer for you, please `export TENSORPACK_TRAIN_API=v2`"
" to import new trainers automatically.")
logger.warn("You can also ignore this warning and wait for new API to become the default.")
return new_trainer(*args, **kwargs)
@six.add_metaclass(ABCMeta)
class SingleCostTrainer(TowerTrainer):
"""
Base class for single-cost trainer.
Single-cost trainer has a :meth:`setup_graph` method which takes
(inputs_desc, input, get_cost_fn, get_opt_fn), and build the training operations from them.
To use a SingleCostTrainer object, call `trainer.setup_graph(...); trainer.train(...)`. def _get_property(name):
""" """
Delegate property to self.loop
"""
ret = property(
lambda self: getattr(self.loop, name))
if six.PY3: # __doc__ is readonly in Py2
try:
ret.__doc__ = getattr(TrainLoop, name).__doc__
except AttributeError:
pass
return ret
def train(self,
callbacks, monitors,
session_creator, session_init,
steps_per_epoch, starting_epoch, max_epoch):
"""
Same as :meth:`Trainer.train()`, except that the callbacks this
trainer needs are automatically added.
"""
callbacks = callbacks + self._internal_callbacks
super(SingleCostTrainer, self).train(
callbacks, monitors,
session_creator, session_init,
steps_per_epoch, starting_epoch, max_epoch)
@call_only_once
def setup_graph(self, inputs_desc, input, get_cost_fn, get_opt_fn):
"""
Responsible for building the main training graph for single-cost training.
Args:
inputs_desc ([InputDesc]):
input (InputSource):
get_cost_fn ([tf.Tensor] -> tf.Tensor): callable, takes some input tenosrs and return a cost tensor.
Might get called multiple times for data-parallel training or inference.
get_opt_fn (-> tf.train.Optimizer): callable which returns an
optimizer. Will only be called once.
Returns:
[Callback]: a (possibly empty) list of callbacks needed for training.
These callbacks will be automatically added when you call `train()`.
So you can usually ignore the return value.
"""
get_cost_fn = TowerFuncWrapper(get_cost_fn, inputs_desc)
get_opt_fn = memoized(get_opt_fn)
self.set_tower_func(get_cost_fn)
input_callbacks = self._setup_input(inputs_desc, input)
train_callbacks = self._setup_graph(input, get_cost_fn, get_opt_fn)
self._internal_callbacks = input_callbacks + train_callbacks
return self._internal_callbacks
@abstractmethod
def _setup_graph(self, input, get_cost_fn, get_opt_fn):
pass
def _setup_input(self, inputs_desc, input):
assert not input.setup_done()
return input.setup(inputs_desc)
def _make_get_grad_fn(self, input, get_cost_fn, get_opt_fn): for name in ['global_step', 'local_step', 'steps_per_epoch',
""" 'epoch_num', 'starting_epoch', 'max_epoch']:
Returns: setattr(Trainer, name, _get_property(name))
a get_grad_fn for GraphBuilder to use.
"""
# internal use only
assert input.setup_done()
assert isinstance(input, FeedfreeInput), input
def get_grad_fn():
ctx = get_current_tower_context()
cost = get_cost_fn(*input.get_input_tensors())
varlist = ctx.filter_vars_by_vs_name(tf.trainable_variables())
opt = get_opt_fn()
grads = opt.compute_gradients(
cost, var_list=varlist,
gate_gradients=False, colocate_gradients_with_ops=True)
grads = FilterNoneGrad().process(grads)
return grads
return get_grad_fn
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# File: interface.py
__all__ = ['launch_train_with_config']
def launch_train_with_config(config, trainer):
from ..train.interface import launch_train_with_config as old_launch
old_launch(config, trainer)
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# File: utility.py
# for backwards-compatibility
from ..graph_builder.utils import ( # noqa
OverrideToLocalVariable,
override_to_local_variable, LeastLoadedDeviceSetter)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment