Commit 73c66c18 authored by Yuxin Wu's avatar Yuxin Wu

Merge branch 'master' into model-redesign

parents 9268bc8c ba4e3178
......@@ -8,6 +8,14 @@ so you won't need to look at here very often.
Here are a list of things that were changed, starting from an early version.
TensorFlow itself also changed APIs before 1.0 and those are not listed here.
+ [2017/10/21]
tensorpack is gradually switching to a new Trainer API.
Compatibility is kept in most ways but not guaranteed.
To switch to new API, the easiest way is to:
1. `export TENSORPACK_TRAIN_API=v2` (will be default in the future).
2. Replace `SomeTrainer(config, ...).train()` with `launch_train_with_config(config, SomeTrainer(...))`.
+ [2017/10/18]
`TrainConfig(predict_tower)` was deprecated. You can set the inference device directly when creating the `InferenceRunner` callback.
+ [2017/10/12](https://github.com/ppwwyyxx/tensorpack/commit/7e963996f615b85f7459455596b4ee9bbd0bce8e).
......
......@@ -22,9 +22,11 @@ In other words, an "epoch" in tensorpack is the __default period to run callback
### Common Trainers
Most neural network training tasks are single-cost optimization.
Tensorpack provides some trainer implementations for such tasks.
These trainers will build the graph based on the given `ModelDesc`, and minimizes `ModelDesc.cost`.
<!--
-Most neural network training tasks are single-cost optimization.
-Tensorpack provides some trainer implementations for such tasks.
-These trainers will build the graph based on the given `ModelDesc`, and minimizes `ModelDesc.cost`.
-->
<!--
-To use trainers, pass a `TrainConfig` to configure them:
......@@ -49,7 +51,7 @@ These trainers will build the graph based on the given `ModelDesc`, and minimize
-in the [Input Pipeline](input-source.html) tutorial.
-You can set the InputSource instead, to customize this behavior.
-->
Trainers are being redesigned, so the recommended API will likely be changed soon.
Trainers are being redesigned, this page will be updated soon.
Existing multi-GPU trainers include the logic of data-parallel training.
You can enable them by just one line, and all the necessary logic to achieve the best performance was baked into the trainers already.
......
......@@ -2,12 +2,13 @@
# -*- coding: UTF-8 -*-
# File: cifar-convnet.py
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
from tensorpack import *
import tensorflow as tf
import argparse
import numpy as np
import os
os.environ['TENSORPACK_TRAIN_API'] = 'v2' # will become default soon
from tensorpack import *
import tensorpack.tfutils.symbolic_functions as symbf
from tensorpack.tfutils.summary import *
from tensorpack.dataflow import dataset
......@@ -151,8 +152,7 @@ if __name__ == '__main__':
if args.load:
config.session_init = SaverRestore(args.load)
config.nr_tower = max(len(args.gpu.split(',')), 1)
if config.nr_tower <= 1:
QueueInputTrainer(config).train()
else:
SyncMultiGPUTrainerParameterServer(config).train()
nr_gpu = len(args.gpu.split(','))
trainer = QueueInputTrainer() if nr_gpu <= 1 \
else SyncMultiGPUTrainerParameterServer(list(range(nr_gpu)))
launch_train_with_config(config, trainer)
......@@ -12,6 +12,7 @@ MNIST ConvNet example.
about 0.6% validation error after 30 epochs.
"""
os.environ['TENSORPACK_TRAIN_API'] = 'v2' # will become default soon
# Just import everything into current namespace
from tensorpack import *
from tensorpack.tfutils import summary
......@@ -142,4 +143,4 @@ if __name__ == '__main__':
config.session_init = SaverRestore(args.load)
# SimpleTrainer is slow, this is just a demo.
# You can use QueueInputTrainer instead
SimpleTrainer(config).train()
launch_train_with_config(config, SimpleTrainer())
[flake8]
max-line-length = 120
ignore = F403,F401,F405,F841,E401
ignore = F403,F401,F405,F841,E401,E402
exclude = private,
FasterRCNN/utils
......@@ -18,9 +18,9 @@ if _HAS_TF:
# In development. Default to v1
if _os.environ.get('TENSORPACK_TRAIN_API', 'v1') == 'v2':
from tensorpack.trainv2 import *
else:
from tensorpack.train import *
else:
from tensorpack.trainv1 import *
from tensorpack.graph_builder import InputDesc, ModelDesc, ModelDescBase
from tensorpack.input_source import *
from tensorpack.predict import *
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# File: base.py
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
import time
import tensorflow as tf
import weakref
import six
import time
from six.moves import range
import six
from abc import abstractmethod, ABCMeta
import tensorflow as tf
from .config import TrainConfig
from ..utils import logger
from ..utils.develop import log_deprecated
from ..utils.argtools import call_only_once, memoized
from ..callbacks import Callback, Callbacks
from ..callbacks.monitor import Monitors, TrainingMonitor
from ..tfutils import get_global_step_value
from ..tfutils.model_utils import describe_trainable_vars
from ..tfutils.sesscreate import ReuseSessionCreator
from ..tfutils.sessinit import JustCurrentSession
from ..tfutils.tower import TowerFuncWrapper
from ..tfutils.sesscreate import ReuseSessionCreator
from ..tfutils.tower import TowerFuncWrapper, get_current_tower_context
from ..tfutils.gradproc import FilterNoneGrad
from ..callbacks.steps import MaintainStepCounter
from ..input_source import PlaceholderInput
from ..graph_builder.predictor_factory import SimplePredictBuilder
from ..input_source import PlaceholderInput
from ..predict.base import OnlinePredictor
from ..callbacks.steps import MaintainStepCounter
__all__ = ['Trainer', 'StopTraining']
class StopTraining(BaseException):
"""
An exception thrown to stop training.
"""
pass
class TrainLoop(object):
"""
Manage the double for loop.
"""
def __init__(self):
self._epoch_num = 0
self._global_step = 0
self._local_step = -1
def config(self, steps_per_epoch, starting_epoch, max_epoch):
"""
Configure the loop given the settings.
"""
self.starting_epoch = starting_epoch
self.max_epoch = max_epoch
self.steps_per_epoch = steps_per_epoch
import tensorpack.trainv1 as old_train # noqa
from ..trainv1.base import StopTraining, TrainLoop
from ..trainv1.config import TrainConfig
self._epoch_num = starting_epoch - 1
def update_global_step(self):
"""
Update the Python-side global_step from TF.
This must be called under initialized default session.
"""
self._global_step = get_global_step_value()
@property
def epoch_num(self):
"""
The number of the currently ongoing epoch.
An epoch is defined to cover the moment before calling `before_epoch` until after calling `trigger_epoch`.
i.e., in the `trigger_epoch` of epoch 3, `self.epoch_num` is 3.
If you need use `self.epoch_num` in your callback, you'll need to know this.
"""
return self._epoch_num
@property
def global_step(self):
"""
The tensorflow global_step, i.e. how many times ``hooked_sess.run`` has been called.
Note:
1. global_step is incremented **after** each ``hooked_sess.run`` returns from TF runtime.
2. If you make zero or more than one calls to ``hooked_sess.run`` in one
:meth:`run_step`, local_step and global_step may increment at different speed.
"""
return self._global_step
@property
def local_step(self):
"""
The number of steps that have finished in the current epoch.
"""
return self._local_step
__all__ = ['TrainConfig', 'Trainer', 'SingleCostTrainer', 'TowerTrainer']
class Trainer(object):
""" Base class for a trainer.
Attributes:
config (TrainConfig): the config used in this trainer.
model (ModelDesc): alias for ``config.model``.
sess (tf.Session): the current session in use.
hooked_sess (tf.train.MonitoredSession): the session with hooks.
monitors (Monitors): the monitors. Other callbacks can use it for logging.
"""
_API_VERSION = 1
_API_VERSION = 2
is_chief = True
"""
Whether this process is the chief worker in distributed training.
Only chief worker will run some callbacks.
"""
def __init__(self, config):
def __init__(self, config=None):
"""
Args:
config (TrainConfig): the train config.
config is only for compatibility reasons in case you're
using custom trainers with old-style API.
You should never use config.
"""
assert isinstance(config, TrainConfig), type(config)
self._config = config
self.model = config.model
if self.model is not None:
def f(*inputs):
self.model.build_graph(inputs)
"""
Only to mimic new trainer interafce on inference.
"""
self.inputs_desc = self.model.get_inputs_desc()
self.tower_func = TowerFuncWrapper(f, self.inputs_desc)
self._callbacks = []
self._monitors = []
self.loop = TrainLoop()
self.loop.config(config.steps_per_epoch, config.starting_epoch, config.max_epoch)
self._setup() # subclass will setup the graph and InputSource
def register_callback(self, cb):
self._monitors = [] # Clarify the type. Don't change from list to monitors.
# Hacks!
if config is not None:
logger.warn("You're initializing new trainer with old trainer API!")
logger.warn("This could happen if you wrote a custom trainer before.")
logger.warn("It may work now through some hacks, but please switch to the new API!")
self._config = config
self.inputs_desc = config.model.get_inputs_desc()
self.tower_func = TowerFuncWrapper(
lambda *inputs: config.model.build_graph(inputs),
self.inputs_desc)
self._main_tower_vs_name = ""
def gp(input_names, output_names, tower=0):
return TowerTrainer.get_predictor(self, input_names, output_names, device=tower)
self.get_predictor = gp
old_train = self.train
def train():
return old_train(
config.callbacks, config.monitors,
config.session_creator, config.session_init,
config.steps_per_epoch, config.starting_epoch, config.max_epoch)
self.train = train
def _register_callback(self, cb):
"""
Register a callback to the trainer.
It can only be called before :meth:`Trainer.train` gets called.
......@@ -151,7 +87,7 @@ class Trainer(object):
else:
self._callbacks.append(cb)
def register_monitor(self, mon):
def _register_monitor(self, mon):
"""
Register a monitor to the trainer.
It can only be called before :meth:`Trainer.train` gets called.
......@@ -162,18 +98,7 @@ class Trainer(object):
if not self.is_chief and mon.chief_only:
logger.warn("Monitor {} is chief-only, skipped.".format(str(mon)))
else:
self._monitors.append(mon)
self.register_callback(mon)
@property
def monitors(self):
assert isinstance(self._monitors, Monitors), "Monitors haven't been setup!"
return self._monitors
def train(self):
""" Start training """
self.setup()
self.main_loop()
self._register_callback(mon)
def run_step(self):
"""
......@@ -189,32 +114,44 @@ class Trainer(object):
"of Trainer.run_step()!")
self.hooked_sess.run(self.train_op)
def setup(self):
@call_only_once
def setup_callbacks(self, callbacks, monitors):
"""
Setup the trainer and be ready for the main loop.
Setup callbacks and monitors. Must be called after the main graph is built.
"""
self.register_callback(MaintainStepCounter())
for cb in self._config.callbacks:
self.register_callback(cb)
for m in self._config.monitors:
self.register_monitor(m)
self._monitors = Monitors(self._monitors)
self.register_callback(self._monitors)
describe_trainable_vars() # TODO weird
describe_trainable_vars()
self._register_callback(MaintainStepCounter())
for cb in callbacks:
self._register_callback(cb)
for m in monitors:
self._register_monitor(m)
self.monitors = Monitors(monitors)
self._register_callback(self.monitors) # monitors is also a callback
# some final operations that might modify the graph
logger.info("Setup callbacks graph ...")
self._callbacks = Callbacks(self._callbacks)
self._callbacks.setup_graph(weakref.proxy(self))
self._config.session_init._setup_graph()
@call_only_once
def initialize(self, session_creator, session_init):
"""
Initialize self.sess and self.hooked_sess.
Must be called after callbacks are setup.
"""
session_init._setup_graph()
logger.info("Creating the session ...")
self._create_session()
hooks = self._callbacks.get_hooks()
self.sess = session_creator.create_session()
self.hooked_sess = tf.train.MonitoredSession(
session_creator=ReuseSessionCreator(self.sess), hooks=hooks)
if self.is_chief:
logger.info("Initializing the session ...")
self._config.session_init._run_init(self.sess)
session_init._run_init(self.sess)
else:
if not isinstance(self._config.session_init, JustCurrentSession):
logger.warn("This is not a chief worker, 'session_init' was ignored!")
......@@ -222,35 +159,18 @@ class Trainer(object):
self.sess.graph.finalize()
logger.info("Graph Finalized.")
def _create_session(self):
"""
Setup self.sess (the raw tf.Session)
and self.hooked_sess (the session with hooks and coordinator)
"""
hooks = self._callbacks.get_hooks()
self.sess = self._config.session_creator.create_session()
self.hooked_sess = tf.train.MonitoredSession(
session_creator=ReuseSessionCreator(self.sess), hooks=hooks)
def _setup(self):
"""
Build the entire graph for training.
Responsible for setup InputSource as well (including registering InputSource callbacks)
Since this method will get called in constructor only,
you can simply leave it empty and build your graph outside the trainer.
"""
pass
def main_loop(self):
@call_only_once
def main_loop(self, steps_per_epoch, starting_epoch=1, max_epoch=99999):
"""
Run the main training loop.
"""
with self.sess.as_default():
self.loop.config(steps_per_epoch, starting_epoch, max_epoch)
self.loop.update_global_step()
try:
self._callbacks.before_train()
# refresh global step (might have changed by callbacks) TODO ugly
# what if gs is changed later?
self.loop.update_global_step()
for self.loop._epoch_num in range(
self.loop.starting_epoch, self.loop.max_epoch + 1):
......@@ -279,18 +199,106 @@ class Trainer(object):
self._callbacks.after_train()
self.hooked_sess.close()
def get_predictor(self, input_names, output_names, tower=0):
def train(self,
callbacks, monitors,
session_creator, session_init,
steps_per_epoch, starting_epoch, max_epoch):
"""
Implemented by:
.. code-block:: python
self.setup_callbacks(callbacks, monitors)
self.initialize(session_creator, session_init)
self.main_loop(steps_per_epoch, starting_epoch, max_epoch)
You can call those methods by yourself to have better control on details if needed.
"""
self.setup_callbacks(callbacks, monitors)
self.initialize(session_creator, session_init)
self.main_loop(steps_per_epoch, starting_epoch, max_epoch)
# create the old trainer when called with TrainConfig
def __new__(cls, *args, **kwargs):
if (len(args) > 0 and isinstance(args[0], TrainConfig)) \
or 'config' in kwargs:
name = cls.__name__
try:
old_trainer = getattr(old_train, name)
except AttributeError:
# custom trainer. has to live with it
return super(Trainer, cls).__new__(cls)
else:
logger.warn("You're calling new trainers with old trainer API!")
logger.warn("Now it returns the old trainer for you, please switch to use new trainers correctly!")
logger.warn("'SomeTrainer(config, ...).train()' should be equivalent to "
"'launch_train_with_config(config, SomeTrainer(...))' in the new API.")
return old_trainer(*args, **kwargs)
else:
return super(Trainer, cls).__new__(cls)
def _get_property(name):
"""
Delegate property to self.loop
"""
ret = property(
lambda self: getattr(self.loop, name))
if six.PY3: # __doc__ is readonly in Py2
try:
ret.__doc__ = getattr(TrainLoop, name).__doc__
except AttributeError:
pass
return ret
for name in ['global_step', 'local_step', 'steps_per_epoch',
'epoch_num', 'starting_epoch', 'max_epoch']:
setattr(Trainer, name, _get_property(name))
class TowerTrainer(Trainer):
"""
Base trainers for models that can be built by calling a tower function under a :class:`TowerContext`.
This is required by some features that replicates the model
automatically, e.g. creating a predictor.
"""
tower_func = None
"""
A :class:`TowerFuncWrapper` instance.
A callable which takes some input tensors and builds one replicate of the model.
"""
@call_only_once
def set_tower_func(self, tower_func):
"""
Args:
tower_func (TowerFuncWrapper)
"""
assert isinstance(tower_func, TowerFuncWrapper), tower_func
self.tower_func = tower_func
@property
def inputs_desc(self):
"""
Returns:
list[InputDesc]: metainfo about the inputs to the tower.
"""
return self.tower_func.inputs_desc
def get_predictor(self, input_names, output_names, device=0):
"""
Returns a callable predictor built under ``TowerContext(is_training=False)``.
Args:
input_names (list), output_names(list): list of names
tower (int): build the predictor on device '/gpu:{tower}' or use -1 for '/cpu:0'.
device (int): build the predictor on device '/gpu:{device}' or use -1 for '/cpu:0'.
Returns:
an :class:`OnlinePredictor`.
"""
device = tower
assert self.tower_func is not None, "Must set tower_func on the trainer to use get_predictor()!"
tower_name = 'tower-pred-{}'.format(device) if device >= 0 else 'tower-pred-cpu'
......@@ -311,33 +319,91 @@ class Trainer(object):
@property
def _main_tower_vs_name(self):
# The vs name a predictor should be built under.
# for internal use only. Should let graphbuilder return it.
"""
The vs name for the "main" copy of the model,
to be used to build predictors.
"""
return ""
@property
def config(self):
log_deprecated(
"Trainer.config",
"It is supposed to be private! Most of its attributes can be accessed by other means.",
"2017-12-31")
return self._config
def _get_property(name):
@six.add_metaclass(ABCMeta)
class SingleCostTrainer(TowerTrainer):
"""
Delegate property to self.loop
Base class for single-cost trainer.
Single-cost trainer has a :meth:`setup_graph` method which takes
(inputs_desc, input, get_cost_fn, get_opt_fn), and build the training operations from them.
To use a SingleCostTrainer object, call `trainer.setup_graph(...); trainer.train(...)`.
"""
ret = property(
lambda self: getattr(self.loop, name))
if six.PY3: # __doc__ is readonly in Py2
try:
ret.__doc__ = getattr(TrainLoop, name).__doc__
except AttributeError:
pass
return ret
def train(self,
callbacks, monitors,
session_creator, session_init,
steps_per_epoch, starting_epoch, max_epoch):
"""
Same as :meth:`Trainer.train()`, except that the callbacks this
trainer needs are automatically added.
"""
callbacks = callbacks + self._internal_callbacks
super(SingleCostTrainer, self).train(
callbacks, monitors,
session_creator, session_init,
steps_per_epoch, starting_epoch, max_epoch)
@call_only_once
def setup_graph(self, inputs_desc, input, get_cost_fn, get_opt_fn):
"""
Responsible for building the main training graph for single-cost training.
for name in ['global_step', 'local_step', 'steps_per_epoch',
'epoch_num', 'starting_epoch', 'max_epoch']:
setattr(Trainer, name, _get_property(name))
Args:
inputs_desc ([InputDesc]):
input (InputSource):
get_cost_fn ([tf.Tensor] -> tf.Tensor): callable, takes some input tenosrs and return a cost tensor.
Might get called multiple times for data-parallel training or inference.
get_opt_fn (-> tf.train.Optimizer): callable which returns an
optimizer. Will only be called once.
Returns:
[Callback]: a (possibly empty) list of callbacks needed for training.
These callbacks will be automatically added when you call `train()`.
So you can usually ignore the return value.
"""
get_cost_fn = TowerFuncWrapper(get_cost_fn, inputs_desc)
get_opt_fn = memoized(get_opt_fn)
self.set_tower_func(get_cost_fn)
input_callbacks = self._setup_input(inputs_desc, input)
train_callbacks = self._setup_graph(input, get_cost_fn, get_opt_fn)
self._internal_callbacks = input_callbacks + train_callbacks
return self._internal_callbacks
@abstractmethod
def _setup_graph(self, input, get_cost_fn, get_opt_fn):
pass
def _setup_input(self, inputs_desc, input):
assert not input.setup_done()
return input.setup(inputs_desc)
def _make_get_grad_fn(self, input, get_cost_fn, get_opt_fn):
"""
Returns:
a get_grad_fn for GraphBuilder to use.
"""
# internal use only
assert input.setup_done()
def get_grad_fn():
ctx = get_current_tower_context()
cost = get_cost_fn(*input.get_input_tensors())
varlist = ctx.filter_vars_by_vs_name(tf.trainable_variables())
opt = get_opt_fn()
grads = opt.compute_gradients(
cost, var_list=varlist,
gate_gradients=False, colocate_gradients_with_ops=True)
grads = FilterNoneGrad().process(grads)
return grads
return get_grad_fn
......@@ -7,11 +7,11 @@ import tensorflow as tf
from ..input_source import (
InputSource, FeedInput, QueueInput, StagingInputWrapper, DummyConstantInput)
from ..train.config import TrainConfig
from ..trainv1.config import TrainConfig
from .base import SingleCostTrainer
from .trainers import SimpleTrainer, DistributedTrainerReplicated
__all__ = ['launch_train_with_config', 'TrainConfig', 'apply_default_prefetch']
__all__ = ['launch_train_with_config', 'apply_default_prefetch']
def apply_default_prefetch(input_source_or_dataflow, trainer, towers):
......
......@@ -24,6 +24,7 @@ from .base import SingleCostTrainer
__all__ = ['SimpleTrainer',
'QueueInputTrainer',
'SyncMultiGPUTrainer',
'SyncMultiGPUTrainerReplicated',
'SyncMultiGPUTrainerParameterServer',
'AsyncMultiGPUTrainer',
......@@ -68,6 +69,17 @@ class SyncMultiGPUTrainerParameterServer(SingleCostTrainer):
return []
def SyncMultiGPUTrainer(towers):
"""
Return a default multi-GPU trainer, if you don't care about the details.
It may not be the most efficient one for your task.
Args:
towers (list[int]): list of GPU ids.
"""
return SyncMultiGPUTrainerParameterServer(towers, ps_device='gpu')
class AsyncMultiGPUTrainer(SingleCostTrainer):
__doc__ = AsyncMultiGPUBuilder.__doc__
......
......@@ -19,7 +19,7 @@ def global_import(name):
_CURR_DIR = os.path.dirname(__file__)
_SKIP = []
_SKIP = ['utility']
for _, module_name, _ in iter_modules(
[_CURR_DIR]):
srcpath = os.path.join(_CURR_DIR, module_name + '.py')
......
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# File: base.py
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
import tensorflow as tf
import weakref
import time
from six.moves import range
import weakref
import six
from abc import abstractmethod, ABCMeta
from six.moves import range
import tensorflow as tf
from .config import TrainConfig
from ..utils import logger
from ..utils.argtools import call_only_once, memoized
from ..utils.develop import log_deprecated
from ..callbacks import Callback, Callbacks
from ..callbacks.monitor import Monitors, TrainingMonitor
from ..tfutils import get_global_step_value
from ..tfutils.model_utils import describe_trainable_vars
from ..tfutils.sessinit import JustCurrentSession
from ..tfutils.sesscreate import ReuseSessionCreator
from ..tfutils.tower import TowerFuncWrapper, get_current_tower_context
from ..tfutils.gradproc import FilterNoneGrad
from ..callbacks.steps import MaintainStepCounter
from ..tfutils.sessinit import JustCurrentSession
from ..tfutils.tower import TowerFuncWrapper
from ..input_source import PlaceholderInput
from ..graph_builder.predictor_factory import SimplePredictBuilder
from ..input_source import FeedfreeInput, PlaceholderInput
from ..predict.base import OnlinePredictor
from ..callbacks.steps import MaintainStepCounter
__all__ = ['Trainer', 'StopTraining']
class StopTraining(BaseException):
"""
An exception thrown to stop training.
"""
pass
class TrainLoop(object):
"""
Manage the double for loop.
"""
import tensorpack.train as old_train # noqa
from ..train.base import StopTraining, TrainLoop
def __init__(self):
self._epoch_num = 0
self._global_step = 0
self._local_step = -1
__all__ = ['Trainer', 'SingleCostTrainer', 'TowerTrainer']
def config(self, steps_per_epoch, starting_epoch, max_epoch):
"""
Configure the loop given the settings.
"""
self.starting_epoch = starting_epoch
self.max_epoch = max_epoch
self.steps_per_epoch = steps_per_epoch
self._epoch_num = starting_epoch - 1
def update_global_step(self):
"""
Update the Python-side global_step from TF.
This must be called under initialized default session.
"""
self._global_step = get_global_step_value()
@property
def epoch_num(self):
"""
The number of the currently ongoing epoch.
An epoch is defined to cover the moment before calling `before_epoch` until after calling `trigger_epoch`.
i.e., in the `trigger_epoch` of epoch 3, `self.epoch_num` is 3.
If you need use `self.epoch_num` in your callback, you'll need to know this.
"""
return self._epoch_num
@property
def global_step(self):
"""
The tensorflow global_step, i.e. how many times ``hooked_sess.run`` has been called.
Note:
1. global_step is incremented **after** each ``hooked_sess.run`` returns from TF runtime.
2. If you make zero or more than one calls to ``hooked_sess.run`` in one
:meth:`run_step`, local_step and global_step may increment at different speed.
"""
return self._global_step
@property
def local_step(self):
"""
The number of steps that have finished in the current epoch.
"""
return self._local_step
class Trainer(object):
""" Base class for a trainer.
Attributes:
config (TrainConfig): the config used in this trainer.
model (ModelDesc): alias for ``config.model``.
sess (tf.Session): the current session in use.
hooked_sess (tf.train.MonitoredSession): the session with hooks.
monitors (Monitors): the monitors. Other callbacks can use it for logging.
"""
_API_VERSION = 2
_API_VERSION = 1
is_chief = True
"""
Whether this process is the chief worker in distributed training.
Only chief worker will run some callbacks.
"""
def __init__(self, config=None):
def __init__(self, config):
"""
config is only for compatibility reasons in case you're
using custom trainers with old-style API.
You should never use config.
Args:
config (TrainConfig): the train config.
"""
assert isinstance(config, TrainConfig), type(config)
self._config = config
self.model = config.model
if self.model is not None:
def f(*inputs):
self.model.build_graph(inputs)
"""
Only to mimic new trainer interafce on inference.
"""
self.inputs_desc = self.model.get_inputs_desc()
self.tower_func = TowerFuncWrapper(f, self.inputs_desc)
self._callbacks = []
self._monitors = []
self.loop = TrainLoop()
self._monitors = [] # Clarify the type. Don't change from list to monitors.
# Hacks!
if config is not None:
logger.warn("You're initializing new trainer with old trainer API!")
logger.warn("This could happen if you wrote a custom trainer before.")
logger.warn("It may work now through some hacks, but please switch to the new API!")
self._config = config
self.inputs_desc = config.model.get_inputs_desc()
self.tower_func = TowerFuncWrapper(
lambda *inputs: config.model.build_graph(inputs),
self.inputs_desc)
self._main_tower_vs_name = ""
def gp(input_names, output_names, tower=0):
return TowerTrainer.get_predictor(self, input_names, output_names, device=tower)
self.get_predictor = gp
old_train = self.train
def train():
return old_train(
config.callbacks, config.monitors,
config.session_creator, config.session_init,
config.steps_per_epoch, config.starting_epoch, config.max_epoch)
self.train = train
def _register_callback(self, cb):
self.loop.config(config.steps_per_epoch, config.starting_epoch, config.max_epoch)
self._setup() # subclass will setup the graph and InputSource
def register_callback(self, cb):
"""
Register a callback to the trainer.
It can only be called before :meth:`Trainer.train` gets called.
......@@ -86,7 +151,7 @@ class Trainer(object):
else:
self._callbacks.append(cb)
def _register_monitor(self, mon):
def register_monitor(self, mon):
"""
Register a monitor to the trainer.
It can only be called before :meth:`Trainer.train` gets called.
......@@ -97,7 +162,18 @@ class Trainer(object):
if not self.is_chief and mon.chief_only:
logger.warn("Monitor {} is chief-only, skipped.".format(str(mon)))
else:
self._register_callback(mon)
self._monitors.append(mon)
self.register_callback(mon)
@property
def monitors(self):
assert isinstance(self._monitors, Monitors), "Monitors haven't been setup!"
return self._monitors
def train(self):
""" Start training """
self.setup()
self.main_loop()
def run_step(self):
"""
......@@ -113,44 +189,32 @@ class Trainer(object):
"of Trainer.run_step()!")
self.hooked_sess.run(self.train_op)
@call_only_once
def setup_callbacks(self, callbacks, monitors):
def setup(self):
"""
Setup callbacks and monitors. Must be called after the main graph is built.
Setup the trainer and be ready for the main loop.
"""
describe_trainable_vars() # TODO weird
self.register_callback(MaintainStepCounter())
for cb in self._config.callbacks:
self.register_callback(cb)
for m in self._config.monitors:
self.register_monitor(m)
self._monitors = Monitors(self._monitors)
self.register_callback(self._monitors)
self._register_callback(MaintainStepCounter())
for cb in callbacks:
self._register_callback(cb)
for m in monitors:
self._register_monitor(m)
self.monitors = Monitors(monitors)
self._register_callback(self.monitors) # monitors is also a callback
describe_trainable_vars()
# some final operations that might modify the graph
logger.info("Setup callbacks graph ...")
self._callbacks = Callbacks(self._callbacks)
self._callbacks.setup_graph(weakref.proxy(self))
@call_only_once
def initialize(self, session_creator, session_init):
"""
Initialize self.sess and self.hooked_sess.
Must be called after callbacks are setup.
"""
session_init._setup_graph()
self._config.session_init._setup_graph()
logger.info("Creating the session ...")
hooks = self._callbacks.get_hooks()
self.sess = session_creator.create_session()
self.hooked_sess = tf.train.MonitoredSession(
session_creator=ReuseSessionCreator(self.sess), hooks=hooks)
self._create_session()
if self.is_chief:
logger.info("Initializing the session ...")
session_init._run_init(self.sess)
self._config.session_init._run_init(self.sess)
else:
if not isinstance(self._config.session_init, JustCurrentSession):
logger.warn("This is not a chief worker, 'session_init' was ignored!")
......@@ -158,18 +222,35 @@ class Trainer(object):
self.sess.graph.finalize()
logger.info("Graph Finalized.")
@call_only_once
def main_loop(self, steps_per_epoch, starting_epoch=1, max_epoch=99999):
def _create_session(self):
"""
Setup self.sess (the raw tf.Session)
and self.hooked_sess (the session with hooks and coordinator)
"""
hooks = self._callbacks.get_hooks()
self.sess = self._config.session_creator.create_session()
self.hooked_sess = tf.train.MonitoredSession(
session_creator=ReuseSessionCreator(self.sess), hooks=hooks)
def _setup(self):
"""
Build the entire graph for training.
Responsible for setup InputSource as well (including registering InputSource callbacks)
Since this method will get called in constructor only,
you can simply leave it empty and build your graph outside the trainer.
"""
pass
def main_loop(self):
"""
Run the main training loop.
"""
with self.sess.as_default():
self.loop.config(steps_per_epoch, starting_epoch, max_epoch)
self.loop.update_global_step()
try:
self._callbacks.before_train()
# refresh global step (might have changed by callbacks) TODO ugly
# what if gs is changed later?
self.loop.update_global_step()
for self.loop._epoch_num in range(
self.loop.starting_epoch, self.loop.max_epoch + 1):
......@@ -198,106 +279,18 @@ class Trainer(object):
self._callbacks.after_train()
self.hooked_sess.close()
def train(self,
callbacks, monitors,
session_creator, session_init,
steps_per_epoch, starting_epoch, max_epoch):
"""
Implemented by:
.. code-block:: python
self.setup_callbacks(callbacks, monitors)
self.initialize(session_creator, session_init)
self.main_loop(steps_per_epoch, starting_epoch, max_epoch)
You can call those methods by yourself to have better control on details if needed.
"""
self.setup_callbacks(callbacks, monitors)
self.initialize(session_creator, session_init)
self.main_loop(steps_per_epoch, starting_epoch, max_epoch)
# create the old trainer when called with TrainConfig
def __new__(cls, *args, **kwargs):
if (len(args) > 0 and isinstance(args[0], old_train.TrainConfig)) \
or 'config' in kwargs:
name = cls.__name__
try:
old_trainer = getattr(old_train, name)
except AttributeError:
# custom trainer. has to live with it
return super(Trainer, cls).__new__(cls)
else:
logger.warn("You're creating trainers with old trainer API!")
logger.warn("Now it returns the old trainer for you, please switch to the new API!")
logger.warn("'SomeTrainer(config, ...).train()' should be equivalent to "
"'launch_train_with_config(config, SomeTrainer(...))' in the new API.")
return old_trainer(*args, **kwargs)
else:
return super(Trainer, cls).__new__(cls)
def _get_property(name):
"""
Delegate property to self.loop
"""
ret = property(
lambda self: getattr(self.loop, name))
if six.PY3: # __doc__ is readonly in Py2
try:
ret.__doc__ = getattr(TrainLoop, name).__doc__
except AttributeError:
pass
return ret
for name in ['global_step', 'local_step', 'steps_per_epoch',
'epoch_num', 'starting_epoch', 'max_epoch']:
setattr(Trainer, name, _get_property(name))
class TowerTrainer(Trainer):
"""
Base trainers for models that can be built by calling a tower function under a :class:`TowerContext`.
This is required by some features that replicates the model
automatically, e.g. creating a predictor.
"""
tower_func = None
"""
A :class:`TowerFuncWrapper` instance.
A callable which takes some input tensors and builds one replicate of the model.
"""
@call_only_once
def set_tower_func(self, tower_func):
"""
Args:
tower_func (TowerFuncWrapper)
"""
assert isinstance(tower_func, TowerFuncWrapper), tower_func
self.tower_func = tower_func
@property
def inputs_desc(self):
"""
Returns:
list[InputDesc]: metainfo about the inputs to the tower.
"""
return self.tower_func.inputs_desc
def get_predictor(self, input_names, output_names, device=0):
def get_predictor(self, input_names, output_names, tower=0):
"""
Returns a callable predictor built under ``TowerContext(is_training=False)``.
Args:
input_names (list), output_names(list): list of names
device (int): build the predictor on device '/gpu:{device}' or use -1 for '/cpu:0'.
tower (int): build the predictor on device '/gpu:{tower}' or use -1 for '/cpu:0'.
Returns:
an :class:`OnlinePredictor`.
"""
device = tower
assert self.tower_func is not None, "Must set tower_func on the trainer to use get_predictor()!"
tower_name = 'tower-pred-{}'.format(device) if device >= 0 else 'tower-pred-cpu'
......@@ -318,92 +311,48 @@ class TowerTrainer(Trainer):
@property
def _main_tower_vs_name(self):
"""
The vs name for the "main" copy of the model,
to be used to build predictors.
"""
# The vs name a predictor should be built under.
# for internal use only. Should let graphbuilder return it.
return ""
@property
def config(self):
log_deprecated(
"Trainer.config",
"It is supposed to be private! Most of its attributes can be accessed by other means.",
"2017-12-31")
return self._config
# create new trainer when not called with TrainConfig
def __new__(cls, *args, **kwargs):
if (len(args) > 0 and isinstance(args[0], TrainConfig)) \
or 'config' in kwargs:
return super(Trainer, cls).__new__(cls)
else:
import tensorpack.train as new_train
name = cls.__name__
new_trainer = getattr(new_train, name)
logger.warn("You're calling old trainers with new trainer API!")
logger.warn("Now it returns the new trainer for you, please `export TENSORPACK_TRAIN_API=v2`"
" to import new trainers automatically.")
logger.warn("You can also ignore this warning and wait for new API to become the default.")
return new_trainer(*args, **kwargs)
@six.add_metaclass(ABCMeta)
class SingleCostTrainer(TowerTrainer):
"""
Base class for single-cost trainer.
Single-cost trainer has a :meth:`setup_graph` method which takes
(inputs_desc, input, get_cost_fn, get_opt_fn), and build the training operations from them.
To use a SingleCostTrainer object, call `trainer.setup_graph(...); trainer.train(...)`.
def _get_property(name):
"""
Delegate property to self.loop
"""
ret = property(
lambda self: getattr(self.loop, name))
if six.PY3: # __doc__ is readonly in Py2
try:
ret.__doc__ = getattr(TrainLoop, name).__doc__
except AttributeError:
pass
return ret
def train(self,
callbacks, monitors,
session_creator, session_init,
steps_per_epoch, starting_epoch, max_epoch):
"""
Same as :meth:`Trainer.train()`, except that the callbacks this
trainer needs are automatically added.
"""
callbacks = callbacks + self._internal_callbacks
super(SingleCostTrainer, self).train(
callbacks, monitors,
session_creator, session_init,
steps_per_epoch, starting_epoch, max_epoch)
@call_only_once
def setup_graph(self, inputs_desc, input, get_cost_fn, get_opt_fn):
"""
Responsible for building the main training graph for single-cost training.
Args:
inputs_desc ([InputDesc]):
input (InputSource):
get_cost_fn ([tf.Tensor] -> tf.Tensor): callable, takes some input tenosrs and return a cost tensor.
Might get called multiple times for data-parallel training or inference.
get_opt_fn (-> tf.train.Optimizer): callable which returns an
optimizer. Will only be called once.
Returns:
[Callback]: a (possibly empty) list of callbacks needed for training.
These callbacks will be automatically added when you call `train()`.
So you can usually ignore the return value.
"""
get_cost_fn = TowerFuncWrapper(get_cost_fn, inputs_desc)
get_opt_fn = memoized(get_opt_fn)
self.set_tower_func(get_cost_fn)
input_callbacks = self._setup_input(inputs_desc, input)
train_callbacks = self._setup_graph(input, get_cost_fn, get_opt_fn)
self._internal_callbacks = input_callbacks + train_callbacks
return self._internal_callbacks
@abstractmethod
def _setup_graph(self, input, get_cost_fn, get_opt_fn):
pass
def _setup_input(self, inputs_desc, input):
assert not input.setup_done()
return input.setup(inputs_desc)
def _make_get_grad_fn(self, input, get_cost_fn, get_opt_fn):
"""
Returns:
a get_grad_fn for GraphBuilder to use.
"""
# internal use only
assert input.setup_done()
assert isinstance(input, FeedfreeInput), input
def get_grad_fn():
ctx = get_current_tower_context()
cost = get_cost_fn(*input.get_input_tensors())
varlist = ctx.filter_vars_by_vs_name(tf.trainable_variables())
opt = get_opt_fn()
grads = opt.compute_gradients(
cost, var_list=varlist,
gate_gradients=False, colocate_gradients_with_ops=True)
grads = FilterNoneGrad().process(grads)
return grads
return get_grad_fn
for name in ['global_step', 'local_step', 'steps_per_epoch',
'epoch_num', 'starting_epoch', 'max_epoch']:
setattr(Trainer, name, _get_property(name))
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# File: interface.py
__all__ = ['launch_train_with_config']
def launch_train_with_config(config, trainer):
from ..train.interface import launch_train_with_config as old_launch
old_launch(config, trainer)
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# File: utility.py
# for backwards-compatibility
from ..graph_builder.utils import ( # noqa
OverrideToLocalVariable,
override_to_local_variable, LeastLoadedDeviceSetter)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment