Commit fde338ea authored by Yuxin Wu's avatar Yuxin Wu

docs and GAN trainer upgrade

parent 353bb6c5
# Build the Graph
### ModelDesc
`ModelDesc` is an abstraction over the most common type of models people train:
It assumes:
1. Training is a single-cost optimized by a single `tf.train.Optimizer`.
2. The graph can be trivially duplicated for data-parallel training or inference.
If your task is single-cost optimization,
you can subclass `ModelDesc` and implement several methods:
```python
class MyModel(ModelDesc):
def _get_inputs(self):
return [InputDesc(...), InputDesc(...)]
def _build_graph(self, inputs):
tensorA, tensorB = inputs
# build the graph
def _get_optimizer(self):
return tf.train.GradientDescentOptimizer(0.1)
```
`_get_inputs` should define the metainfo of all the inputs your graph may need.
`_build_graph` should add tensors/operations to the graph, where
the argument `inputs` is the list of input tensors matching `_get_inputs`.
You can use any symbolic functions in `_build_graph`, including TensorFlow core library
functions and other symbolic libraries.
Most tensorpack trainers expect a `ModelDesc`.
The trainers will call these methods to create the model,
connect `InputSource` to the model, create the minimization op, and so on.
Data-parallel Multi-GPU trainers will call `_build_graph` __multiple times__ on each GPU.
A trainer may also make __extra calls__ to `_build_graph` for inference, if used by some callbacks.
### Build It Manually
When you need to deal with complicated graph, it may be easier to build the graph manually.
You are free to do so as long as you tell the trainer what to do in each step.
More details to come.
......@@ -37,7 +37,8 @@ User Tutorials
dataflow
input-source
efficient-dataflow
model
graph
symbolic
trainer
callback
faq
......
# Model
To define a model (i.e. the computation graph) that will be used for training,
you'll need to subclass `ModelDesc` and implement several methods:
```python
class MyModel(ModelDesc):
def _get_inputs(self):
return [InputDesc(...), InputDesc(...)]
def _build_graph(self, inputs):
tensorA, tensorB = inputs
# build the graph
def _get_optimizer(self):
return tf.train.GradientDescentOptimizer(0.1)
```
`_get_inputs` should define the metainfo of all the inputs your graph may need.
`_build_graph` should add tensors/operations to the graph, where
the argument `inputs` is the list of input tensors matching `_get_inputs`.
You can use any symbolic functions in `_build_graph`, including TensorFlow core library
functions and other symbolic libraries.
# Symbolic Layers
While you can use other symbolic libraries,
tensorpack also contains a small collection of common model primitives,
such as conv/deconv, fc, batch normalization, pooling layers, and some custom loss functions.
Using the tensorpack implementations, you can also benefit from `argscope` and `LinearWrap` to
simplify the code.
Note that the layers are written because there are no other alternatives back at that time.
In the future we may shift to `tf.layers` because they will be better maintained.
### argscope and LinearWrap
`argscope` gives you a context with default arguments.
`LinearWrap` allows you to simplify "linear structure" models by
......@@ -63,9 +44,7 @@ l = FullyConnected('fc1', l, 10, nl=tf.identity)
### Use Models outside Tensorpack
You can use tensorpack models alone as a simple symbolic function library, and write your own
training code instead of using tensorpack trainers.
You can use tensorpack models alone as a simple symbolic function library.
To do this, just enter a [TowerContext](http://tensorpack.readthedocs.io/en/latest/modules/tfutils.html#tensorpack.tfutils.TowerContext)
when you define your model:
```python
......@@ -85,8 +64,10 @@ with tf.variable_scope(tf.get_variable_scope(), reuse=True), TowerContext('predi
When defining the model you can construct the graph using whatever library you feel comfortable with.
Usually, slim/tflearn/tensorlayer are just symbolic functions, calling them is nothing different
from calling `tf.add`. However it is a bit different to use sonnet/Keras.
from calling `tf.add`. You may need to be careful how regularizations/BN updates are supposed
to be handled in those libraries, though.
It is a bit different to use sonnet/Keras.
sonnet/Keras manages the variable scope by their own model classes, and calling their symbolic functions
always creates new variable scope. See the [Keras example](../examples/mnist-keras.py) for how to
use them within tensorpack.
always creates new variable scope. See the [Keras example](../examples/mnist-keras.py) for how to use it within tensorpack.
The support is only preliminary for now.
......@@ -6,14 +6,15 @@
import tensorflow as tf
import numpy as np
import time
from tensorpack import (FeedfreeTrainerBase, QueueInput,
ModelDesc, DataFlow, StagingInputWrapper,
from tensorpack import (Trainer, QueueInput,
ModelDescBase, DataFlow, StagingInputWrapper,
MultiGPUTrainerBase, LeastLoadedDeviceSetter,
TowerContext)
from tensorpack.tfutils.summary import add_moving_summary
from tensorpack.utils.argtools import memoized
class GANModelDesc(ModelDesc):
class GANModelDesc(ModelDescBase):
def collect_variables(self, g_scope='gen', d_scope='discrim'):
"""
Assign self.g_vars to the parameters under scope `g_scope`,
......@@ -58,14 +59,18 @@ class GANModelDesc(ModelDesc):
add_moving_summary(self.g_loss, self.d_loss, d_accuracy, g_accuracy)
@memoized
def get_optimizer(self):
return self._get_optimizer()
class GANTrainer(FeedfreeTrainerBase):
class GANTrainer(Trainer):
def __init__(self, config):
self._input_source = QueueInput(config.dataflow)
super(GANTrainer, self).__init__(config)
def _setup(self):
super(GANTrainer, self)._setup()
self._setup_input_source(self._input_source)
with TowerContext('', is_training=True):
self.model.build_graph(self._input_source)
opt = self.model.get_optimizer()
......@@ -77,7 +82,7 @@ class GANTrainer(FeedfreeTrainerBase):
self.train_op = d_min
class SeparateGANTrainer(FeedfreeTrainerBase):
class SeparateGANTrainer(Trainer):
""" A GAN trainer which runs two optimization ops with a certain ratio, one in each step. """
def __init__(self, config, d_period=1, g_period=1):
"""
......@@ -92,7 +97,7 @@ class SeparateGANTrainer(FeedfreeTrainerBase):
super(SeparateGANTrainer, self).__init__(config)
def _setup(self):
super(SeparateGANTrainer, self)._setup()
self._setup_input_source(self._input_source)
with TowerContext('', is_training=True):
self.model.build_graph(self._input_source)
......@@ -111,19 +116,19 @@ class SeparateGANTrainer(FeedfreeTrainerBase):
self._cnt += 1
class MultiGPUGANTrainer(MultiGPUTrainerBase, FeedfreeTrainerBase):
class MultiGPUGANTrainer(Trainer):
"""
A replacement of GANTrainer (optimize d and g one by one) with multi-gpu support.
"""
def __init__(self, config):
super(MultiGPUGANTrainer, self).__init__(config)
self._nr_gpu = config.nr_tower
assert self._nr_gpu > 1
self._raw_devices = ['/gpu:{}'.format(k) for k in self.config.tower]
self._raw_devices = ['/gpu:{}'.format(k) for k in config.tower]
self._input_source = StagingInputWrapper(QueueInput(config.dataflow), self._raw_devices)
super(MultiGPUGANTrainer, self).__init__(config)
def _setup(self):
super(MultiGPUGANTrainer, self)._setup()
self._setup_input_source(self._input_source)
devices = [LeastLoadedDeviceSetter(d, self._raw_devices) for d in self._raw_devices]
def get_cost():
......
......@@ -14,8 +14,7 @@ from ..tfutils.gradproc import FilterNoneGrad
from .input_source_base import InputSource
from ..models.regularize import regularize_cost_from_collection
__all__ = ['InputDesc', 'ModelDesc']
# don't expose ModelDescBase for use right now. API wasn't final.
__all__ = ['InputDesc', 'ModelDesc', 'ModelDescBase']
class InputDesc(
......
......@@ -7,7 +7,7 @@ from ..callbacks import (
ProgressBar, MergeAllSummaries,
TFEventWriter, JSONWriter, ScalarPrinter, RunUpdateOps)
from ..dataflow.base import DataFlow
from ..graph_builder.model_desc import ModelDesc
from ..graph_builder.model_desc import ModelDescBase
from ..utils import logger
from ..utils.develop import log_deprecated
from ..tfutils import (JustCurrentSession,
......@@ -39,7 +39,7 @@ class TrainConfig(object):
Args:
dataflow (DataFlow):
data (InputSource):
model (ModelDesc):
model (ModelDescBase):
callbacks (list): a list of :class:`Callback` to perform during training.
extra_callbacks (list): the same as ``callbacks``. This argument
......@@ -82,7 +82,7 @@ class TrainConfig(object):
assert_type(self.data, InputSource)
self.dataflow = None
if model is not None:
assert_type(model, ModelDesc)
assert_type(model, ModelDescBase)
self.model = model
if callbacks is None:
......
......@@ -15,6 +15,7 @@ __all__ = ['FeedfreeTrainerBase', 'SingleCostFeedfreeTrainer',
'SimpleFeedfreeTrainer', 'QueueInputTrainer']
# TODO deprecate it some time
class FeedfreeTrainerBase(Trainer):
""" A base trainer which runs iteration without feed_dict (therefore faster)
Expect ``config.data`` to be a :class:`FeedfreeInput`.
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment