Commit a2934281 authored by Yuxin Wu's avatar Yuxin Wu

docs update

parent a71ff4d7
...@@ -62,7 +62,7 @@ You can also define your own trainer for non-standard training (e.g. GAN). ...@@ -62,7 +62,7 @@ You can also define your own trainer for non-standard training (e.g. GAN).
Dependencies: Dependencies:
+ Python 2 or 3 + Python 2 or 3
+ TensorFlow >= 1.0.0rc0 + TensorFlow >= 1.0.0rc1
+ Python bindings for OpenCV + Python bindings for OpenCV
``` ```
pip install --user -U git+https://github.com/ppwwyyxx/tensorpack.git pip install --user -U git+https://github.com/ppwwyyxx/tensorpack.git
......
...@@ -34,8 +34,9 @@ that we can measure the speed of this DataFlow in terms of "batch per second". B ...@@ -34,8 +34,9 @@ that we can measure the speed of this DataFlow in terms of "batch per second". B
will concatenate the data into an `numpy.ndarray`, but since images are originally of different shapes, we use will concatenate the data into an `numpy.ndarray`, but since images are originally of different shapes, we use
`use_list=True` so that it just produces lists. `use_list=True` so that it just produces lists.
On an SSD you probably can already observe good speed here (e.g. 5 it/s), but on HDD the speed may be just 1 it/s, On an SSD you probably can already observe good speed here (e.g. 5 it/s, that is 1280 samples/s), but on HDD the speed may be just 1 it/s,
because we're doing random read on the filesystem (regardless of whether `shuffle` is True). because we're doing heavy random read on the filesystem (regardless of whether `shuffle` is True).
Note that for smaller datasets, random read + prefetching is usually enough.
We'll now add the cheapest pre-processing now to get an ndarray in the end instead of a list We'll now add the cheapest pre-processing now to get an ndarray in the end instead of a list
(because TensorFlow will need ndarray eventually): (because TensorFlow will need ndarray eventually):
...@@ -176,7 +177,7 @@ Both imdecode and the augmentors can be quite slow. We can parallelize them like ...@@ -176,7 +177,7 @@ Both imdecode and the augmentors can be quite slow. We can parallelize them like
Since we are reading the database sequentially, having multiple identical instances of the Since we are reading the database sequentially, having multiple identical instances of the
underlying DataFlow will result in biased data distribution. Therefore we use `PrefetchData` to underlying DataFlow will result in biased data distribution. Therefore we use `PrefetchData` to
launch the underlying DataFlow in one independent process, and only parallelize the transformations. launch the underlying DataFlow in one independent process, and only parallelize the transformations.
(`PrefetchDataZMQ` is faster but not fork-safe, so the first prefetch has to be `PrefetchData`. This is [issue#138]) (`PrefetchDataZMQ` is faster but not fork-safe, so the first prefetch has to be `PrefetchData`. This is [issue#138](https://github.com/ppwwyyxx/tensorpack/issues/138))
Let me summarize what the above DataFlow does: Let me summarize what the above DataFlow does:
1. One process reads LMDB file, shuffle them in a buffer and put them into a `multiprocessing.Queue` (used by `PrefetchData`). 1. One process reads LMDB file, shuffle them in a buffer and put them into a `multiprocessing.Queue` (used by `PrefetchData`).
...@@ -186,7 +187,7 @@ Let me summarize what the above DataFlow does: ...@@ -186,7 +187,7 @@ Let me summarize what the above DataFlow does:
how the `Trainer` is implemented. how the `Trainer` is implemented.
The above DataFlow can run at a speed of 5~10 batches per second, if you have good CPUs, RAM, disks and augmentors. The above DataFlow can run at a speed of 5~10 batches per second, if you have good CPUs, RAM, disks and augmentors.
As a reference, tensorpack can train ResNet-18 (a shallow ResNet) at 5.5 batches per second on 4 TitanX Pascal. As a reference, tensorpack can train ResNet-18 (a shallow ResNet) at 4.4 batches (of 256 samples) per second on 4 old TitanX.
So DataFlow won't be a serious bottleneck if configured properly. So DataFlow won't be a serious bottleneck if configured properly.
## Larger Datasets? ## Larger Datasets?
......
...@@ -28,15 +28,19 @@ Double-DQN runs at 18 batches/s (1152 frames/s) on TitanX. ...@@ -28,15 +28,19 @@ Double-DQN runs at 18 batches/s (1152 frames/s) on TitanX.
## How to use ## How to use
Download an [atari rom](https://github.com/openai/atari-py/tree/master/atari_py/atari_roms) to Download an [atari rom](https://github.com/openai/atari-py/tree/master/atari_py/atari_roms) to
`$TENSORPACK_DATASET/atari_rom/` (defaults to tensorpack/dataflow/dataset/atari_rom/). `$TENSORPACK_DATASET/atari_rom/` (defaults to ~/tensorpack_data/atari_rom/), e.g.:
```
mkdir -p ~/tensorpack_data/atari_rom
wget https://github.com/openai/atari-py/raw/master/atari_py/atari_roms/breakout.bin -O ~/tensorpack_data/atari_rom/breakout.bin
```
To train: Start Training:
``` ```
./DQN.py --rom breakout.bin ./DQN.py --rom breakout.bin
# use `--algo` to select other DQN algorithms. See `-h` for more options. # use `--algo` to select other DQN algorithms. See `-h` for more options.
``` ```
To watch the agent play: Watch the agent play:
``` ```
./DQN.py --rom breakout.bin --task play --load trained.model ./DQN.py --rom breakout.bin --task play --load trained.model
``` ```
......
...@@ -6,11 +6,8 @@ ...@@ -6,11 +6,8 @@
import tensorflow as tf import tensorflow as tf
import numpy as np import numpy as np
import time import time
from tensorpack import (FeedfreeTrainerBase, TowerContext, from tensorpack import (FeedfreeTrainerBase, QueueInput, ModelDesc, DataFlow)
QueueInput, ModelDesc)
from tensorpack.tfutils.summary import add_moving_summary from tensorpack.tfutils.summary import add_moving_summary
from tensorpack.tfutils.gradproc import apply_grad_processors, CheckGradient
from tensorpack.dataflow import DataFlow
class GANModelDesc(ModelDesc): class GANModelDesc(ModelDesc):
...@@ -19,12 +16,8 @@ class GANModelDesc(ModelDesc): ...@@ -19,12 +16,8 @@ class GANModelDesc(ModelDesc):
Assign self.g_vars to the parameters under scope `g_scope`, Assign self.g_vars to the parameters under scope `g_scope`,
and same with self.d_vars. and same with self.d_vars.
""" """
all_vars = tf.trainable_variables() self.g_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, g_scope)
self.g_vars = [v for v in all_vars if v.name.startswith(g_scope + '/')] self.d_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, d_scope)
self.d_vars = [v for v in all_vars if v.name.startswith(d_scope + '/')]
# TODO after TF1.0.0rc1
# self.g_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, g_scope)
# self.d_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, d_scope)
def build_losses(self, logits_real, logits_fake): def build_losses(self, logits_real, logits_fake):
"""D and G play two-player minimax game with value function V(G,D) """D and G play two-player minimax game with value function V(G,D)
......
...@@ -54,4 +54,4 @@ Train a simple GAN on mnist, conditioned on the class labels. ...@@ -54,4 +54,4 @@ Train a simple GAN on mnist, conditioned on the class labels.
## WGAN-CelebA.py ## WGAN-CelebA.py
Reproduce WGAN by some small modifications on DCGAN-CelebA.py. Reproduce Wasserstein GAN by some small modifications on DCGAN-CelebA.py.
...@@ -70,7 +70,7 @@ def get_config(): ...@@ -70,7 +70,7 @@ def get_config():
class WGANTrainer(FeedfreeTrainerBase): class WGANTrainer(FeedfreeTrainerBase):
""" A new trainer which runs two optimization ops with 5:1 ratio. """ A new trainer which runs two optimization ops with 5:1 ratio.
This is to be consistent with the original code, but I found just This is to be consistent with the original code, but I found just
running them 1:1 (i.e. using the existing GANTrainer) also works well. running them 1:1 (i.e. just using the existing GANTrainer) also works well.
""" """
def __init__(self, config): def __init__(self, config):
self._input_method = QueueInput(config.dataflow) self._input_method = QueueInput(config.dataflow)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment