Commit b33053cf authored by Yuxin Wu's avatar Yuxin Wu

rename input_data to input_source. Update docs

parent 18064a54
...@@ -34,16 +34,15 @@ It's Yet Another TF wrapper, but different in: ...@@ -34,16 +34,15 @@ It's Yet Another TF wrapper, but different in:
Tensorpack includes only a few common models, and helpful tools such as `LinearWrap` to simplify large models. Tensorpack includes only a few common models, and helpful tools such as `LinearWrap` to simplify large models.
But you can use any other wrappers within tensorpack, such as sonnet/Keras/slim/tflearn/tensorlayer/.... But you can use any other wrappers within tensorpack, such as sonnet/Keras/slim/tflearn/tensorlayer/....
2. Focus on large datasets. 2. Focus on __training speed__.
+ __DataFlow__ allows you to process large datasets such as ImageNet in Python without blocking the training.
+ DataFlow has a unified interface, so you can compose and reuse them to perform complex preprocessing.
3. Focus on training speed.
+ Tensorpack trainer is almost always faster than `feed_dict` based wrappers. + Tensorpack trainer is almost always faster than `feed_dict` based wrappers.
Even on a small CNN example, the training runs [2x faster](https://gist.github.com/ppwwyyxx/8d95da79f8d97036a7d67c2416c851b6) than the equivalent Keras code. Even on a tiny CNN example, the training runs [2x faster](https://gist.github.com/ppwwyyxx/8d95da79f8d97036a7d67c2416c851b6) than the equivalent Keras code.
+ Data-Parallel Multi-GPU training is off-the-shelf to use. For <=4 GPUs it is as fast as [tensorflow/benchmarks](https://github.com/tensorflow/benchmarks). + Data-Parallel Multi-GPU training is off-the-shelf to use. It is as fast as Google's [benchmark code](https://github.com/tensorflow/benchmarks).
More improvements to come later.
3. Focus on large datasets.
+ __DataFlow__ allows you to process large datasets such as ImageNet in pure Python without blocking the training.
+ DataFlow has a unified interface, so you can compose and reuse them to perform complex preprocessing.
4. Interface of extensible __Callbacks__. 4. Interface of extensible __Callbacks__.
Write a callback to implement everything you want to do apart from the training iterations, and Write a callback to implement everything you want to do apart from the training iterations, and
...@@ -59,7 +58,7 @@ It's Yet Another TF wrapper, but different in: ...@@ -59,7 +58,7 @@ It's Yet Another TF wrapper, but different in:
Dependencies: Dependencies:
+ Python 2 or 3 + Python 2 or 3
+ TensorFlow >= 1.0.0 + TensorFlow >= 1.0.0 (>=1.1.0 for Multi-GPU)
+ Python bindings for OpenCV + Python bindings for OpenCV
``` ```
pip install -U git+https://github.com/ppwwyyxx/tensorpack.git pip install -U git+https://github.com/ppwwyyxx/tensorpack.git
......
...@@ -24,6 +24,8 @@ TrainConfig( ...@@ -24,6 +24,8 @@ TrainConfig(
callbacks=[ callbacks=[
# save the model every epoch # save the model every epoch
ModelSaver(), ModelSaver(),
# backup the model with best validation error
MinSaver('val-error-top1'),
# run inference on another Dataflow every epoch, compute top1/top5 classification error and save them in log # run inference on another Dataflow every epoch, compute top1/top5 classification error and save them in log
InferenceRunner(dataset_val, [ InferenceRunner(dataset_val, [
ClassificationError('wrong-top1', 'val-error-top1'), ClassificationError('wrong-top1', 'val-error-top1'),
...@@ -46,6 +48,8 @@ TrainConfig( ...@@ -46,6 +48,8 @@ TrainConfig(
ProgressBar(), ProgressBar(),
# run `tf.summary.merge_all` every epoch and send results to monitors # run `tf.summary.merge_all` every epoch and send results to monitors
MergeAllSummaries(), MergeAllSummaries(),
# run ops in GraphKeys.UPDATE_OPS collection along with training, if any
RunUpdateOps(),
], ],
monitors=[ # monitors are a special kind of callbacks. these are also enabled by default monitors=[ # monitors are a special kind of callbacks. these are also enabled by default
# write all monitor data to tensorboard # write all monitor data to tensorboard
......
...@@ -48,9 +48,8 @@ for details. ...@@ -48,9 +48,8 @@ for details.
--> -->
### Use DataFlow outside Tensorpack ### Use DataFlow outside Tensorpack
Another good thing about DataFlow is that it is independent of DataFlow is independent of both tensorpack and TensorFlow.
tensorpack internals. You can just use it as an efficient data processing pipeline You can simply use it as a data processing pipeline and plug it into any other frameworks.
and plug it into other frameworks.
To use a DataFlow independently, you will need to call `reset_state()` first to initialize it, To use a DataFlow independently, you will need to call `reset_state()` first to initialize it,
and then use the generator however you want: and then use the generator however you want:
......
...@@ -9,7 +9,7 @@ A High Level Glance ...@@ -9,7 +9,7 @@ A High Level Glance
It provides a uniform interface so that data processing modules can be chained together. It provides a uniform interface so that data processing modules can be chained together.
It allows you to load and process your data in pure Python and accelerate it by prefetching. It allows you to load and process your data in pure Python and accelerate it by prefetching.
See also :doc:`tf-queue` and :doc:`efficient-dataflow` for more details about the efficiency of data See also :doc:`input-source` and :doc:`efficient-dataflow` for more details about the efficiency of data
processing. processing.
* You can use any TF-based symbolic function library to define a model in tensorpack. * You can use any TF-based symbolic function library to define a model in tensorpack.
...@@ -34,7 +34,7 @@ User Tutorials ...@@ -34,7 +34,7 @@ User Tutorials
:maxdepth: 1 :maxdepth: 1
dataflow dataflow
tf-queue input-source
efficient-dataflow efficient-dataflow
model model
trainer trainer
......
# How data goes into the graph # Input Sources
This tutorial covers how data goes from DataFlow to TensorFlow graph. This tutorial covers how data goes from DataFlow or other sources to TensorFlow graph.
They are tensorpack internal details, but it is important to know You don't have to know it, but it may help with efficiency.
if you care about efficiency.
## Use TensorFlow queues `InputSource` is an abstract interface in tensorpack describing where the input come from and how they enter the graph.
For example,
1. Come from a DataFlow and been fed to the graph.
2. Come from a DataFlow and been prefetched on CPU by a TF queue.
3. Come from a DataFlow, prefetched on CPU by a TF queue, then prefetched on GPU by a TF StagingArea.
4. Come from some TF native reading pipeline.
5. Come from some ZMQ pipe.
For most tasks, DataFlow with some prefetch is fast enough. You can use `TrainConfig(data=)` option
to customize your `InputSource`.
## Use Prefetch
In general, `feed_dict` is slow and should never appear in your critical loop. In general, `feed_dict` is slow and should never appear in your critical loop.
i.e., you should avoid loops like this: i.e., when you use TensorFlow without any wrappers, you should avoid loops like this:
```python ```python
while True: while True:
X, y = get_some_data() X, y = get_some_data()
minimize_op.run(feed_dict={'X': X, 'y': y}) minimize_op.run(feed_dict={'X': X, 'y': y})
``` ```
However, when you need to load data from Python-side, this is the only available interface in frameworks such as Keras, tflearn. However, when you need to load data from Python-side, this is the only available interface in frameworks such as Keras, tflearn.
This is part of the reason why [tensorpack is faster](https://gist.github.com/ppwwyyxx/8d95da79f8d97036a7d67c2416c851b6) than examples from other packages. This is part of the reason why [tensorpack is faster](https://gist.github.com/ppwwyyxx/8d95da79f8d97036a7d67c2416c851b6) than examples from other frameworks.
You should use something like this instead: You should use something like this instead, to prefetch data into the graph in one thread and hide the copy latency:
```python ```python
# Thread 1: # Thread 1:
while True: while True:
...@@ -29,27 +40,28 @@ while True: ...@@ -29,27 +40,28 @@ while True:
minimize_op.run() # minimize_op was built from dequeued tensors minimize_op.run() # minimize_op was built from dequeued tensors
``` ```
This is now automatically handled by tensorpack trainers already, This is now automatically handled by tensorpack trainers already, see [Trainer](trainer.md) for details.
see [Trainer](trainer.md) for details.
TensorFlow provides staging interface which will further improve the speed in the future. This is TensorFlow StagingArea can further hide H2D (CPU->GPU) copy latency.
[issue#140](https://github.com/ppwwyyxx/tensorpack/issues/140). It is also automatically included in tensorpack when you use Synchronous MultiGPU training.
You can also avoid `feed_dict` by using TensorFlow native operators to read data, which is also You can also avoid `feed_dict` by using TensorFlow native operators to read data, which is also supported in tensorpack.
supported in tensorpack. It probably allows you to reach the best performance,
It probably allows you to reach the best performance, but at the cost of implementing the but at the cost of implementing the reading / preprocessing ops in C++ if there isn't one for your task.
reading / preprocessing ops in C++ if there isn't one for your task.
## Figure out the bottleneck ## Figure out the bottleneck
For training, we will only worry about the throughput but not the latency.
Thread 1 & 2 runs in parallel and the faster one will block to wait for the slower one. Thread 1 & 2 runs in parallel and the faster one will block to wait for the slower one.
So the overall throughput will appear to be the slower one. So the overall throughput will appear to be the slower one.
There isn't a way to accurately benchmark the two threads while they are running, without introducing overhead. However, are ways to understand which one is the bottleneck: There is no way to accurately benchmark the two dependent threads while they are running,
without introducing overhead. However, are ways to understand which one is the bottleneck:
1. Use the average occupancy (size) of the queue. This information is summarized after every epoch. 1. Use the average occupancy (size) of the queue. This information is summarized by default.
If the queue is nearly empty, then the data thread is the bottleneck. If the queue is nearly empty (default size 50), then the input source is the bottleneck.
2. Benchmark them separately. You can use `TestDataSpeed` to benchmark a DataFlow, and 2. Benchmark them separately. You can use `TestDataSpeed` to benchmark a DataFlow, and
use `FakeData` as a fast replacement in a dry run, to benchmark the training iterations. use `FakeData` as a fast replacement in a dry run, to benchmark the training iterations.
If you found your input is the bottleneck, then you'll need to think about how to speed up your data.
You may either change `InputSource`, or look at [Efficient DataFlow](http://tensorpack.readthedocs.io/en/latest/tutorial/efficient-dataflow.html).
...@@ -11,11 +11,13 @@ These trainers will by default minimizes `ModelDesc.cost`. ...@@ -11,11 +11,13 @@ These trainers will by default minimizes `ModelDesc.cost`.
Therefore, you can use these trainers as long as you set `self.cost` in `ModelDesc._build_graph()`, Therefore, you can use these trainers as long as you set `self.cost` in `ModelDesc._build_graph()`,
as most examples did. as most examples did.
Most existing trainers were implemented with a TensorFlow queue to prefetch and buffer Existing trainers were implemented with certain prefetch mechanism,
training data, which is faster than a naive `sess.run(..., feed_dict={...})`. which will run significantly faster than a naive `sess.run(..., feed_dict={...})`.
There are also multi-GPU trainers which include the logic of data-parallel multi-GPU training,
with either synchronous update or asynchronous update. You can enable multi-GPU training There are also Multi-GPU trainers which include the logic of data-parallel Multi-GPU training.
by just changing one line. You can enable them by just changing one line, and all the necessary logic to achieve the best
performance was baked into the trainers already.
For example, SyncMultiGPUTrainer can train ResNet50 as fast as the [official benchmark](https://github.com/tensorflow/benchmarks).
To use trainers, pass a `TrainConfig` to configure them: To use trainers, pass a `TrainConfig` to configure them:
...@@ -40,5 +42,3 @@ Trainers just run some iterations, so there is no limit to where the data come f ...@@ -40,5 +42,3 @@ Trainers just run some iterations, so there is no limit to where the data come f
or what to do in an iteration. or what to do in an iteration.
For example, [GAN trainer](../examples/GAN/GAN.py) minimizes For example, [GAN trainer](../examples/GAN/GAN.py) minimizes
two cost functions alternatively. two cost functions alternatively.
Some trainer takes data from a TensorFlow reading pipeline instead of a Dataflow
([PTB example](../examples/PennTreebank)).
...@@ -56,7 +56,8 @@ class GANModelDesc(ModelDesc): ...@@ -56,7 +56,8 @@ class GANModelDesc(ModelDesc):
class GANTrainer(FeedfreeTrainerBase): class GANTrainer(FeedfreeTrainerBase):
def __init__(self, config): def __init__(self, config):
self._input_method = QueueInput(config.dataflow) # TODO design better
self._input_source = QueueInput(config.dataflow)
super(GANTrainer, self).__init__(config) super(GANTrainer, self).__init__(config)
def _setup(self): def _setup(self):
...@@ -79,7 +80,7 @@ class SeparateGANTrainer(FeedfreeTrainerBase): ...@@ -79,7 +80,7 @@ class SeparateGANTrainer(FeedfreeTrainerBase):
d_period(int): period of each d_opt run d_period(int): period of each d_opt run
g_period(int): period of each g_opt run g_period(int): period of each g_opt run
""" """
self._input_method = QueueInput(config.dataflow) self._input_source = QueueInput(config.dataflow)
self._d_period = int(d_period) self._d_period = int(d_period)
self._g_period = int(g_period) self._g_period = int(g_period)
assert min(d_period, g_period) == 1 assert min(d_period, g_period) == 1
......
...@@ -17,7 +17,7 @@ from ..utils import logger, get_tqdm_kwargs ...@@ -17,7 +17,7 @@ from ..utils import logger, get_tqdm_kwargs
from ..dataflow import DataFlow from ..dataflow import DataFlow
from ..tfutils.common import get_op_tensor_name, get_tensors_by_names from ..tfutils.common import get_op_tensor_name, get_tensors_by_names
from ..tfutils.tower import TowerContext from ..tfutils.tower import TowerContext
from ..train.input_data import TensorInput, FeedInput from ..train.input_source import TensorInput, FeedInput
from ..predict import PredictorTowerBuilder from ..predict import PredictorTowerBuilder
from .base import Callback from .base import Callback
...@@ -59,14 +59,14 @@ class InferenceRunnerBase(Callback): ...@@ -59,14 +59,14 @@ class InferenceRunnerBase(Callback):
def __init__(self, input, infs, input_names=None, prefix='', extra_hooks=None): def __init__(self, input, infs, input_names=None, prefix='', extra_hooks=None):
""" """
Args: Args:
input (InputData): the input to use. Must have ``size()``. input (InputSource): the input to use. Must have ``size()``.
infs (list): list of :class:`Inferencer` to run. infs (list): list of :class:`Inferencer` to run.
input_names (list): must be a subset of the names in InputDesc. input_names (list): must be a subset of the names in InputDesc.
prefix(str): an prefix used to build the tower. Must be set prefix(str): an prefix used to build the tower. Must be set
differently if more than one :class:`InferenceRunner` are used. differently if more than one :class:`InferenceRunner` are used.
extra_hooks (list): extra ``SessionRunHook`` to run with the evaluation. extra_hooks (list): extra ``SessionRunHook`` to run with the evaluation.
""" """
self._input_data = input self._input_source = input
if not isinstance(infs, list): if not isinstance(infs, list):
self.infs = [infs] self.infs = [infs]
else: else:
...@@ -102,7 +102,7 @@ class InferenceRunnerBase(Callback): ...@@ -102,7 +102,7 @@ class InferenceRunnerBase(Callback):
# return x.name # return x.name
def _setup_graph(self): def _setup_graph(self):
self._input_data.setup(self.trainer.model) self._input_source.setup(self.trainer.model)
self._setup_input_names() self._setup_input_names()
# Use predict_tower in train config. either gpuid or -1 # Use predict_tower in train config. either gpuid or -1
self._predict_tower_id = self.trainer.config.predict_tower[0] self._predict_tower_id = self.trainer.config.predict_tower[0]
...@@ -142,9 +142,9 @@ class InferenceRunnerBase(Callback): ...@@ -142,9 +142,9 @@ class InferenceRunnerBase(Callback):
inf.before_inference() inf.before_inference()
# iterate over the data, and run the hooked session # iterate over the data, and run the hooked session
self._input_data.reset_state() self._input_source.reset_state()
for _ in tqdm.trange(self._input_data.size(), **get_tqdm_kwargs()): for _ in tqdm.trange(self._input_source.size(), **get_tqdm_kwargs()):
dp = self._input_data.next_feed() dp = self._input_source.next_feed()
feed = dict(zip(self._feed_tensors, dp)) feed = dict(zip(self._feed_tensors, dp))
self._hooked_sess.run(fetches=[], feed_dict=feed) self._hooked_sess.run(fetches=[], feed_dict=feed)
summary_inferencer(self.trainer, self.infs) summary_inferencer(self.trainer, self.infs)
...@@ -209,7 +209,7 @@ class FeedfreeInferenceRunner(InferenceRunnerBase): ...@@ -209,7 +209,7 @@ class FeedfreeInferenceRunner(InferenceRunnerBase):
"[FeedfreeInferenceRunner] name {} is not a model input!".format(n) "[FeedfreeInferenceRunner] name {} is not a model input!".format(n)
def _find_input_tensors(self): def _find_input_tensors(self):
tensors = self._input_data.get_input_tensors() tensors = self._input_source.get_input_tensors()
assert len(self.input_names) == len(tensors), \ assert len(self.input_names) == len(tensors), \
"[FeedfreeInferenceRunner] Input names must match the " \ "[FeedfreeInferenceRunner] Input names must match the " \
...@@ -251,7 +251,7 @@ class DataParallelInferenceRunner(InferenceRunner): ...@@ -251,7 +251,7 @@ class DataParallelInferenceRunner(InferenceRunner):
def _setup_graph(self): def _setup_graph(self):
model = self.trainer.model model = self.trainer.model
self._input_data.setup(model) self._input_source.setup(model)
self._setup_input_names() self._setup_input_names()
# build graph # build graph
...@@ -318,21 +318,21 @@ class DataParallelInferenceRunner(InferenceRunner): ...@@ -318,21 +318,21 @@ class DataParallelInferenceRunner(InferenceRunner):
for inf in self.infs: for inf in self.infs:
inf.before_inference() inf.before_inference()
self._input_data.reset_state() self._input_source.reset_state()
total = self._input_data.size() total = self._input_source.size()
nr_tower = len(self._gpus) nr_tower = len(self._gpus)
with tqdm.tqdm(total=total, **get_tqdm_kwargs()) as pbar: with tqdm.tqdm(total=total, **get_tqdm_kwargs()) as pbar:
while total >= nr_tower: while total >= nr_tower:
dps = [] dps = []
for k in self._gpus: for k in self._gpus:
dps.extend(self._input_data.next_feed()) dps.extend(self._input_source.next_feed())
feed = dict(zip(self._feed_tensors, dps)) feed = dict(zip(self._feed_tensors, dps))
self._parallel_hooked_sess.run(fetches=[], feed_dict=feed) self._parallel_hooked_sess.run(fetches=[], feed_dict=feed)
pbar.update(nr_tower) pbar.update(nr_tower)
total -= nr_tower total -= nr_tower
# take care of the rest # take care of the rest
while total > 0: while total > 0:
dp = self._input_data.next_feed() dp = self._input_source.next_feed()
feed = dict(zip(self._feed_tensors[:len(dp)], dp)) feed = dict(zip(self._feed_tensors[:len(dp)], dp))
self._hooked_sess.run(fetches=[], feed_dict=feed) self._hooked_sess.run(fetches=[], feed_dict=feed)
summary_inferencer(self.trainer, self.infs) summary_inferencer(self.trainer, self.infs)
...@@ -16,7 +16,7 @@ from ..tfutils import (JustCurrentSession, ...@@ -16,7 +16,7 @@ from ..tfutils import (JustCurrentSession,
get_default_sess_config, SessionInit) get_default_sess_config, SessionInit)
from ..tfutils.sesscreate import NewSessionCreator from ..tfutils.sesscreate import NewSessionCreator
from ..tfutils.optimizer import apply_grad_processors from ..tfutils.optimizer import apply_grad_processors
from .input_data import InputData from .input_source import InputSource
__all__ = ['TrainConfig'] __all__ = ['TrainConfig']
...@@ -38,7 +38,7 @@ class TrainConfig(object): ...@@ -38,7 +38,7 @@ class TrainConfig(object):
""" """
Args: Args:
dataflow (DataFlow): the dataflow to train. dataflow (DataFlow): the dataflow to train.
data (InputData): an `InputData` instance. Only one of ``dataflow`` data (InputSource): an `InputSource` instance. Only one of ``dataflow``
or ``data`` has to be present. or ``data`` has to be present.
model (ModelDesc): the model to train. model (ModelDesc): the model to train.
callbacks (list): a list of :class:`Callback` to perform during training. callbacks (list): a list of :class:`Callback` to perform during training.
...@@ -78,7 +78,7 @@ class TrainConfig(object): ...@@ -78,7 +78,7 @@ class TrainConfig(object):
self.data = None self.data = None
else: else:
self.data = data self.data = data
assert_type(self.data, InputData) assert_type(self.data, InputSource)
self.dataflow = None self.dataflow = None
if callbacks is None: if callbacks is None:
......
...@@ -6,7 +6,7 @@ ...@@ -6,7 +6,7 @@
import tensorflow as tf import tensorflow as tf
from ..tfutils.tower import TowerContext, get_current_tower_context from ..tfutils.tower import TowerContext, get_current_tower_context
from .input_data import QueueInput, FeedfreeInput from .input_source import QueueInput, FeedfreeInput
from .base import Trainer from .base import Trainer
...@@ -20,10 +20,10 @@ class FeedfreeTrainerBase(Trainer): ...@@ -20,10 +20,10 @@ class FeedfreeTrainerBase(Trainer):
""" """
def build_train_tower(self): def build_train_tower(self):
""" """
Get input tensors from `self.input_method` and build the forward graph. Get input tensors from `self.input_source` and build the forward graph.
""" """
def f(): def f():
self._input_tensors = self._input_method.get_input_tensors() self._input_tensors = self._input_source.get_input_tensors()
self.model.build_graph(self._input_tensors) self.model.build_graph(self._input_tensors)
ctx = get_current_tower_context() ctx = get_current_tower_context()
if ctx is None: # call without a context, use a default one if ctx is None: # call without a context, use a default one
...@@ -34,8 +34,8 @@ class FeedfreeTrainerBase(Trainer): ...@@ -34,8 +34,8 @@ class FeedfreeTrainerBase(Trainer):
f() f()
def _setup(self): def _setup(self):
assert isinstance(self._input_method, FeedfreeInput), type(self._input_method) assert isinstance(self._input_source, FeedfreeInput), type(self._input_source)
self._input_method.setup_training(self) self._input_source.setup_training(self)
def run_step(self): def run_step(self):
""" Simply run ``self.train_op``.""" """ Simply run ``self.train_op``."""
...@@ -85,8 +85,8 @@ class SimpleFeedfreeTrainer(SingleCostFeedfreeTrainer): ...@@ -85,8 +85,8 @@ class SimpleFeedfreeTrainer(SingleCostFeedfreeTrainer):
config (TrainConfig): ``config.data`` must exist and is a config (TrainConfig): ``config.data`` must exist and is a
:class:`FeedfreeInput`. :class:`FeedfreeInput`.
""" """
self._input_method = config.data self._input_source = config.data
assert isinstance(self._input_method, FeedfreeInput), self._input_method assert isinstance(self._input_source, FeedfreeInput), self._input_source
super(SimpleFeedfreeTrainer, self).__init__(config) super(SimpleFeedfreeTrainer, self).__init__(config)
assert len(self.config.tower) == 1, \ assert len(self.config.tower) == 1, \
"Got nr_tower={}, but doesn't support multigpu!" \ "Got nr_tower={}, but doesn't support multigpu!" \
...@@ -118,7 +118,7 @@ def QueueInputTrainer(config, input_queue=None): ...@@ -118,7 +118,7 @@ def QueueInputTrainer(config, input_queue=None):
assert isinstance(config.data, QueueInput), config.data assert isinstance(config.data, QueueInput), config.data
# debug # debug
# from tensorpack.train.input_data import StagingInputWrapper, DummyConstantInput # from tensorpack.train.input_source import StagingInputWrapper, DummyConstantInput
# config.data = StagingInputWrapper(config.data, ['/gpu:0']) # config.data = StagingInputWrapper(config.data, ['/gpu:0'])
# config.data = DummyConstantInput([[128,224,224,3], [128]]) # config.data = DummyConstantInput([[128,224,224,3], [128]])
return SimpleFeedfreeTrainer(config) return SimpleFeedfreeTrainer(config)
#!/usr/bin/env python #!/usr/bin/env python
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# File: input_data.py # File: input_source.py
# Author: Yuxin Wu <ppwwyyxxc@gmail.com> # Author: Yuxin Wu <ppwwyyxxc@gmail.com>
import tensorflow as tf import tensorflow as tf
...@@ -24,15 +24,15 @@ from ..utils.concurrency import ShareSessionThread ...@@ -24,15 +24,15 @@ from ..utils.concurrency import ShareSessionThread
from ..callbacks.concurrency import StartProcOrThread from ..callbacks.concurrency import StartProcOrThread
from ..callbacks.base import Callback from ..callbacks.base import Callback
__all__ = ['InputData', 'FeedfreeInput', __all__ = ['InputSource', 'FeedfreeInput',
'QueueInput', 'BatchQueueInput', 'QueueInput', 'BatchQueueInput',
'ZMQInput', 'ZMQInput',
'DummyConstantInput', 'TensorInput', 'StagingInputWrapper'] 'DummyConstantInput', 'TensorInput', 'StagingInputWrapper']
@six.add_metaclass(ABCMeta) @six.add_metaclass(ABCMeta)
class InputData(object): class InputSource(object):
""" Base class for the abstract InputData. """ """ Base class for the abstract InputSource. """
@abstractmethod @abstractmethod
def get_input_tensors(self): def get_input_tensors(self):
...@@ -56,7 +56,7 @@ class InputData(object): ...@@ -56,7 +56,7 @@ class InputData(object):
return [] return []
class FeedInput(InputData): class FeedInput(InputSource):
""" Input by iterating over a DataFlow and feed datapoints. """ """ Input by iterating over a DataFlow and feed datapoints. """
def __init__(self, ds): def __init__(self, ds):
""" """
...@@ -87,7 +87,7 @@ class FeedInput(InputData): ...@@ -87,7 +87,7 @@ class FeedInput(InputData):
return next(self.data_producer) return next(self.data_producer)
class FeedfreeInput(InputData): class FeedfreeInput(InputSource):
""" Abstract base for input without feed, """ Abstract base for input without feed,
e.g. by queue or other operations. """ e.g. by queue or other operations. """
......
...@@ -18,7 +18,7 @@ from ..tfutils.gradproc import FilterNoneGrad, ScaleGradient ...@@ -18,7 +18,7 @@ from ..tfutils.gradproc import FilterNoneGrad, ScaleGradient
from .base import Trainer from .base import Trainer
from .feedfree import SingleCostFeedfreeTrainer from .feedfree import SingleCostFeedfreeTrainer
from .input_data import QueueInput, StagingInputWrapper from .input_source import QueueInput, StagingInputWrapper
__all__ = ['SyncMultiGPUTrainer', 'AsyncMultiGPUTrainer'] __all__ = ['SyncMultiGPUTrainer', 'AsyncMultiGPUTrainer']
...@@ -100,17 +100,17 @@ class SyncMultiGPUTrainerParameterServer(MultiGPUTrainer, SingleCostFeedfreeTrai ...@@ -100,17 +100,17 @@ class SyncMultiGPUTrainerParameterServer(MultiGPUTrainer, SingleCostFeedfreeTrai
""" """
if config.dataflow is not None: if config.dataflow is not None:
# use queueinput by default. May need to avoid this in the future (when more input type is available) # use queueinput by default. May need to avoid this in the future (when more input type is available)
self._input_method = QueueInput(config.dataflow) self._input_source = QueueInput(config.dataflow)
else: else:
self._input_method = config.data self._input_source = config.data
if len(config.tower) > 1: if len(config.tower) > 1:
assert tf.test.is_gpu_available() assert tf.test.is_gpu_available()
# seem to only improve on >1 GPUs # seem to only improve on >1 GPUs
if not isinstance(self._input_method, StagingInputWrapper): if not isinstance(self._input_source, StagingInputWrapper):
devices = ['/gpu:{}'.format(k) for k in config.tower] devices = ['/gpu:{}'.format(k) for k in config.tower]
self._input_method = StagingInputWrapper(self._input_method, devices) self._input_source = StagingInputWrapper(self._input_source, devices)
assert ps_device in ['gpu', 'cpu'], ps_device assert ps_device in ['gpu', 'cpu'], ps_device
self._ps_device = ps_device self._ps_device = ps_device
...@@ -192,9 +192,9 @@ class AsyncMultiGPUTrainer(MultiGPUTrainer, ...@@ -192,9 +192,9 @@ class AsyncMultiGPUTrainer(MultiGPUTrainer,
effective learning rate. effective learning rate.
""" """
if config.dataflow is not None: if config.dataflow is not None:
self._input_method = QueueInput(config.dataflow) self._input_source = QueueInput(config.dataflow)
else: else:
self._input_method = config.data self._input_source = config.data
super(AsyncMultiGPUTrainer, self).__init__(config) super(AsyncMultiGPUTrainer, self).__init__(config)
self._scale_gradient = scale_gradient self._scale_gradient = scale_gradient
......
...@@ -5,7 +5,7 @@ ...@@ -5,7 +5,7 @@
from .base import Trainer from .base import Trainer
from ..tfutils import TowerContext from ..tfutils import TowerContext
from .input_data import FeedInput from .input_source import FeedInput
__all__ = ['SimpleTrainer'] __all__ = ['SimpleTrainer']
...@@ -21,19 +21,19 @@ class SimpleTrainer(Trainer): ...@@ -21,19 +21,19 @@ class SimpleTrainer(Trainer):
""" """
super(SimpleTrainer, self).__init__(config) super(SimpleTrainer, self).__init__(config)
if config.dataflow is None: if config.dataflow is None:
self._input_method = config.data self._input_source = config.data
assert isinstance(self._input_method, FeedInput), type(self._input_method) assert isinstance(self._input_source, FeedInput), type(self._input_source)
else: else:
self._input_method = FeedInput(config.dataflow) self._input_source = FeedInput(config.dataflow)
def run_step(self): def run_step(self):
""" Feed data into the graph and run the updates. """ """ Feed data into the graph and run the updates. """
dp = self._input_method.next_feed() dp = self._input_source.next_feed()
feed = dict(zip(self.inputs, dp)) feed = dict(zip(self.inputs, dp))
self.hooked_sess.run(self.train_op, feed_dict=feed) self.hooked_sess.run(self.train_op, feed_dict=feed)
def _setup(self): def _setup(self):
self._input_method.setup_training(self) self._input_source.setup_training(self)
model = self.model model = self.model
self.inputs = model.get_reused_placehdrs() self.inputs = model.get_reused_placehdrs()
with TowerContext('', is_training=True): with TowerContext('', is_training=True):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment