Commit e49d4fd4 authored by Yuxin Wu's avatar Yuxin Wu

Add notes about using sonnet (fix #222)

parent 2ba9c3cd
...@@ -33,8 +33,9 @@ The examples are not only for demonstration of the framework -- you can train th ...@@ -33,8 +33,9 @@ The examples are not only for demonstration of the framework -- you can train th
It's Yet Another TF wrapper, but different in: It's Yet Another TF wrapper, but different in:
1. Not focus on models. 1. Not focus on models.
+ It includes only a few common models, and helpful tools such as `LinearWrap` to simplify large models. + There are already too many symbolic function wrappers.
But you can use any other TF wrappers here, such as slim/tflearn/tensorlayer. Tensorpack includes only a few common models, and helpful tools such as `LinearWrap` to simplify large models.
But you can use any other wrappers within tensorpack, such as sonnet/Keras/slim/tflearn/tensorlayer/....
2. Focus on large datasets. 2. Focus on large datasets.
+ __DataFlow__ allows you to process large datasets such as ImageNet in Python without blocking the training. + __DataFlow__ allows you to process large datasets such as ImageNet in Python without blocking the training.
......
...@@ -21,7 +21,7 @@ Basically, `_get_inputs` should define the metainfo of all the possible placehol ...@@ -21,7 +21,7 @@ Basically, `_get_inputs` should define the metainfo of all the possible placehol
the argument `input_tensors` is the list of input tensors matching `_get_inputs`. the argument `input_tensors` is the list of input tensors matching `_get_inputs`.
You can use any symbolic functions in `_build_graph`, including TensorFlow core library You can use any symbolic functions in `_build_graph`, including TensorFlow core library
functions, TensorFlow slim layers, or functions in other packages such as tflean, tensorlayer. functions and other symbolic libraries (see below).
tensorpack also contains a small collection of common model primitives, tensorpack also contains a small collection of common model primitives,
such as conv/deconv, fc, batch normalization, pooling layers, and some custom loss functions. such as conv/deconv, fc, batch normalization, pooling layers, and some custom loss functions.
...@@ -62,12 +62,23 @@ l = FullyConnected('fc1', l, 10, nl=tf.identity) ...@@ -62,12 +62,23 @@ l = FullyConnected('fc1', l, 10, nl=tf.identity)
### Use Models outside Tensorpack ### Use Models outside Tensorpack
You can use the tensorpack models alone as a simple symbolic function library, and write your own You can use tensorpack models alone as a simple symbolic function library, and write your own
training code instead of using tensorpack trainers. training code instead of using tensorpack trainers.
To do this, just enter a [TowerContext](http://tensorpack.readthedocs.io/en/latest/modules/tfutils.html#tensorpack.tfutils.TowerContext) To do this, just enter a [TowerContext](http://tensorpack.readthedocs.io/en/latest/modules/tfutils.html#tensorpack.tfutils.TowerContext)
when you define your model: when you define your model:
```python ```python
with TowerContext('', is_training=True): with TowerContext('', is_training=True):
# call any tensorpack symbolic functions # call any tensorpack layer
``` ```
### Use Other Symbolic Libraries within Tensorpack
When defining the model you can construct the graph using whatever library you feel comfortable with.
Usually, slim/tflearn/tensorlayer are just symbolic functions, calling them is nothing different
from calling `tf.add`. However it's a bit different to use sonnet/Keras.
sonnet/Keras manages the variable scope by their own model classes, and calling their symbolic functions
always creates new variable scope. See the [Keras example](../examples/mnist-keras.py) for how to
use them within tensorpack.
...@@ -32,7 +32,7 @@ class Model(ModelDesc): ...@@ -32,7 +32,7 @@ class Model(ModelDesc):
InputDesc(tf.int32, (None,), 'label'), InputDesc(tf.int32, (None,), 'label'),
] ]
@memoized # this is necessary for Keras to work under tensorpack @memoized # this is necessary for sonnet/Keras to work under tensorpack
def _build_keras_model(self): def _build_keras_model(self):
M = Sequential() M = Sequential()
M.add(KL.Conv2D(32, 3, activation='relu', input_shape=[IMAGE_SIZE, IMAGE_SIZE, 1], padding='same')) M.add(KL.Conv2D(32, 3, activation='relu', input_shape=[IMAGE_SIZE, IMAGE_SIZE, 1], padding='same'))
...@@ -83,6 +83,7 @@ class Model(ModelDesc): ...@@ -83,6 +83,7 @@ class Model(ModelDesc):
return tf.train.AdamOptimizer(lr) return tf.train.AdamOptimizer(lr)
# Keras needs an extra input
class KerasCallback(Callback): class KerasCallback(Callback):
def __init__(self, isTrain): def __init__(self, isTrain):
self._isTrain = isTrain self._isTrain = isTrain
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment