Commit 00c47fa0 authored by Yuxin Wu's avatar Yuxin Wu

add batchqueueinput

parent c86cd15a
......@@ -67,10 +67,10 @@ class LMDBData(RNGDataFlow):
lmdb_path (str): a directory or a file.
shuffle (bool): shuffle the keys or not.
keys (list of str or str): list of str as the keys, used only when shuffle is True.
It can also be a format string e.g. `'{:0>8d}'` which will be
formatted with the indices from 0 to `total_size - 1`.
It can also be a format string e.g. ``{:0>8d}`` which will be
formatted with the indices from 0 to *total_size - 1*.
If not provided, it will then look in the database for `__keys__` which
If not provided, it will then look in the database for ``__keys__`` which
:func:`dump_dataflow_to_lmdb` used to store the list of keys.
If still not found, it will iterate over the database to find
all the keys.
......@@ -177,7 +177,7 @@ def CaffeLMDB(lmdb_path, shuffle=True, keys=None):
a :class:`LMDBDataDecoder` instance.
Example:
ds = CaffeLMDB("/tmp/validation", keys='{:0>8d}')
``ds = CaffeLMDB("/tmp/validation", keys='{:0>8d}')``
"""
cpb = get_caffe_pb()
......
......@@ -105,26 +105,22 @@ class SimpleFeedfreeTrainer(
# self.train_op = tf.group(*self.dequed_inputs)
class QueueInputTrainer(SimpleFeedfreeTrainer):
def QueueInputTrainer(config, input_queue=None, predict_tower=None):
"""
A trainer which automatically wraps ``config.dataflow`` by a
A wrapper trainer which automatically wraps ``config.dataflow`` by a
:class:`QueueInput`.
"""
def __init__(self, config, input_queue=None, predict_tower=None):
"""
Single tower Trainer, takes input from a queue
It is an equivalent of ``SimpleFeedfreeTrainer(config)`` with ``config.data = QueueInput(dataflow)``.
Args:
config(TrainConfig): a `TrainConfig` instance. config.dataflow must exist.
input_queue(tf.QueueBase): an input queue. Defaults to the
:class:`QueueInput` default.
"""
config.data = QueueInput(config.dataflow, input_queue)
if predict_tower is not None:
logger.warn("[Deprecated] Argument `predict_tower` is deprecated for trainer. "
"Use TrainConfig(predict_tower=...) instead!")
config.predict_tower = predict_tower
assert len(config.tower) == 1, \
"QueueInputTrainer doesn't support multigpu! Use Sync/AsyncMultiGPUTrainer instead."
super(QueueInputTrainer, self).__init__(config)
Args:
config(TrainConfig): a `TrainConfig` instance. config.dataflow must exist.
input_queue(tf.QueueBase): an input queue. Defaults to the
:class:`QueueInput` default.
"""
config.data = QueueInput(config.dataflow, input_queue)
if predict_tower is not None:
logger.warn("[Deprecated] Argument `predict_tower` is deprecated for trainer. "
"Use TrainConfig(predict_tower=...) instead!")
config.predict_tower = predict_tower
assert len(config.tower) == 1, \
"QueueInputTrainer doesn't support multigpu! Use Sync/AsyncMultiGPUTrainer instead."
return SimpleFeedfreeTrainer(config)
......@@ -10,11 +10,13 @@ import six
from ..dataflow import DataFlow, RepeatedData
from ..tfutils.summary import add_moving_summary
from ..tfutils import get_op_tensor_name
from ..utils import logger
from ..callbacks.concurrency import StartProcOrThread
__all__ = ['InputData', 'QueueInput', 'FeedfreeInput', 'TensorInput',
'DummyConstantInput']
__all__ = ['InputData', 'FeedfreeInput',
'QueueInput', 'BatchQueueInput',
'TensorInput', 'DummyConstantInput']
@six.add_metaclass(ABCMeta)
......@@ -90,9 +92,9 @@ class EnqueueThread(threading.Thread):
self.size_op, tf.float32, name='input_queue_size'))
def run(self):
self.dataflow.reset_state()
with self.sess.as_default():
try:
try:
self.dataflow.reset_state()
with self.sess.as_default():
while True:
for dp in self.dataflow.get_data():
if self.coord.should_stop():
......@@ -100,22 +102,23 @@ class EnqueueThread(threading.Thread):
feed = dict(zip(self.placehdrs, dp))
# print 'qsize:', self.sess.run([self.op, self.size_op], feed_dict=feed)[1]
self.op.run(feed_dict=feed)
except tf.errors.CancelledError:
except tf.errors.CancelledError:
pass
except Exception:
logger.exception("Exception in EnqueueThread:")
finally:
self.coord.request_stop()
try:
self.sess.run(self.close_op)
except RuntimeError: # session already closed
pass
except Exception:
logger.exception("Exception in EnqueueThread:")
finally:
self.coord.request_stop()
try:
self.sess.run(self.close_op)
except RuntimeError: # session already closed
pass
logger.info("Enqueue Thread Exited.")
logger.info("Enqueue Thread Exited.")
class QueueInput(FeedfreeInput):
""" Input by enqueueing datapoints from a DataFlow to a TF queue, and dequeue
tensors to the graph. """
""" Enqueue datapoints from a DataFlow to a TF queue.
And the model receives dequeued tensors.
"""
def __init__(self, ds, queue=None):
"""
......@@ -144,6 +147,7 @@ class QueueInput(FeedfreeInput):
def _get_input_tensors(self):
ret = self.queue.dequeue(name='input_deque')
print(ret)
if isinstance(ret, tf.Tensor): # only one input
ret = [ret]
assert len(ret) == len(self.input_placehdrs)
......@@ -158,6 +162,70 @@ class QueueInput(FeedfreeInput):
return ret
class BatchQueueInput(FeedfreeInput):
""" Enqueue datapoints from a DataFlow to a TF queue.
And the model receives batches formed by concatenating
dequeued tensors.
"""
def __init__(self, ds, batch_size, queue=None):
"""
Args:
ds(DataFlow): the input DataFlow.
batch_size(int): the batch size.
queue (tf.QueueBase): Defaults to a FIFO queue of size 3000.
"""
assert isinstance(ds, DataFlow), ds
self.queue = queue
self.ds = ds
self.batch_size = int(batch_size)
def size(self):
return self.ds.size() // self.batch_size
def _setup(self, trainer):
self.input_placehdrs = trainer.model.get_input_vars()
assert len(self.input_placehdrs) > 0, \
"QueueInput can only be used with input placeholders!"
# prepare placeholders without the first dimension
placehdrs_nobatch = []
for p in self.input_placehdrs:
placehdrs_nobatch.append(tf.placeholder(
dtype=p.dtype, shape=p.get_shape().as_list()[1:],
name=get_op_tensor_name(p.name)[0] + '-nobatch'))
# dequeue_many requires fully-defined shapes
shape_err = "Use of BatchQueueInput requires input variables to have fully-defined "
"shapes except for the batch dimension"
shapes = []
for p in placehdrs_nobatch:
assert p.get_shape().is_fully_defined(), shape_err
shapes.append(p.get_shape())
if self.queue is None:
self.queue = tf.FIFOQueue(
3000, [x.dtype for x in self.input_placehdrs],
shapes=shapes,
name='input_queue')
for shp in self.queue.shapes:
assert shp.is_fully_defined(), shape_err
self.thread = EnqueueThread(
trainer, self.queue, self.ds, placehdrs_nobatch)
trainer.config.callbacks.append(StartProcOrThread(self.thread))
def _get_input_tensors(self):
ret = self.queue.dequeue_many(self.batch_size, name='input_deque')
if isinstance(ret, tf.Tensor): # only one input
ret = [ret]
assert len(ret) == len(self.input_placehdrs)
for qv, v in zip(ret, self.input_placehdrs):
shp = v.get_shape().as_list()
shp[0] = self.batch_size
qv.set_shape(shp)
return ret
class DummyConstantInput(FeedfreeInput):
""" Input some constant variables. Only for debugging performance issues """
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment