Commit ea0f1b90 authored by Yuxin Wu's avatar Yuxin Wu

split trainer to tower.py

parent 0b8727b0
# Trainer
Tensorpack trainers prepares and runs the training, which consists of the following steps:
Tensorpack follows the "define-and-run" paradigm. A training has two steps:
1. __Build graph__ for the model.
1. Build graph for the model.
Users can call whatever tensorflow functions to setup the graph.
Users may or may not use tensorpack `InputSource`, `ModelDesc` to build the graph.
This step defines "what to run" in every training step.
It can happen either inside or outside the trainer.
2. Train the model (the [Trainer.train() method](http://tensorpack.readthedocs.io/en/latest/modules/train.html#tensorpack.train.Trainer.train)):
......
# Training Interface
Tensorpack trainers provide low-level API which requires a number of options to setup.
There are high-level interfaces built on top of trainer to simplify the use,
Tensorpack trainers have an interface for maximum flexibility.
There are also interfaces built on top of trainers to simplify the use,
when you don't want to customize too much.
### With ModelDesc and TrainConfig
[SingleCost trainers](trainer.html#single-cost-trainers)
expects `InputDesc`, `InputSource`, get_cost function, and optimizer.
`ModelDesc` describes a model by packing three of them together into one object:
expects 4 arguments to build the graph: `InputDesc`, `InputSource`, get_cost function, and optimizer.
`ModelDesc` describes a model by packing 3 of them together into one object:
```python
class MyModel(ModelDesc):
......@@ -25,9 +25,9 @@ class MyModel(ModelDesc):
return tf.train.GradientDescentOptimizer(0.1)
```
`_get_inputs` should define the metainfo of all the inputs your graph may need.
`_build_graph` should add tensors/operations to the graph, where
the argument `inputs` is a list of tensors which will match `_get_inputs`.
`_get_inputs` should define the metainfo of all the inputs your graph will take to build.
`_build_graph` takes a list of `inputs` tensors which will match `_get_inputs`.
You can use any symbolic functions in `_build_graph`, including TensorFlow core library
functions and other symbolic libraries.
......
......@@ -7,7 +7,6 @@ import weakref
import time
from six.moves import range
import six
from abc import abstractmethod, ABCMeta
from ..callbacks import (
Callback, Callbacks, Monitors, TrainingMonitor,
......@@ -15,26 +14,29 @@ from ..callbacks import (
ProgressBar, MergeAllSummaries,
TFEventWriter, JSONWriter, ScalarPrinter, RunUpdateOps)
from ..utils import logger
from ..utils.argtools import call_only_once, memoized
from ..utils.argtools import call_only_once
from ..tfutils.tower import TowerFuncWrapper
from ..tfutils.model_utils import describe_trainable_vars
from ..tfutils.sessinit import JustCurrentSession
from ..tfutils.sesscreate import ReuseSessionCreator, NewSessionCreator
from ..tfutils.tower import TowerFuncWrapper, get_current_tower_context
from ..tfutils.gradproc import FilterNoneGrad
from ..callbacks.steps import MaintainStepCounter
from ..graph_builder.predictor_factory import SimplePredictBuilder
from ..input_source import PlaceholderInput
from ..predict.base import OnlinePredictor
import tensorpack.trainv1 as old_train # noqa
from ..trainv1.base import StopTraining, TrainLoop
from ..trainv1.config import TrainConfig
__all__ = ['TrainConfig', 'Trainer', 'SingleCostTrainer', 'TowerTrainer']
__all__ = ['TrainConfig', 'Trainer']
def DEFAULT_CALLBACKS():
"""
Return the default callbacks. They are:
1. MovingAverageSummary()
2. ProgressBar()
3. MergeAllSummaries()
4. RunUpdateOps()
"""
return [
MovingAverageSummary(),
ProgressBar(),
......@@ -43,6 +45,13 @@ def DEFAULT_CALLBACKS():
def DEFAULT_MONITORS():
"""
Return the default monitors. They are:
1. TFEventWriter()
2. JSONWriter()
3. ScalarPrinter()
"""
return [TFEventWriter(), JSONWriter(), ScalarPrinter()]
......@@ -77,6 +86,7 @@ class Trainer(object):
self._main_tower_vs_name = ""
def gp(input_names, output_names, tower=0):
from .tower import TowerTrainer
return TowerTrainer.get_predictor(self, input_names, output_names, device=tower)
self.get_predictor = gp
......@@ -314,151 +324,3 @@ def _get_property(name):
for name in ['global_step', 'local_step', 'steps_per_epoch',
'epoch_num', 'starting_epoch', 'max_epoch']:
setattr(Trainer, name, _get_property(name))
class TowerTrainer(Trainer):
"""
Base trainers for models that can be built by calling a tower function under a :class:`TowerContext`.
This is required by some features that replicates the model
automatically, e.g. creating a predictor.
"""
tower_func = None
"""
A :class:`TowerFuncWrapper` instance.
A callable which takes some input tensors and builds one replicate of the model.
"""
@call_only_once
def set_tower_func(self, tower_func):
"""
Args:
tower_func (TowerFuncWrapper)
"""
assert isinstance(tower_func, TowerFuncWrapper), tower_func
self.tower_func = tower_func
@property
def inputs_desc(self):
"""
Returns:
list[InputDesc]: metainfo about the inputs to the tower.
"""
return self.tower_func.inputs_desc
def get_predictor(self, input_names, output_names, device=0):
"""
Returns a callable predictor built under ``TowerContext(is_training=False)``.
Args:
input_names (list), output_names(list): list of names
device (int): build the predictor on device '/gpu:{device}' or use -1 for '/cpu:0'.
Returns:
an :class:`OnlinePredictor`.
"""
assert self.tower_func is not None, "Must set tower_func on the trainer to use get_predictor()!"
tower_name = 'tower-pred-{}'.format(device) if device >= 0 else 'tower-pred-cpu'
try:
tower = self.tower_func.towers[tower_name]
except KeyError:
input = PlaceholderInput()
input.setup(self.inputs_desc)
with tf.variable_scope(tf.get_variable_scope(), reuse=True):
SimplePredictBuilder(
ns_name=tower_name, vs_name=self._main_tower_vs_name,
device=device).build(input, self.tower_func)
tower = self.tower_func.towers[tower_name]
input_tensors = tower.get_tensors(input_names)
output_tensors = tower.get_tensors(output_names)
return OnlinePredictor(input_tensors, output_tensors)
@property
def _main_tower_vs_name(self):
"""
The vs name for the "main" copy of the model,
to be used to build predictors.
"""
return ""
@six.add_metaclass(ABCMeta)
class SingleCostTrainer(TowerTrainer):
"""
Base class for single-cost trainer.
Single-cost trainer has a :meth:`setup_graph` method which takes
(inputs_desc, input, get_cost_fn, get_opt_fn), and build the training operations from them.
To use a :class:`SingleCostTrainer` object, call `trainer.setup_graph(...); trainer.train(...)`.
"""
@call_only_once
def setup_graph(self, inputs_desc, input, get_cost_fn, get_opt_fn):
"""
Responsible for building the main training graph for single-cost training.
Args:
inputs_desc ([InputDesc]):
input (InputSource):
get_cost_fn ([tf.Tensor] -> tf.Tensor): callable, takes some input tenosrs and return a cost tensor.
get_opt_fn (-> tf.train.Optimizer): callable which returns an
optimizer. Will only be called once.
Note:
1. `get_cost_fn` will always be called under a :class:`TowerContext`.
which will contain information abouut reuse,
training/inference, scope name, etc.
2. `get_cost_fn` might get called multiple times for data-parallel training or inference.
3. To respect variable reuse, use `tf.get_variable` instead of
`tf.Variable` in `get_cost_fn`.
"""
get_cost_fn = TowerFuncWrapper(get_cost_fn, inputs_desc)
get_opt_fn = memoized(get_opt_fn)
self.set_tower_func(get_cost_fn)
input_callbacks = self._setup_input(inputs_desc, input)
train_callbacks = self._setup_graph(input, get_cost_fn, get_opt_fn)
internal_callbacks = input_callbacks + train_callbacks
for cb in internal_callbacks:
self._register_callback(cb)
# TODO register directly instead of return?
@abstractmethod
def _setup_graph(self, input, get_cost_fn, get_opt_fn):
"""
Implement the logic to build the graph, with an :class:`InputSource`
that's been setup already.
Returns:
[Callback]: list of callbacks needed
"""
def _setup_input(self, inputs_desc, input):
assert not input.setup_done()
return input.setup(inputs_desc)
def _make_get_grad_fn(self, input, get_cost_fn, get_opt_fn):
"""
Returns:
a get_grad_fn for GraphBuilder to use.
"""
# internal use only
assert input.setup_done()
def get_grad_fn():
ctx = get_current_tower_context()
cost = get_cost_fn(*input.get_input_tensors())
varlist = ctx.filter_vars_by_vs_name(tf.trainable_variables())
opt = get_opt_fn()
grads = opt.compute_gradients(
cost, var_list=varlist,
gate_gradients=False, colocate_gradients_with_ops=True)
grads = FilterNoneGrad().process(grads)
return grads
return get_grad_fn
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# File: tower.py
import tensorflow as tf
import six
from abc import abstractmethod, ABCMeta
from ..utils.argtools import call_only_once, memoized
from ..graph_builder.predictor_factory import SimplePredictBuilder
from ..input_source import PlaceholderInput
from ..predict.base import OnlinePredictor
from ..tfutils.tower import TowerFuncWrapper, get_current_tower_context
from ..tfutils.gradproc import FilterNoneGrad
from .base import Trainer
__all__ = ['SingleCostTrainer', 'TowerTrainer']
class TowerTrainer(Trainer):
"""
Base trainers for models that can be built by calling a tower function under a :class:`TowerContext`.
This is required by some features that replicates the model
automatically, e.g. creating a predictor.
"""
tower_func = None
"""
A :class:`TowerFuncWrapper` instance.
A callable which takes some input tensors and builds one replicate of the model.
"""
@call_only_once
def set_tower_func(self, tower_func):
"""
Args:
tower_func (TowerFuncWrapper)
"""
assert isinstance(tower_func, TowerFuncWrapper), tower_func
self.tower_func = tower_func
@property
def inputs_desc(self):
"""
Returns:
list[InputDesc]: metainfo about the inputs to the tower.
"""
return self.tower_func.inputs_desc
def get_predictor(self, input_names, output_names, device=0):
"""
Returns a callable predictor built under ``TowerContext(is_training=False)``.
Args:
input_names (list), output_names(list): list of names
device (int): build the predictor on device '/gpu:{device}' or use -1 for '/cpu:0'.
Returns:
an :class:`OnlinePredictor`.
"""
assert self.tower_func is not None, "Must set tower_func on the trainer to use get_predictor()!"
tower_name = 'tower-pred-{}'.format(device) if device >= 0 else 'tower-pred-cpu'
try:
tower = self.tower_func.towers[tower_name]
except KeyError:
input = PlaceholderInput()
input.setup(self.inputs_desc)
with tf.variable_scope(tf.get_variable_scope(), reuse=True):
SimplePredictBuilder(
ns_name=tower_name, vs_name=self._main_tower_vs_name,
device=device).build(input, self.tower_func)
tower = self.tower_func.towers[tower_name]
input_tensors = tower.get_tensors(input_names)
output_tensors = tower.get_tensors(output_names)
return OnlinePredictor(input_tensors, output_tensors)
@property
def _main_tower_vs_name(self):
"""
The vs name for the "main" copy of the model,
to be used to build predictors.
"""
return ""
@six.add_metaclass(ABCMeta)
class SingleCostTrainer(TowerTrainer):
"""
Base class for single-cost trainer.
Single-cost trainer has a :meth:`setup_graph` method which takes
(inputs_desc, input, get_cost_fn, get_opt_fn), and build the training operations from them.
To use a :class:`SingleCostTrainer` object, call `trainer.setup_graph(...); trainer.train(...)`.
"""
@call_only_once
def setup_graph(self, inputs_desc, input, get_cost_fn, get_opt_fn):
"""
Responsible for building the main training graph for single-cost training.
Args:
inputs_desc ([InputDesc]):
input (InputSource):
get_cost_fn ([tf.Tensor] -> tf.Tensor): callable, takes some input tenosrs and return a cost tensor.
get_opt_fn (-> tf.train.Optimizer): callable which returns an
optimizer. Will only be called once.
Note:
1. `get_cost_fn` will always be called under a :class:`TowerContext`.
which will contain information abouut reuse,
training/inference, scope name, etc.
2. `get_cost_fn` might get called multiple times for data-parallel training or inference.
3. To respect variable reuse, use `tf.get_variable` instead of
`tf.Variable` in `get_cost_fn`.
"""
get_cost_fn = TowerFuncWrapper(get_cost_fn, inputs_desc)
get_opt_fn = memoized(get_opt_fn)
self.set_tower_func(get_cost_fn)
input_callbacks = self._setup_input(inputs_desc, input)
train_callbacks = self._setup_graph(input, get_cost_fn, get_opt_fn)
internal_callbacks = input_callbacks + train_callbacks
for cb in internal_callbacks:
self._register_callback(cb)
# TODO register directly instead of return?
@abstractmethod
def _setup_graph(self, input, get_cost_fn, get_opt_fn):
"""
Implement the logic to build the graph, with an :class:`InputSource`
that's been setup already.
Returns:
[Callback]: list of callbacks needed
"""
def _setup_input(self, inputs_desc, input):
assert not input.setup_done()
return input.setup(inputs_desc)
def _make_get_grad_fn(self, input, get_cost_fn, get_opt_fn):
"""
Returns:
a get_grad_fn for GraphBuilder to use.
"""
# internal use only
assert input.setup_done()
def get_grad_fn():
ctx = get_current_tower_context()
cost = get_cost_fn(*input.get_input_tensors())
varlist = ctx.filter_vars_by_vs_name(tf.trainable_variables())
opt = get_opt_fn()
grads = opt.compute_gradients(
cost, var_list=varlist,
gate_gradients=False, colocate_gradients_with_ops=True)
grads = FilterNoneGrad().process(grads)
return grads
return get_grad_fn
......@@ -21,7 +21,7 @@ from ..graph_builder.training import (
from ..graph_builder.distributed import DistributedReplicatedBuilder
from ..graph_builder.utils import override_to_local_variable
from .base import SingleCostTrainer
from .tower import SingleCostTrainer
__all__ = ['SimpleTrainer',
'QueueInputTrainer',
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment