Commit fde338ea authored by Yuxin Wu's avatar Yuxin Wu

docs and GAN trainer upgrade

parent 353bb6c5
# Build the Graph
### ModelDesc
`ModelDesc` is an abstraction over the most common type of models people train:
It assumes:
1. Training is a single-cost optimized by a single `tf.train.Optimizer`.
2. The graph can be trivially duplicated for data-parallel training or inference.
If your task is single-cost optimization,
you can subclass `ModelDesc` and implement several methods:
```python
class MyModel(ModelDesc):
def _get_inputs(self):
return [InputDesc(...), InputDesc(...)]
def _build_graph(self, inputs):
tensorA, tensorB = inputs
# build the graph
def _get_optimizer(self):
return tf.train.GradientDescentOptimizer(0.1)
```
`_get_inputs` should define the metainfo of all the inputs your graph may need.
`_build_graph` should add tensors/operations to the graph, where
the argument `inputs` is the list of input tensors matching `_get_inputs`.
You can use any symbolic functions in `_build_graph`, including TensorFlow core library
functions and other symbolic libraries.
Most tensorpack trainers expect a `ModelDesc`.
The trainers will call these methods to create the model,
connect `InputSource` to the model, create the minimization op, and so on.
Data-parallel Multi-GPU trainers will call `_build_graph` __multiple times__ on each GPU.
A trainer may also make __extra calls__ to `_build_graph` for inference, if used by some callbacks.
### Build It Manually
When you need to deal with complicated graph, it may be easier to build the graph manually.
You are free to do so as long as you tell the trainer what to do in each step.
More details to come.
...@@ -37,7 +37,8 @@ User Tutorials ...@@ -37,7 +37,8 @@ User Tutorials
dataflow dataflow
input-source input-source
efficient-dataflow efficient-dataflow
model graph
symbolic
trainer trainer
callback callback
faq faq
......
# Model # Symbolic Layers
To define a model (i.e. the computation graph) that will be used for training,
you'll need to subclass `ModelDesc` and implement several methods:
```python
class MyModel(ModelDesc):
def _get_inputs(self):
return [InputDesc(...), InputDesc(...)]
def _build_graph(self, inputs):
tensorA, tensorB = inputs
# build the graph
def _get_optimizer(self):
return tf.train.GradientDescentOptimizer(0.1)
```
`_get_inputs` should define the metainfo of all the inputs your graph may need.
`_build_graph` should add tensors/operations to the graph, where
the argument `inputs` is the list of input tensors matching `_get_inputs`.
You can use any symbolic functions in `_build_graph`, including TensorFlow core library
functions and other symbolic libraries.
While you can use other symbolic libraries,
tensorpack also contains a small collection of common model primitives, tensorpack also contains a small collection of common model primitives,
such as conv/deconv, fc, batch normalization, pooling layers, and some custom loss functions. such as conv/deconv, fc, batch normalization, pooling layers, and some custom loss functions.
Using the tensorpack implementations, you can also benefit from `argscope` and `LinearWrap` to Using the tensorpack implementations, you can also benefit from `argscope` and `LinearWrap` to
simplify the code. simplify the code.
Note that the layers are written because there are no other alternatives back at that time.
In the future we may shift to `tf.layers` because they will be better maintained.
### argscope and LinearWrap ### argscope and LinearWrap
`argscope` gives you a context with default arguments. `argscope` gives you a context with default arguments.
`LinearWrap` allows you to simplify "linear structure" models by `LinearWrap` allows you to simplify "linear structure" models by
...@@ -63,9 +44,7 @@ l = FullyConnected('fc1', l, 10, nl=tf.identity) ...@@ -63,9 +44,7 @@ l = FullyConnected('fc1', l, 10, nl=tf.identity)
### Use Models outside Tensorpack ### Use Models outside Tensorpack
You can use tensorpack models alone as a simple symbolic function library, and write your own You can use tensorpack models alone as a simple symbolic function library.
training code instead of using tensorpack trainers.
To do this, just enter a [TowerContext](http://tensorpack.readthedocs.io/en/latest/modules/tfutils.html#tensorpack.tfutils.TowerContext) To do this, just enter a [TowerContext](http://tensorpack.readthedocs.io/en/latest/modules/tfutils.html#tensorpack.tfutils.TowerContext)
when you define your model: when you define your model:
```python ```python
...@@ -85,8 +64,10 @@ with tf.variable_scope(tf.get_variable_scope(), reuse=True), TowerContext('predi ...@@ -85,8 +64,10 @@ with tf.variable_scope(tf.get_variable_scope(), reuse=True), TowerContext('predi
When defining the model you can construct the graph using whatever library you feel comfortable with. When defining the model you can construct the graph using whatever library you feel comfortable with.
Usually, slim/tflearn/tensorlayer are just symbolic functions, calling them is nothing different Usually, slim/tflearn/tensorlayer are just symbolic functions, calling them is nothing different
from calling `tf.add`. However it is a bit different to use sonnet/Keras. from calling `tf.add`. You may need to be careful how regularizations/BN updates are supposed
to be handled in those libraries, though.
It is a bit different to use sonnet/Keras.
sonnet/Keras manages the variable scope by their own model classes, and calling their symbolic functions sonnet/Keras manages the variable scope by their own model classes, and calling their symbolic functions
always creates new variable scope. See the [Keras example](../examples/mnist-keras.py) for how to always creates new variable scope. See the [Keras example](../examples/mnist-keras.py) for how to use it within tensorpack.
use them within tensorpack. The support is only preliminary for now.
...@@ -6,14 +6,15 @@ ...@@ -6,14 +6,15 @@
import tensorflow as tf import tensorflow as tf
import numpy as np import numpy as np
import time import time
from tensorpack import (FeedfreeTrainerBase, QueueInput, from tensorpack import (Trainer, QueueInput,
ModelDesc, DataFlow, StagingInputWrapper, ModelDescBase, DataFlow, StagingInputWrapper,
MultiGPUTrainerBase, LeastLoadedDeviceSetter, MultiGPUTrainerBase, LeastLoadedDeviceSetter,
TowerContext) TowerContext)
from tensorpack.tfutils.summary import add_moving_summary from tensorpack.tfutils.summary import add_moving_summary
from tensorpack.utils.argtools import memoized
class GANModelDesc(ModelDesc): class GANModelDesc(ModelDescBase):
def collect_variables(self, g_scope='gen', d_scope='discrim'): def collect_variables(self, g_scope='gen', d_scope='discrim'):
""" """
Assign self.g_vars to the parameters under scope `g_scope`, Assign self.g_vars to the parameters under scope `g_scope`,
...@@ -58,14 +59,18 @@ class GANModelDesc(ModelDesc): ...@@ -58,14 +59,18 @@ class GANModelDesc(ModelDesc):
add_moving_summary(self.g_loss, self.d_loss, d_accuracy, g_accuracy) add_moving_summary(self.g_loss, self.d_loss, d_accuracy, g_accuracy)
@memoized
def get_optimizer(self):
return self._get_optimizer()
class GANTrainer(FeedfreeTrainerBase):
class GANTrainer(Trainer):
def __init__(self, config): def __init__(self, config):
self._input_source = QueueInput(config.dataflow) self._input_source = QueueInput(config.dataflow)
super(GANTrainer, self).__init__(config) super(GANTrainer, self).__init__(config)
def _setup(self): def _setup(self):
super(GANTrainer, self)._setup() self._setup_input_source(self._input_source)
with TowerContext('', is_training=True): with TowerContext('', is_training=True):
self.model.build_graph(self._input_source) self.model.build_graph(self._input_source)
opt = self.model.get_optimizer() opt = self.model.get_optimizer()
...@@ -77,7 +82,7 @@ class GANTrainer(FeedfreeTrainerBase): ...@@ -77,7 +82,7 @@ class GANTrainer(FeedfreeTrainerBase):
self.train_op = d_min self.train_op = d_min
class SeparateGANTrainer(FeedfreeTrainerBase): class SeparateGANTrainer(Trainer):
""" A GAN trainer which runs two optimization ops with a certain ratio, one in each step. """ """ A GAN trainer which runs two optimization ops with a certain ratio, one in each step. """
def __init__(self, config, d_period=1, g_period=1): def __init__(self, config, d_period=1, g_period=1):
""" """
...@@ -92,7 +97,7 @@ class SeparateGANTrainer(FeedfreeTrainerBase): ...@@ -92,7 +97,7 @@ class SeparateGANTrainer(FeedfreeTrainerBase):
super(SeparateGANTrainer, self).__init__(config) super(SeparateGANTrainer, self).__init__(config)
def _setup(self): def _setup(self):
super(SeparateGANTrainer, self)._setup() self._setup_input_source(self._input_source)
with TowerContext('', is_training=True): with TowerContext('', is_training=True):
self.model.build_graph(self._input_source) self.model.build_graph(self._input_source)
...@@ -111,19 +116,19 @@ class SeparateGANTrainer(FeedfreeTrainerBase): ...@@ -111,19 +116,19 @@ class SeparateGANTrainer(FeedfreeTrainerBase):
self._cnt += 1 self._cnt += 1
class MultiGPUGANTrainer(MultiGPUTrainerBase, FeedfreeTrainerBase): class MultiGPUGANTrainer(Trainer):
""" """
A replacement of GANTrainer (optimize d and g one by one) with multi-gpu support. A replacement of GANTrainer (optimize d and g one by one) with multi-gpu support.
""" """
def __init__(self, config): def __init__(self, config):
super(MultiGPUGANTrainer, self).__init__(config)
self._nr_gpu = config.nr_tower self._nr_gpu = config.nr_tower
assert self._nr_gpu > 1 assert self._nr_gpu > 1
self._raw_devices = ['/gpu:{}'.format(k) for k in self.config.tower] self._raw_devices = ['/gpu:{}'.format(k) for k in config.tower]
self._input_source = StagingInputWrapper(QueueInput(config.dataflow), self._raw_devices) self._input_source = StagingInputWrapper(QueueInput(config.dataflow), self._raw_devices)
super(MultiGPUGANTrainer, self).__init__(config)
def _setup(self): def _setup(self):
super(MultiGPUGANTrainer, self)._setup() self._setup_input_source(self._input_source)
devices = [LeastLoadedDeviceSetter(d, self._raw_devices) for d in self._raw_devices] devices = [LeastLoadedDeviceSetter(d, self._raw_devices) for d in self._raw_devices]
def get_cost(): def get_cost():
......
...@@ -14,8 +14,7 @@ from ..tfutils.gradproc import FilterNoneGrad ...@@ -14,8 +14,7 @@ from ..tfutils.gradproc import FilterNoneGrad
from .input_source_base import InputSource from .input_source_base import InputSource
from ..models.regularize import regularize_cost_from_collection from ..models.regularize import regularize_cost_from_collection
__all__ = ['InputDesc', 'ModelDesc'] __all__ = ['InputDesc', 'ModelDesc', 'ModelDescBase']
# don't expose ModelDescBase for use right now. API wasn't final.
class InputDesc( class InputDesc(
......
...@@ -7,7 +7,7 @@ from ..callbacks import ( ...@@ -7,7 +7,7 @@ from ..callbacks import (
ProgressBar, MergeAllSummaries, ProgressBar, MergeAllSummaries,
TFEventWriter, JSONWriter, ScalarPrinter, RunUpdateOps) TFEventWriter, JSONWriter, ScalarPrinter, RunUpdateOps)
from ..dataflow.base import DataFlow from ..dataflow.base import DataFlow
from ..graph_builder.model_desc import ModelDesc from ..graph_builder.model_desc import ModelDescBase
from ..utils import logger from ..utils import logger
from ..utils.develop import log_deprecated from ..utils.develop import log_deprecated
from ..tfutils import (JustCurrentSession, from ..tfutils import (JustCurrentSession,
...@@ -39,7 +39,7 @@ class TrainConfig(object): ...@@ -39,7 +39,7 @@ class TrainConfig(object):
Args: Args:
dataflow (DataFlow): dataflow (DataFlow):
data (InputSource): data (InputSource):
model (ModelDesc): model (ModelDescBase):
callbacks (list): a list of :class:`Callback` to perform during training. callbacks (list): a list of :class:`Callback` to perform during training.
extra_callbacks (list): the same as ``callbacks``. This argument extra_callbacks (list): the same as ``callbacks``. This argument
...@@ -82,7 +82,7 @@ class TrainConfig(object): ...@@ -82,7 +82,7 @@ class TrainConfig(object):
assert_type(self.data, InputSource) assert_type(self.data, InputSource)
self.dataflow = None self.dataflow = None
if model is not None: if model is not None:
assert_type(model, ModelDesc) assert_type(model, ModelDescBase)
self.model = model self.model = model
if callbacks is None: if callbacks is None:
......
...@@ -15,6 +15,7 @@ __all__ = ['FeedfreeTrainerBase', 'SingleCostFeedfreeTrainer', ...@@ -15,6 +15,7 @@ __all__ = ['FeedfreeTrainerBase', 'SingleCostFeedfreeTrainer',
'SimpleFeedfreeTrainer', 'QueueInputTrainer'] 'SimpleFeedfreeTrainer', 'QueueInputTrainer']
# TODO deprecate it some time
class FeedfreeTrainerBase(Trainer): class FeedfreeTrainerBase(Trainer):
""" A base trainer which runs iteration without feed_dict (therefore faster) """ A base trainer which runs iteration without feed_dict (therefore faster)
Expect ``config.data`` to be a :class:`FeedfreeInput`. Expect ``config.data`` to be a :class:`FeedfreeInput`.
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment