Commit c86cd15a authored by Yuxin Wu's avatar Yuxin Wu

update docs

parent 190ac8cf
...@@ -48,11 +48,17 @@ script: ...@@ -48,11 +48,17 @@ script:
- cd $TRAVIS_BUILD_DIR && python tests/test_examples.py - cd $TRAVIS_BUILD_DIR && python tests/test_examples.py
notifications: notifications:
email: - email:
recipients: recipients:
- ppwwyyxxc@gmail.com - ppwwyyxxc@gmail.com
on_success: never on_success: never
on_failure: change on_failure: change
- webhooks:
urls:
- https://webhooks.gitter.im/e/cede9dbbf6630b3704b3
on_success: change # options: [always|never|change] default: always
on_failure: always # options: [always|never|change] default: always
on_start: never # options: [always|never|change] default: always
deploy: deploy:
- provider: pypi - provider: pypi
......
...@@ -24,7 +24,7 @@ TrainConfig( ...@@ -24,7 +24,7 @@ TrainConfig(
callbacks=[ callbacks=[
# save the model every epoch # save the model every epoch
ModelSaver(), ModelSaver(),
# run inference on another Dataflow every epoch, compute top1/top5 classification error and save them # run inference on another Dataflow every epoch, compute top1/top5 classification error and save them in log
InferenceRunner(dataset_val, [ InferenceRunner(dataset_val, [
ClassificationError('wrong-top1', 'val-error-top1'), ClassificationError('wrong-top1', 'val-error-top1'),
ClassificationError('wrong-top5', 'val-error-top5')]), ClassificationError('wrong-top5', 'val-error-top5')]),
...@@ -39,12 +39,12 @@ TrainConfig( ...@@ -39,12 +39,12 @@ TrainConfig(
-d body={val-error-top1} > /dev/null 2>&1', -d body={val-error-top1} > /dev/null 2>&1',
'val-error-top1') 'val-error-top1')
], ],
extra_callbacks=[ # these are already enabled by default extra_callbacks=[ # these callbacks are already enabled by default
# maintain and summarize moving average of some tensors (e.g. training loss, training error) # maintain and summarize moving average of some tensors (e.g. training loss, training error)
MovingAverageSummary(), MovingAverageSummary(),
# draw a nice progress bar # draw a nice progress bar
ProgressBar(), ProgressBar(),
# print all the statistics I've created and scalar tensors I've summarized # print all the statistics I've created, and scalar tensors I've summarized
StatPrinter(), StatPrinter(),
] ]
) )
......
...@@ -7,8 +7,9 @@ A Dataflow has a `get_data()` generator method, ...@@ -7,8 +7,9 @@ A Dataflow has a `get_data()` generator method,
which yields a `datapoint` when called. which yields a `datapoint` when called.
A datapoint must be a **list** of Python objects which I called the `components` of this datapoint. A datapoint must be a **list** of Python objects which I called the `components` of this datapoint.
For example, to train on MNIST dataset, you can define a Dataflow For example, to train on MNIST dataset, you can build a Dataflow
that produces datapoints of two elements: a numpy array of shape (64, 28, 28), and an array of shape (64,). that produces datapoints of two elements (components):
a numpy array of shape (64, 28, 28), and an array of shape (64,).
### Composition of DataFlow ### Composition of DataFlow
One good thing about having a standard interface is to be able to provide One good thing about having a standard interface is to be able to provide
...@@ -25,10 +26,10 @@ df = CaffeLMDB('/path/to/caffe/lmdb', shuffle=False) ...@@ -25,10 +26,10 @@ df = CaffeLMDB('/path/to/caffe/lmdb', shuffle=False)
df = AugmentImageComponent(df, [imgaug.Resize((225, 225))]) df = AugmentImageComponent(df, [imgaug.Resize((225, 225))])
# group data into batches of size 128 # group data into batches of size 128
df = BatchData(df, 128) df = BatchData(df, 128)
# start 3 processes to run the dataflow in parallel, and transfer the data with ZeroMQ # start 3 processes to run the dataflow in parallel, and transfer data with ZeroMQ
df = PrefetchDataZMQ(df, 3) df = PrefetchDataZMQ(df, 3)
```` ````
Another complicated example is the [ResNet training script](../examples/ResNet/imagenet-resnet.py) A more complicated example is the [ResNet training script](../examples/ResNet/imagenet-resnet.py)
with all the data preprocessing. with all the data preprocessing.
All these modules are written in Python, All these modules are written in Python,
...@@ -60,7 +61,9 @@ A Dataflow has a `get_data()` method which yields a datapoint every time. ...@@ -60,7 +61,9 @@ A Dataflow has a `get_data()` method which yields a datapoint every time.
class MyDataFlow(DataFlow): class MyDataFlow(DataFlow):
def get_data(self): def get_data(self):
for k in range(100): for k in range(100):
yield datapoint digit = np.random.rand(28, 28)
label = np.random.randint(10)
yield [digit, label]
``` ```
Optionally, Dataflow can implement the following two methods: Optionally, Dataflow can implement the following two methods:
......
...@@ -3,6 +3,3 @@ ...@@ -3,6 +3,3 @@
The following guide introduces some core concepts of TensorPack. In contrast to several other libraries TensorPack contains of several modules to build complex deep learning algorithms and train models with high accuracy and high speed. The following guide introduces some core concepts of TensorPack. In contrast to several other libraries TensorPack contains of several modules to build complex deep learning algorithms and train models with high accuracy and high speed.
### Callbacks
The use of callbacks add the flexibility to execute code during training. These callbacks are triggered on several events such as after each step or at the end of one training epoch.
# Trainers # Trainers
## Trainer
Training is basically **running something again and again**. Training is basically **running something again and again**.
Tensorpack base trainer implements the logic of *running the iteration*, Tensorpack base trainer implements the logic of *running the iteration*,
and other trainers implement *what the iteration is*. and other trainers implement *what the iteration is*.
...@@ -54,7 +52,7 @@ The existing trainers should be enough for single-cost optimization tasks. If yo ...@@ -54,7 +52,7 @@ The existing trainers should be enough for single-cost optimization tasks. If yo
want to do something inside the trainer, consider writing it as a callback, or want to do something inside the trainer, consider writing it as a callback, or
write an issue to see if there is a better solution than creating new trainers. write an issue to see if there is a better solution than creating new trainers.
For other tasks, you might need a new trainer. For certain tasks, you might need a new trainer.
The [GAN trainer](../examples/GAN/GAN.py) is one example of how to implement The [GAN trainer](../examples/GAN/GAN.py) is one example of how to implement
new trainers. new trainers.
......
...@@ -175,7 +175,7 @@ def get_data(train_or_test): ...@@ -175,7 +175,7 @@ def get_data(train_or_test):
ds = AugmentImageComponent(ds, augmentors) ds = AugmentImageComponent(ds, augmentors)
ds = BatchData(ds, BATCH_SIZE, remainder=not isTrain) ds = BatchData(ds, BATCH_SIZE, remainder=not isTrain)
if isTrain: if isTrain:
ds = PrefetchDataZMQ(ds, min(30, multiprocessing.cpu_count())) ds = PrefetchDataZMQ(ds, min(20, multiprocessing.cpu_count()))
return ds return ds
......
...@@ -41,7 +41,7 @@ class TestDataSpeed(ProxyDataFlow): ...@@ -41,7 +41,7 @@ class TestDataSpeed(ProxyDataFlow):
with get_tqdm(total=self.test_size, leave=True) as pbar: with get_tqdm(total=self.test_size, leave=True) as pbar:
for idx, dp in enumerate(self.ds.get_data()): for idx, dp in enumerate(self.ds.get_data()):
pbar.update() pbar.update()
if idx == self.test_size: if idx == self.test_size - 1:
break break
...@@ -439,7 +439,8 @@ class LocallyShuffleData(ProxyDataFlow, RNGDataFlow): ...@@ -439,7 +439,8 @@ class LocallyShuffleData(ProxyDataFlow, RNGDataFlow):
Args: Args:
ds (DataFlow): input DataFlow. ds (DataFlow): input DataFlow.
cache_size (int): size of the cache. cache_size (int): size of the cache.
nr_reuse (int): reuse each datapoints several times to improve speed. nr_reuse (int): reuse each datapoints several times to improve
speed, but may hurt your model.
""" """
ProxyDataFlow.__init__(self, ds) ProxyDataFlow.__init__(self, ds)
self.q = deque(maxlen=cache_size) self.q = deque(maxlen=cache_size)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment