Commit ab8503e8 authored by Yuxin Wu's avatar Yuxin Wu

update docs

parent 47c98ec2
...@@ -39,7 +39,7 @@ without worrying about adding operators to TensorFlow. ...@@ -39,7 +39,7 @@ without worrying about adding operators to TensorFlow.
Unless you are working with standard data types (image folders, LMDB, etc), Unless you are working with standard data types (image folders, LMDB, etc),
you would usually want to write your own DataFlow. you would usually want to write your own DataFlow.
See [another tutorial](http://tensorpack.readthedocs.io/en/latest/tutorial/extend/dataflow.html) See [another tutorial](http://tensorpack.readthedocs.io/en/latest/tutorial/extend/dataflow.html)
for details. for details on handling your own data format.
<!-- <!--
- TODO mention RL, distributed data, and zmq operator in the future. - TODO mention RL, distributed data, and zmq operator in the future.
......
...@@ -3,25 +3,25 @@ ...@@ -3,25 +3,25 @@
This tutorial gives an overview of how to build an efficient DataFlow, using ImageNet This tutorial gives an overview of how to build an efficient DataFlow, using ImageNet
dataset as an example. dataset as an example.
Our goal in the end is to have Our goal in the end is to have
a __generator__ which yields preprocessed ImageNet images and labels as fast as possible. a __Python generator__ which yields preprocessed ImageNet images and labels as fast as possible.
Since it is simply a generator interface, you can use the DataFlow in other frameworks (e.g. Keras) Since it is simply a generator interface, you can use the DataFlow in other Python-based frameworks (e.g. Keras)
or your own code as well. or your own code as well.
We use ILSVRC12 training set, which contains 1.28 million images. We use ILSVRC12 training set, which contains 1.28 million images.
The original images (JPEG compressed) are 140G in total. The original images (JPEG compressed) are 140G in total.
The average resolution is about 400x350 <sup>[[1]]</sup>. The average resolution is about 400x350 <sup>[[1]]</sup>.
Following the [ResNet example](../examples/ResNet), we need images in their original resolution, Following the [ResNet example](../examples/ResNet), we need images in their original resolution,
so we will read the original dataset instead of a down-sampled version, and so we will read the original dataset (instead of a down-sampled version), and
apply complicated preprocessing to it. then apply complicated preprocessing to it.
We will need to reach a speed of, roughly 1000 images per second, to keep GPUs busy. We will need to reach a speed of, roughly 1k ~ 2k images per second, to keep GPUs busy.
Note that the actual performance would depend on not only the disk, but also Note that the actual performance would depend on not only the disk, but also
memory (for caching) and CPU (for data processing). memory (for caching) and CPU (for data processing).
You will need to tune the parameters (#processes, #threads, size of buffer, etc.) You may need to tune the parameters (#processes, #threads, size of buffer, etc.)
or change the pipeline for new tasks and new machines to achieve the best performance. or change the pipeline for new tasks and new machines to achieve the best performance.
This tutorial is quite complicated because you do need this knowledge of hardware & system to run fast on ImageNet-sized dataset. This tutorial is quite complicated because you do need this knowledge of hardware & system to run fast on ImageNet-sized dataset.
However, for __small datasets__ (e.g., several GBs), a proper prefetch should work well enough. However, for __smaller datasets__ (e.g. several GBs of space, or lightweight preprocessing), a simple reader plus some prefetch should work well enough.
## Random Read ## Random Read
...@@ -30,7 +30,7 @@ We start from a simple DataFlow: ...@@ -30,7 +30,7 @@ We start from a simple DataFlow:
from tensorpack import * from tensorpack import *
ds0 = dataset.ILSVRC12('/path/to/ILSVRC12', 'train', shuffle=True) ds0 = dataset.ILSVRC12('/path/to/ILSVRC12', 'train', shuffle=True)
ds1 = BatchData(ds0, 256, use_list=True) ds1 = BatchData(ds0, 256, use_list=True)
TestDataSpeed(ds1).start_test() TestDataSpeed(ds1).start()
``` ```
Here `ds0` simply reads original images from the filesystem. It is implemented simply by: Here `ds0` simply reads original images from the filesystem. It is implemented simply by:
...@@ -44,11 +44,11 @@ By default, `BatchData` ...@@ -44,11 +44,11 @@ By default, `BatchData`
will stack the datapoints into an `numpy.ndarray`, but since original images are of different shapes, we use will stack the datapoints into an `numpy.ndarray`, but since original images are of different shapes, we use
`use_list=True` so that it just produces lists. `use_list=True` so that it just produces lists.
On an SSD you probably can already observe good speed here (e.g. 5 it/s, that is 1280 samples/s), but on HDD the speed may be just 1 it/s, On an SSD you probably can already observe good speed here (e.g. 5 it/s, that is 1280 images/s), but on HDD the speed may be just 1 it/s,
because we are doing heavy random read on the filesystem (regardless of whether `shuffle` is True). because we are doing heavy random read on the filesystem (regardless of whether `shuffle` is True).
We will now add the cheapest pre-processing now to get an ndarray in the end instead of a list We will now add the cheapest pre-processing now to get an ndarray in the end instead of a list
(because TensorFlow will need ndarray eventually): (because training will need ndarray eventually):
```eval_rst ```eval_rst
.. code-block:: python .. code-block:: python
:emphasize-lines: 2,3 :emphasize-lines: 2,3
...@@ -68,26 +68,29 @@ Now it's time to add threads or processes: ...@@ -68,26 +68,29 @@ Now it's time to add threads or processes:
ds = PrefetchDataZMQ(ds1, nr_proc=25) ds = PrefetchDataZMQ(ds1, nr_proc=25)
ds = BatchData(ds, 256) ds = BatchData(ds, 256)
``` ```
Here we started 25 processes to run `ds1`, and collect their output through ZMQ IPC protocol. Here we start 25 processes to run `ds1`, and collect their output through ZMQ IPC protocol.
Using ZMQ to transfer data is faster than `multiprocessing.Queue`, but data copy (even Using ZMQ to transfer data is faster than `multiprocessing.Queue`, but data copy (even
within one process) can still be quite expensive when you're dealing with large data. within one process) can still be quite expensive when you're dealing with large data.
For example, to reduce copy overhead, the ResNet example deliberately moves certain pre-processing (the mean/std normalization) from DataFlow to the graph. For example, to reduce copy overhead, the ResNet example deliberately moves certain pre-processing (the mean/std normalization) from DataFlow to the graph.
This way the DataFlow only transfers uint8 images as opposed float32 which takes 4x more memory. This way the DataFlow only transfers uint8 images as opposed float32 which takes 4x more memory.
Alternatively, you can use multi-threading like this: Alternatively, you can use multi-threading like this:
```python ```eval_rst
ds0 = dataset.ILSVRC12('/path/to/ILSVRC12', 'train', shuffle=True) .. code-block:: python
augmentor = AugmentorList(lots_of_augmentors) :emphasize-lines: 3-6
ds1 = ThreadedMapData(
ds0 = dataset.ILSVRC12('/path/to/ILSVRC12', 'train', shuffle=True)
augmentor = AugmentorList(lots_of_augmentors)
ds1 = ThreadedMapData(
ds0, nr_thread=25, ds0, nr_thread=25,
map_func=lambda x: augmentor.augment(x), buffer_size=1000) map_func=lambda x: augmentor.augment(x), buffer_size=1000)
# ds1 = PrefetchDataZMQ(ds1, nr_proc=1) # ds1 = PrefetchDataZMQ(ds1, nr_proc=1)
ds = BatchData(ds1, 256) ds = BatchData(ds1, 256)
``` ```
Since no `fork()` is happening here, there'll be only one instance of `ds0`. Since no `fork()` is happening here, there'll be only one instance of `ds0`.
25 threads will fetch data from `ds0`, run the augmentor function and 25 threads will all fetch data from the same `ds0` instance, run the augmentor function and
put results into a buffer of size 1000. put results into a buffer of size 1000.
To reduce the effect of GIL, you can then uncomment the line so that everything above it (including all the To reduce the effect of GIL, you want to uncomment the line so that everything above it (including all the
threads) happen in an independent process. threads) happen in an independent process.
There is no answer whether it is faster to use threads or processes. There is no answer whether it is faster to use threads or processes.
...@@ -98,7 +101,7 @@ You can also try a combination of both (several processes each with several thre ...@@ -98,7 +101,7 @@ You can also try a combination of both (several processes each with several thre
## Sequential Read ## Sequential Read
Random read is usually not a good idea, especially if the data is not on a SSD. Random read is usually not a good idea, especially if the data is not on a SSD.
We can also dump the dataset into one single file and read it sequentially. We can also dump the dataset into one single LMDB file and read it sequentially.
```python ```python
from tensorpack import * from tensorpack import *
...@@ -123,16 +126,16 @@ ds1 = PrefetchDataZMQ(ds0, nr_proc=1) ...@@ -123,16 +126,16 @@ ds1 = PrefetchDataZMQ(ds0, nr_proc=1)
dftools.dump_dataflow_to_lmdb(ds1, '/path/to/ILSVRC-train.lmdb') dftools.dump_dataflow_to_lmdb(ds1, '/path/to/ILSVRC-train.lmdb')
``` ```
The above script builds a DataFlow which produces jpeg-encoded ImageNet data. The above script builds a DataFlow which produces jpeg-encoded ImageNet data.
We store the jpeg string as a numpy array because the function `cv2.imdecode` expect it later. We store the jpeg string as a numpy array because the function `cv2.imdecode` expect this format.
We use 1 prefetch process to speed up. If `nr_proc>1`, `ds1` will take data We use 1 prefetch process to speed up. If `nr_proc>1`, `ds1` will take data
from several forks of `ds0` and will not be identical to `ds0` any more. from several forks of `ds0` and will therefore be not identical to `ds0` any more.
It will generate a database file of 140G. We build a DataFlow to read the LMDB file sequentially: It will generate a database file of 140G. We build a DataFlow to read the LMDB file sequentially:
``` ```
from tensorpack import * from tensorpack import *
ds = LMDBData('/path/to/ILSVRC-train.lmdb', shuffle=False) ds = LMDBData('/path/to/ILSVRC-train.lmdb', shuffle=False)
ds = BatchData(ds, 256, use_list=True) ds = BatchData(ds, 256, use_list=True)
TestDataSpeed(ds).start_test() TestDataSpeed(ds).start()
``` ```
Depending on whether the OS has cached the file for you (and how large the RAM is), the above script Depending on whether the OS has cached the file for you (and how large the RAM is), the above script
can run at a speed of 10~130 it/s, roughly corresponding to 250MB~3.5GB/s bandwidth. You can test can run at a speed of 10~130 it/s, roughly corresponding to 250MB~3.5GB/s bandwidth. You can test
...@@ -165,7 +168,7 @@ Then we add necessary transformations: ...@@ -165,7 +168,7 @@ Then we add necessary transformations:
ds = BatchData(ds, 256) ds = BatchData(ds, 256)
``` ```
1. `LMDBDataPoint` deserialize the datapoints (from string to [jpeg_string, label]) 1. `LMDBDataPoint` deserialize the datapoints (from raw bytes to [jpeg_string, label] -- what we dumped in `RawILSVRC12`)
2. Use OpenCV to decode the first component into ndarray 2. Use OpenCV to decode the first component into ndarray
3. Apply augmentations to the ndarray 3. Apply augmentations to the ndarray
...@@ -187,18 +190,20 @@ Both imdecode and the augmentors can be quite slow. We can parallelize them like ...@@ -187,18 +190,20 @@ Both imdecode and the augmentors can be quite slow. We can parallelize them like
Since we are reading the database sequentially, having multiple identical instances of the Since we are reading the database sequentially, having multiple identical instances of the
underlying DataFlow will result in biased data distribution. Therefore we use `PrefetchData` to underlying DataFlow will result in biased data distribution. Therefore we use `PrefetchData` to
launch the underlying DataFlow in one independent process, and only parallelize the transformations. launch the underlying DataFlow in one independent process, and only parallelize the transformations.
(`PrefetchDataZMQ` is faster but not fork-safe, so the first prefetch has to be `PrefetchData`. This is [issue#138](https://github.com/ppwwyyxx/tensorpack/issues/138)) (`PrefetchDataZMQ` is faster but not fork-safe, so the first prefetch has to be `PrefetchData`. This
is supposed to get fixed in the future).
Let me summarize what the above DataFlow does: Let me summarize what the above DataFlow does:
1. One process reads LMDB file, shuffle them in a buffer and put them into a `multiprocessing.Queue` (used by `PrefetchData`). 1. One process reads LMDB file, shuffle them in a buffer and put them into a `multiprocessing.Queue` (used by `PrefetchData`).
2. 25 processes take items from the queue, decode and process them into [image, label] pairs, and 2. 25 processes take items from the queue, decode and process them into [image, label] pairs, and
send them through ZMQ IPC pipes. send them through ZMQ IPC pipe.
3. The main process takes data from the pipe and feeds it into the graph, according to 3. The main process takes data from the pipe, makes batches and feeds them into the graph,
how the `Trainer` is implemented. according to what [InputSource](http://tensorpack.readthedocs.io/en/latest/tutorial/input-source.html) is used.
The above DataFlow can run at a speed of 5~10 batches per second if you have good CPUs, RAM, disks and augmentors. The above DataFlow can run at a speed of 1k ~ 2k images per second if you have good CPUs, RAM, disks and augmentors.
As a reference, tensorpack can train ResNet-18 (a shallow ResNet) at 4.5 batches (of 256 samples) per second on 4 old TitanX. As a reference, tensorpack can train ResNet-18 (a shallow ResNet) at 4.5 batches (1.2k images) per second on 4 old TitanX.
A DGX-1 (8 P100) can train ResNet-50 at 1.7k images/s according to the [official benchmark](https://www.tensorflow.org/performance/benchmarks).
So DataFlow will not be a serious bottleneck if configured properly. So DataFlow will not be a serious bottleneck if configured properly.
## More Efficient DataFlow ## More Efficient DataFlow
...@@ -225,7 +230,7 @@ send_dataflow_zmq(df, 'ipc:///tmp/ipc-socket') ...@@ -225,7 +230,7 @@ send_dataflow_zmq(df, 'ipc:///tmp/ipc-socket')
```python ```python
# Training Machine, training process # Training Machine, training process
df = RemoteDataZMQ('ipc:///tmp/ipc-socket', 'tcp://0.0.0.0:8877') df = RemoteDataZMQ('ipc:///tmp/ipc-socket', 'tcp://0.0.0.0:8877')
TestDataSpeed(df).start_test() TestDataSpeed(df).start()
``` ```
......
### Write an image augmentor ### Write an image augmentor
The first thing to note: an augmentor is a part of the DataFlow, so you can always The first thing to note: __you never have to write an augmentor__.
An augmentor is a part of the DataFlow, so you can always
[write a DataFlow](http://tensorpack.readthedocs.io/en/latest/tutorial/extend/dataflow.html) [write a DataFlow](http://tensorpack.readthedocs.io/en/latest/tutorial/extend/dataflow.html)
to do whatever operations to your data, rather than writing an augmentor. to do whatever operations to your data, rather than writing an augmentor.
Augmentors just sometimes make things easier. Augmentors just sometimes make things easier.
An augmentor maps images to images. An augmentor maps images to images.
If you have such a mapping function `f` already, you can simply use `imgaug.MapImage(f)` as the If you have such a mapping function `f` already, you can simply use `imgaug.MapImage(f)` as the
augmentor, or use `MapDataComponent(df, f, index)` as the DataFlow. augmentor, or use `MapDataComponent(dataflow, f, index)` as the DataFlow.
In other words, for simple mapping you do not need to write an augmentor. In other words, for simple mapping you do not need to write an augmentor.
An augmentor may do something more than applying a mapping. The interface you will need to implement An augmentor may do something more than applying a mapping. The interface you will need to implement
...@@ -28,7 +29,7 @@ It does the following extra things for you: ...@@ -28,7 +29,7 @@ It does the following extra things for you:
1. `self.rng` is a `np.random.RandomState` object, 1. `self.rng` is a `np.random.RandomState` object,
guaranteed to have different seeds when you use multiprocess prefetch. guaranteed to have different seeds when you use multiprocess prefetch.
In multiprocess settings, you have to use it to generate random numbers. In multiprocess settings, you have to use this rng to generate random numbers.
2. Random parameter generation and the actual augmentation is separated. This allows you to apply the 2. Random parameter generation and the actual augmentation is separated. This allows you to apply the
same transformation to several images together (with `AugmentImageComponents`), same transformation to several images together (with `AugmentImageComponents`),
......
...@@ -33,10 +33,14 @@ TODO how to access the tensors already defined. ...@@ -33,10 +33,14 @@ TODO how to access the tensors already defined.
Can be used to run some manual initialization of variables, or start some services for the whole training. Can be used to run some manual initialization of variables, or start some services for the whole training.
* `_trigger_step(self)` * `_after_train(self)`
Do something (including running ops) after each step has finished. Do some finalization work.
Be careful to only do light work here because it could affect training speed.
* `_before_epoch(self)`, `_after_epoch(self)`
Use it only when you really need something to happen __immediately__ before/after an epoch.
Usually `_trigger_epoch` should be enough.
* `_before_run(self, ctx)`, `_after_run(self, ctx, values)` * `_before_run(self, ctx)`, `_after_run(self, ctx, values)`
...@@ -56,6 +60,11 @@ The training loops would become `sess.run([training_op, my_op])`. ...@@ -56,6 +60,11 @@ The training loops would become `sess.run([training_op, my_op])`.
This is different from `sess.run(training_op); sess.run(my_op);`, This is different from `sess.run(training_op); sess.run(my_op);`,
which is what you would get if you run the op in `_trigger_step`. which is what you would get if you run the op in `_trigger_step`.
* `_trigger_step(self)`
Do something (including running ops) after each step has finished.
Be careful to only do light work here because it could affect training speed.
* `_trigger_epoch(self)` * `_trigger_epoch(self)`
Do something after each epoch has finished. Will call `self.trigger()` by default. Do something after each epoch has finished. Will call `self.trigger()` by default.
...@@ -63,9 +72,5 @@ Do something after each epoch has finished. Will call `self.trigger()` by defaul ...@@ -63,9 +72,5 @@ Do something after each epoch has finished. Will call `self.trigger()` by defaul
* `_trigger(self)` * `_trigger(self)`
By default will get called by `_trigger_epoch`, By default will get called by `_trigger_epoch`,
but you can then customize the scheduling of this callback by but you can customize the scheduling of this callback by
`PeriodicTrigger`, to let this method run every k steps or every k epochs. `PeriodicTrigger`, to let this method run every k steps or every k epochs.
* `_after_train(self)`
Do some finalization work.
...@@ -25,7 +25,7 @@ Optionally, DataFlow can implement the following two methods: ...@@ -25,7 +25,7 @@ Optionally, DataFlow can implement the following two methods:
A typical situation is when your DataFlow uses random number generator (RNG). Then you would need to reset the RNG here. A typical situation is when your DataFlow uses random number generator (RNG). Then you would need to reset the RNG here.
Otherwise, child processes will have the same random seed. The `RNGDataFlow` base class does this for you. Otherwise, child processes will have the same random seed. The `RNGDataFlow` base class does this for you.
With a "low-level" DataFlow defined, you can then compose it with existing modules (e.g. batching, prefetching, ...). With a "low-level" DataFlow defined like above, you can then compose it with existing modules (e.g. batching, prefetching, ...).
DataFlow implementations for several well-known datasets are provided in the DataFlow implementations for several well-known datasets are provided in the
[dataflow.dataset](http://tensorpack.readthedocs.io/en/latest/modules/tensorpack.dataflow.dataset.html) [dataflow.dataset](http://tensorpack.readthedocs.io/en/latest/modules/tensorpack.dataflow.dataset.html)
......
## Implement a layer ## Write a layer
Symbolic functions should be nothing new to you. The first thing to note: __you never have to write a layer__.
Using symbolic functions in tensorpack is same as in TensorFlow: you can use any symbolic functions you have Tensorpack layers are nothing but wrappers of symbolic functions.
made or seen elsewhere together with tensorpack layers. You can use any symbolic functions you have written or seen elsewhere with or without tensorpack layers.
You can use symbolic functions from slim/tflearn/tensorlayer, and even Keras/sonnet ([with some tricks](../../examples/mnist-keras.py)). You can use symbolic functions from slim/tflearn/tensorlayer, and even Keras/sonnet ([with some tricks](../../examples/mnist-keras.py)).
So you never **have to** implement a tensorpack layer.
If you would like, you can make a symbolic function become a "layer" by following some simple rules, and then gain benefits from the framework. If you would like, you can make a symbolic function become a "layer" by following some simple rules, and then gain benefits from the framework.
......
...@@ -12,7 +12,7 @@ for more details. ...@@ -12,7 +12,7 @@ for more details.
If you think: If you think:
1. The framework has limitation in its interface so your XYZ cannot be supported, OR 1. The framework has limitation in its interface so your XYZ cannot be supported, OR
2. Your XYZ is very common / very well-defined, so it would be nice to include it. 2. Your XYZ is super common / very well-defined / very useful, so it would be nice to include it.
Then it is a good time to open an issue. Then it is a good time to open an issue.
...@@ -40,7 +40,7 @@ decide which one to use from a file name.) ...@@ -40,7 +40,7 @@ decide which one to use from a file name.)
Doing transfer learning is straightforward. Variable restoring is completely based on name match between Doing transfer learning is straightforward. Variable restoring is completely based on name match between
the current graph and the `SessionInit` initializer. the current graph and the `SessionInit` initializer.
Therefore, if you want to load some model, just use the same name. Therefore, if you want to load some model, just use the same variable name.
If you want to re-train some layer, just rename it. If you want to re-train some layer, just rename it.
Unmatched variables on both sides will be printed as a warning. Unmatched variables on both sides will be printed as a warning.
......
...@@ -51,7 +51,7 @@ See the [Efficient DataFlow](http://tensorpack.readthedocs.io/en/latest/tutorial ...@@ -51,7 +51,7 @@ See the [Efficient DataFlow](http://tensorpack.readthedocs.io/en/latest/tutorial
When you use Python to load/preprocess data, TF `QueueBase` can help hide the "Copy to TF" latency, When you use Python to load/preprocess data, TF `QueueBase` can help hide the "Copy to TF" latency,
and TF `StagingArea` can help hide the "Copy to GPU" latency. and TF `StagingArea` can help hide the "Copy to GPU" latency.
They are used by most examples in tensorpack, They are used by most examples in tensorpack,
however most other TensorFlow wrappers are `feed_dict` based -- no latency hiding at all. however most other TensorFlow wrappers are designed to be `feed_dict` based -- no latency hiding at all.
This is the major reason why tensorpack is [faster](https://gist.github.com/ppwwyyxx/8d95da79f8d97036a7d67c2416c851b6). This is the major reason why tensorpack is [faster](https://gist.github.com/ppwwyyxx/8d95da79f8d97036a7d67c2416c851b6).
## InputSource ## InputSource
......
...@@ -21,7 +21,7 @@ Basically, `_get_inputs` should define the metainfo of all the possible placehol ...@@ -21,7 +21,7 @@ Basically, `_get_inputs` should define the metainfo of all the possible placehol
the argument `inputs` is the list of input tensors matching `_get_inputs`. the argument `inputs` is the list of input tensors matching `_get_inputs`.
You can use any symbolic functions in `_build_graph`, including TensorFlow core library You can use any symbolic functions in `_build_graph`, including TensorFlow core library
functions and other symbolic libraries (see below). functions and other symbolic libraries.
tensorpack also contains a small collection of common model primitives, tensorpack also contains a small collection of common model primitives,
such as conv/deconv, fc, batch normalization, pooling layers, and some custom loss functions. such as conv/deconv, fc, batch normalization, pooling layers, and some custom loss functions.
......
...@@ -2,8 +2,8 @@ ...@@ -2,8 +2,8 @@
# Trainer # Trainer
Training is **running something again and again**. Training is **running something again and again**.
Tensorpack base trainer implements the logic of *running the iteration*, Tensorpack base trainer implements the logic of __running the iteration__,
and other trainers implement *what the iteration is*. and derived trainers implement __what the iteration is__.
Most neural network training tasks are single-cost optimization. Most neural network training tasks are single-cost optimization.
Tensorpack provides some trainer implementations for such tasks. Tensorpack provides some trainer implementations for such tasks.
...@@ -17,7 +17,7 @@ which will run significantly faster than a naive `sess.run(..., feed_dict={...}) ...@@ -17,7 +17,7 @@ which will run significantly faster than a naive `sess.run(..., feed_dict={...})
There are also Multi-GPU trainers which include the logic of data-parallel Multi-GPU training. There are also Multi-GPU trainers which include the logic of data-parallel Multi-GPU training.
You can enable them by just changing one line, and all the necessary logic to achieve the best You can enable them by just changing one line, and all the necessary logic to achieve the best
performance was baked into the trainers already. performance was baked into the trainers already.
For example, SyncMultiGPUTrainer can train ResNet50 as fast as the [official benchmark](https://github.com/tensorflow/benchmarks). For example, SyncMultiGPUTrainer can train ResNet50 as fast as the [official tensorflow benchmark](https://github.com/tensorflow/benchmarks).
To use trainers, pass a `TrainConfig` to configure them: To use trainers, pass a `TrainConfig` to configure them:
...@@ -38,7 +38,7 @@ QueueInputTrainer(config).train() ...@@ -38,7 +38,7 @@ QueueInputTrainer(config).train()
# SyncMultiGPUTrainer(config).train() # SyncMultiGPUTrainer(config).train()
``` ```
Trainers just run some iterations, so there is no limit to where the data come from Trainers just run __some__ iterations, so there is no limit in where the data come from
or what to do in an iteration. or what to do in an iteration.
For example, [GAN trainer](../examples/GAN/GAN.py) minimizes For example, [GAN trainer](../examples/GAN/GAN.py) minimizes
two cost functions alternatively. two cost functions alternatively.
...@@ -270,7 +270,7 @@ def main(): ...@@ -270,7 +270,7 @@ def main():
else: else:
column = args.column.strip().split(',') column = args.column.strip().split(',')
for k in column: for k in column:
assert k[0] in ['x', 'y'] assert k[0] in ['x', 'y', 'n']
assert nr_column == len(column), "Column and data doesn't have same length. {}!={}".format(nr_column, len(column)) assert nr_column == len(column), "Column and data doesn't have same length. {}!={}".format(nr_column, len(column))
args.y_column = [v for v in column if v[0] == 'y'] args.y_column = [v for v in column if v[0] == 'y']
args.y_column_idx = [idx for idx, v in enumerate(column) if v[0] == 'y'] args.y_column_idx = [idx for idx, v in enumerate(column) if v[0] == 'y']
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment