Commit 00c47fa0 authored by Yuxin Wu's avatar Yuxin Wu

add batchqueueinput

parent c86cd15a
...@@ -67,10 +67,10 @@ class LMDBData(RNGDataFlow): ...@@ -67,10 +67,10 @@ class LMDBData(RNGDataFlow):
lmdb_path (str): a directory or a file. lmdb_path (str): a directory or a file.
shuffle (bool): shuffle the keys or not. shuffle (bool): shuffle the keys or not.
keys (list of str or str): list of str as the keys, used only when shuffle is True. keys (list of str or str): list of str as the keys, used only when shuffle is True.
It can also be a format string e.g. `'{:0>8d}'` which will be It can also be a format string e.g. ``{:0>8d}`` which will be
formatted with the indices from 0 to `total_size - 1`. formatted with the indices from 0 to *total_size - 1*.
If not provided, it will then look in the database for `__keys__` which If not provided, it will then look in the database for ``__keys__`` which
:func:`dump_dataflow_to_lmdb` used to store the list of keys. :func:`dump_dataflow_to_lmdb` used to store the list of keys.
If still not found, it will iterate over the database to find If still not found, it will iterate over the database to find
all the keys. all the keys.
...@@ -177,7 +177,7 @@ def CaffeLMDB(lmdb_path, shuffle=True, keys=None): ...@@ -177,7 +177,7 @@ def CaffeLMDB(lmdb_path, shuffle=True, keys=None):
a :class:`LMDBDataDecoder` instance. a :class:`LMDBDataDecoder` instance.
Example: Example:
ds = CaffeLMDB("/tmp/validation", keys='{:0>8d}') ``ds = CaffeLMDB("/tmp/validation", keys='{:0>8d}')``
""" """
cpb = get_caffe_pb() cpb = get_caffe_pb()
......
...@@ -105,26 +105,22 @@ class SimpleFeedfreeTrainer( ...@@ -105,26 +105,22 @@ class SimpleFeedfreeTrainer(
# self.train_op = tf.group(*self.dequed_inputs) # self.train_op = tf.group(*self.dequed_inputs)
class QueueInputTrainer(SimpleFeedfreeTrainer): def QueueInputTrainer(config, input_queue=None, predict_tower=None):
""" """
A trainer which automatically wraps ``config.dataflow`` by a A wrapper trainer which automatically wraps ``config.dataflow`` by a
:class:`QueueInput`. :class:`QueueInput`.
""" It is an equivalent of ``SimpleFeedfreeTrainer(config)`` with ``config.data = QueueInput(dataflow)``.
def __init__(self, config, input_queue=None, predict_tower=None):
"""
Single tower Trainer, takes input from a queue
Args: Args:
config(TrainConfig): a `TrainConfig` instance. config.dataflow must exist. config(TrainConfig): a `TrainConfig` instance. config.dataflow must exist.
input_queue(tf.QueueBase): an input queue. Defaults to the input_queue(tf.QueueBase): an input queue. Defaults to the
:class:`QueueInput` default. :class:`QueueInput` default.
""" """
config.data = QueueInput(config.dataflow, input_queue) config.data = QueueInput(config.dataflow, input_queue)
if predict_tower is not None: if predict_tower is not None:
logger.warn("[Deprecated] Argument `predict_tower` is deprecated for trainer. " logger.warn("[Deprecated] Argument `predict_tower` is deprecated for trainer. "
"Use TrainConfig(predict_tower=...) instead!") "Use TrainConfig(predict_tower=...) instead!")
config.predict_tower = predict_tower config.predict_tower = predict_tower
assert len(config.tower) == 1, \ assert len(config.tower) == 1, \
"QueueInputTrainer doesn't support multigpu! Use Sync/AsyncMultiGPUTrainer instead." "QueueInputTrainer doesn't support multigpu! Use Sync/AsyncMultiGPUTrainer instead."
super(QueueInputTrainer, self).__init__(config) return SimpleFeedfreeTrainer(config)
...@@ -10,11 +10,13 @@ import six ...@@ -10,11 +10,13 @@ import six
from ..dataflow import DataFlow, RepeatedData from ..dataflow import DataFlow, RepeatedData
from ..tfutils.summary import add_moving_summary from ..tfutils.summary import add_moving_summary
from ..tfutils import get_op_tensor_name
from ..utils import logger from ..utils import logger
from ..callbacks.concurrency import StartProcOrThread from ..callbacks.concurrency import StartProcOrThread
__all__ = ['InputData', 'QueueInput', 'FeedfreeInput', 'TensorInput', __all__ = ['InputData', 'FeedfreeInput',
'DummyConstantInput'] 'QueueInput', 'BatchQueueInput',
'TensorInput', 'DummyConstantInput']
@six.add_metaclass(ABCMeta) @six.add_metaclass(ABCMeta)
...@@ -90,9 +92,9 @@ class EnqueueThread(threading.Thread): ...@@ -90,9 +92,9 @@ class EnqueueThread(threading.Thread):
self.size_op, tf.float32, name='input_queue_size')) self.size_op, tf.float32, name='input_queue_size'))
def run(self): def run(self):
self.dataflow.reset_state() try:
with self.sess.as_default(): self.dataflow.reset_state()
try: with self.sess.as_default():
while True: while True:
for dp in self.dataflow.get_data(): for dp in self.dataflow.get_data():
if self.coord.should_stop(): if self.coord.should_stop():
...@@ -100,22 +102,23 @@ class EnqueueThread(threading.Thread): ...@@ -100,22 +102,23 @@ class EnqueueThread(threading.Thread):
feed = dict(zip(self.placehdrs, dp)) feed = dict(zip(self.placehdrs, dp))
# print 'qsize:', self.sess.run([self.op, self.size_op], feed_dict=feed)[1] # print 'qsize:', self.sess.run([self.op, self.size_op], feed_dict=feed)[1]
self.op.run(feed_dict=feed) self.op.run(feed_dict=feed)
except tf.errors.CancelledError: except tf.errors.CancelledError:
pass
except Exception:
logger.exception("Exception in EnqueueThread:")
finally:
self.coord.request_stop()
try:
self.sess.run(self.close_op)
except RuntimeError: # session already closed
pass pass
except Exception: logger.info("Enqueue Thread Exited.")
logger.exception("Exception in EnqueueThread:")
finally:
self.coord.request_stop()
try:
self.sess.run(self.close_op)
except RuntimeError: # session already closed
pass
logger.info("Enqueue Thread Exited.")
class QueueInput(FeedfreeInput): class QueueInput(FeedfreeInput):
""" Input by enqueueing datapoints from a DataFlow to a TF queue, and dequeue """ Enqueue datapoints from a DataFlow to a TF queue.
tensors to the graph. """ And the model receives dequeued tensors.
"""
def __init__(self, ds, queue=None): def __init__(self, ds, queue=None):
""" """
...@@ -144,6 +147,7 @@ class QueueInput(FeedfreeInput): ...@@ -144,6 +147,7 @@ class QueueInput(FeedfreeInput):
def _get_input_tensors(self): def _get_input_tensors(self):
ret = self.queue.dequeue(name='input_deque') ret = self.queue.dequeue(name='input_deque')
print(ret)
if isinstance(ret, tf.Tensor): # only one input if isinstance(ret, tf.Tensor): # only one input
ret = [ret] ret = [ret]
assert len(ret) == len(self.input_placehdrs) assert len(ret) == len(self.input_placehdrs)
...@@ -158,6 +162,70 @@ class QueueInput(FeedfreeInput): ...@@ -158,6 +162,70 @@ class QueueInput(FeedfreeInput):
return ret return ret
class BatchQueueInput(FeedfreeInput):
""" Enqueue datapoints from a DataFlow to a TF queue.
And the model receives batches formed by concatenating
dequeued tensors.
"""
def __init__(self, ds, batch_size, queue=None):
"""
Args:
ds(DataFlow): the input DataFlow.
batch_size(int): the batch size.
queue (tf.QueueBase): Defaults to a FIFO queue of size 3000.
"""
assert isinstance(ds, DataFlow), ds
self.queue = queue
self.ds = ds
self.batch_size = int(batch_size)
def size(self):
return self.ds.size() // self.batch_size
def _setup(self, trainer):
self.input_placehdrs = trainer.model.get_input_vars()
assert len(self.input_placehdrs) > 0, \
"QueueInput can only be used with input placeholders!"
# prepare placeholders without the first dimension
placehdrs_nobatch = []
for p in self.input_placehdrs:
placehdrs_nobatch.append(tf.placeholder(
dtype=p.dtype, shape=p.get_shape().as_list()[1:],
name=get_op_tensor_name(p.name)[0] + '-nobatch'))
# dequeue_many requires fully-defined shapes
shape_err = "Use of BatchQueueInput requires input variables to have fully-defined "
"shapes except for the batch dimension"
shapes = []
for p in placehdrs_nobatch:
assert p.get_shape().is_fully_defined(), shape_err
shapes.append(p.get_shape())
if self.queue is None:
self.queue = tf.FIFOQueue(
3000, [x.dtype for x in self.input_placehdrs],
shapes=shapes,
name='input_queue')
for shp in self.queue.shapes:
assert shp.is_fully_defined(), shape_err
self.thread = EnqueueThread(
trainer, self.queue, self.ds, placehdrs_nobatch)
trainer.config.callbacks.append(StartProcOrThread(self.thread))
def _get_input_tensors(self):
ret = self.queue.dequeue_many(self.batch_size, name='input_deque')
if isinstance(ret, tf.Tensor): # only one input
ret = [ret]
assert len(ret) == len(self.input_placehdrs)
for qv, v in zip(ret, self.input_placehdrs):
shp = v.get_shape().as_list()
shp[0] = self.batch_size
qv.set_shape(shp)
return ret
class DummyConstantInput(FeedfreeInput): class DummyConstantInput(FeedfreeInput):
""" Input some constant variables. Only for debugging performance issues """ """ Input some constant variables. Only for debugging performance issues """
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment