Commit f3923789 authored by Yuxin Wu's avatar Yuxin Wu

some docs

parent 63bf3d12
# Dataflow # Dataflow
Dataflow uses Python generator to produce data. Dataflow is a unified interface to produce data.
A Dataflow has to implement the `get_data()` generator method, which yields `datapoints` when called. A Dataflow has a `get_data()` generator method,
A datapoint must be a list of Python object, each is called a `component` of the datapoint. which yields a `datapoint` when called.
For example, to train on MNIST dataset, you can define a Dataflow that produces datapoints of shape `[(BATCH, 28, 28), (BATCH,)]`. A datapoint must be a **list** of Python objects which I called the `components` of this datapoint.
Then, multiple Dataflows can be composed together to build a complex data-preprocessing pipeline, For example, to train on MNIST dataset, you can define a Dataflow
including __reading from disk, batching, augmentations, prefetching__, etc. These components written in Python that produces datapoints of shape `[(BATCH, 28, 28), (BATCH,)]`.
can provide a more flexible data pipeline than with TensorFlow operators.
Take a look at [common Dataflow](../../tensorpack/dataflow/common.py) and a [example of use](../../examples/ResNet/cifar10-resnet.py#L125).
Optionally, Dataflow can implement the following two methods: ### Composition of DataFlow
One good thing about having a standard interface is to be able to provide
the greatest code reusablility.
There are a lot of existing modules in tensorpack which you can use to compose
complex Dataflow instances with a long pre-processing pipeline. A whole pipeline usually
includes __read from disk (or other sources), augmentations, group into batches,
prefetching__, etc. An example is as the following:
+ `size()`. Return the number of elements. Some components in the pipeline might require this to be ````python
implemented. For example, only Dataflows with the same number of elements can be [joined](../../tensorpack/dataflow/common.py#L276). # define a Dataflow which produces image-label pairs from a caffe lmdb database
ds = CaffeLMDB('/path/to/caffe/lmdb', shuffle=False)
# resize the image component of each datapoint
ds = AugmentImageComponent(ds, [imgaug.Resize((225, 225))])
# group data into batches of size 128
ds = BatchData(ds, 128)
# start 3 processes to run the dataflow in parallel, and transfer the data with ZeroMQ
ds = PrefetchDataZMQ(ds, 3)
````
Another complicated example is the [ResNet training script](https://github.com/ppwwyyxx/tensorpack/blob/master/examples/ResNet/imagenet-resnet.py)
with all the data preprocessing.
+ `reset_state()`. It's necessary if your Dataflow uses RNG. This All these modules are written in Python,
method should reset the state of RNG and will be called after a fork, so that different child so you can easily implement whatever opeartions/transformations you need,
processes won't produce identical data. without worring about adding data reading operators to TensorFlow.
In the mean time, thanks to the prefetching, it can still run fast enough for
tasks as large as ImageNet training.
### Reuse in other frameworks
Another good thing about Dataflow is that it is independent of
tensorpack internals. You can just use it as an efficient data processing pipeline,
and plug it into other frameworks.
### Write your own Dataflow
There are several existing Dataflow, e.g. ImageFromFile, DataFromList, which you can
use to read images or load data from a list.
But in general, you'll probably need to write a new Dataflow to produce data for your task.
Dataflow implementations for several well-known datasets are provided in the
[dataflow.dataset](http://tensorpack.readthedocs.io/en/latest/modules/tensorpack.dataflow.dataset.html)
module, which you can take as a reference.
A Dataflow has a `get_data()` method which yields a datapoint every time.
```python
class MyDataFlow(DataFlow):
def get_data(self):
for k in range(100):
yield datapoint
```
NOTE: Dataflow aims to be independent of tensorflow. Optionally, Dataflow can implement the following two methods:
It should be useful for other python-based learning libraries as well.
+ `size()`. Return the number of elements the generator can produce. Certain modules might require this to be
implemented. For example, only Dataflows with the same number of elements can be joined together.
+ `reset_state()`. It's necessary if your Dataflow uses RNG. This
method should reset the internal state of this Dataflow (including RNG). It get called after a fork, so that different child
processes will have different random seed.
Common public datasets are also a kind of Dataflow. Some are defined in [dataflow.dataset](../../tensorpack/dataflow/dataset). With this "low-level" Dataflow implemented, you can then compose it with existing modules.
...@@ -3,21 +3,6 @@ ...@@ -3,21 +3,6 @@
The following guide introduces some core concepts of TensorPack. In contrast to several other libraries TensorPack contains of several modules to build complex deep learning algorithms and train models with high accuracy and high speed. The following guide introduces some core concepts of TensorPack. In contrast to several other libraries TensorPack contains of several modules to build complex deep learning algorithms and train models with high accuracy and high speed.
### DataFlow
To train neural network architectures on extensive training data this library provides a data flow mechanism. This consists of several readers, mappings (e.g. image augmentations) and efficient prefetching.
The following code reads images from a database produced in the fashion of Caffe and add several modifications such as resizing them to $255\times 255$ and converting these to gray-scale images.
````python
ds = CaffeLMDB('/path/to/caffe/lmdb', shuffle=False)
ds = AugmentImageComponent(ds, [imgaug.Resize((225, 225))])
ds = MapData(ds, lambda dp: [np.dot(dp[0], [0.299, 0.587, 0.114])[:, :]])
ds = BatchData(ds, 128)
ds = PrefetchData(ds, 3, 2)
````
In addition, the input data is gathered in batches of 128 entries and prefetched in an extra process to avoid slow-downs due to GIL.
### Layers and Architectures ### Layers and Architectures
The library also contains several pre-implemented neural network modules and layers: The library also contains several pre-implemented neural network modules and layers:
- Convolution, Deconvolution - Convolution, Deconvolution
...@@ -51,7 +36,7 @@ You only need to configure your training protocol like ...@@ -51,7 +36,7 @@ You only need to configure your training protocol like
````python ````python
config = TrainConfig( config = TrainConfig(
dataflow=my_dataflow, dataflow=my_dataflow,
optimizer=tf.train.AdamOptimizer(lr), optimizer=tf.train.AdamOptimizer(lr),
callbacks=Callbacks([ModelSaver(), ...]), callbacks=Callbacks([ModelSaver(), ...]),
model=Model()) model=Model())
...@@ -69,4 +54,4 @@ SyncMultiGPUTrainer(config).train() ...@@ -69,4 +54,4 @@ SyncMultiGPUTrainer(config).train()
### Callbacks ### Callbacks
The use of callbacks add the flexibility to execute code during training. These callbacks are triggered on several events such as after each step or at the end of one training epoch. The use of callbacks add the flexibility to execute code during training. These callbacks are triggered on several events such as after each step or at the end of one training epoch.
\ No newline at end of file
...@@ -30,7 +30,6 @@ param.corpus = 'input.txt' ...@@ -30,7 +30,6 @@ param.corpus = 'input.txt'
class CharRNNData(RNGDataFlow): class CharRNNData(RNGDataFlow):
def __init__(self, input_file, size): def __init__(self, input_file, size):
self.seq_length = param.seq_len self.seq_length = param.seq_len
self._size = size self._size = size
...@@ -61,7 +60,6 @@ class CharRNNData(RNGDataFlow): ...@@ -61,7 +60,6 @@ class CharRNNData(RNGDataFlow):
class Model(ModelDesc): class Model(ModelDesc):
def _get_input_vars(self): def _get_input_vars(self):
return [InputVar(tf.int32, (None, param.seq_len), 'input'), return [InputVar(tf.int32, (None, param.seq_len), 'input'),
InputVar(tf.int32, (None, param.seq_len), 'nextinput')] InputVar(tf.int32, (None, param.seq_len), 'nextinput')]
......
...@@ -130,9 +130,9 @@ class Callback(object): ...@@ -130,9 +130,9 @@ class Callback(object):
class Triggerable(Callback): class Triggerable(Callback):
""" """
Base class for "triggerable" callback. It has a method :meth:`Triggerable.trigger()` Base class for "triggerable" callback. It has a method :meth:`Triggerable.trigger()`
which can be triggered either inside an epoch or between epochs. which can be called either inside an epoch or between epochs.
The higher-level wrapper will take the responsibility to determine when Other higher-level wrappers will take the responsibility to determine **when**
to trigger. to call the trigger.
If an triggerable is used as a callback directly (instead of under other If an triggerable is used as a callback directly (instead of under other
higher-level wrapper to control the trigger), it will by default trigger after higher-level wrapper to control the trigger), it will by default trigger after
...@@ -143,11 +143,7 @@ class Triggerable(Callback): ...@@ -143,11 +143,7 @@ class Triggerable(Callback):
""" """
Trigger something. Trigger something.
Note that this method may be called both inside an epoch and after an epoch. Note that this method may be called both inside an epoch and after an epoch.
Some operations (e.g. writing scalar stats) currently will cause
problems if run inside an epoch. This will be fixed in the future.
""" """
# TODO
self._trigger() self._trigger()
@abstractmethod @abstractmethod
......
...@@ -26,7 +26,7 @@ class ModelSaver(Triggerable): ...@@ -26,7 +26,7 @@ class ModelSaver(Triggerable):
keep_recent(int): see ``tf.train.Saver`` documentation. keep_recent(int): see ``tf.train.Saver`` documentation.
keep_freq(int): see ``tf.train.Saver`` documentation. keep_freq(int): see ``tf.train.Saver`` documentation.
checkpoint_dir (str): Defaults to ``logger.LOG_DIR``. checkpoint_dir (str): Defaults to ``logger.LOG_DIR``.
var_collections (str or list): the variable collection (or list of collections) o save. var_collections (str or list of str): collection of the variables (or list of collections) to save.
""" """
self.keep_recent = keep_recent self.keep_recent = keep_recent
self.keep_freq = keep_freq self.keep_freq = keep_freq
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment