Commit 7a7295ee authored by Yuxin Wu's avatar Yuxin Wu

PrintData when reset & support names in QueueInput

parent 1175aade
...@@ -71,3 +71,6 @@ class ProxyDataFlow(DataFlow): ...@@ -71,3 +71,6 @@ class ProxyDataFlow(DataFlow):
def size(self): def size(self):
return self.ds.size() return self.ds.size()
def get_data(self):
return self.ds.get_data()
...@@ -5,6 +5,7 @@ ...@@ -5,6 +5,7 @@
from __future__ import division from __future__ import division
import numpy as np import numpy as np
from copy import copy from copy import copy
import itertools
from termcolor import colored from termcolor import colored
from collections import deque, defaultdict from collections import deque, defaultdict
from six.moves import range, map from six.moves import range, map
...@@ -623,7 +624,6 @@ class PrintData(ProxyDataFlow): ...@@ -623,7 +624,6 @@ class PrintData(ProxyDataFlow):
super(PrintData, self).__init__(ds) super(PrintData, self).__init__(ds)
self.num = num self.num = num
self.label = label self.label = label
self.print_info()
def _analyze_input_data(self, el, k, depth=1): def _analyze_input_data(self, el, k, depth=1):
""" """
...@@ -670,30 +670,8 @@ class PrintData(ProxyDataFlow): ...@@ -670,30 +670,8 @@ class PrintData(ProxyDataFlow):
""" """
Dump gathered debugging information to stdout. Dump gathered debugging information to stdout.
""" """
def cutoff(gen, num=1):
"""
Stop a generator after n iterations.
Args:
gen (PyGenObject): arbitrary generator
num (int, optional): number of maximal iterations
Yields:
element from generator object
"""
c = 0
for el in gen:
yield el
c += 1
if c == num:
break
ds = self.ds
ds.reset_state()
msg = [""] msg = [""]
for i, dummy in enumerate(cutoff(ds.get_data(), self.num)): for i, dummy in enumerate(itertools.islice(self.ds.get_data(), self.num)):
if isinstance(dummy, list): if isinstance(dummy, list):
msg.append("datapoint %i<%i with %i components consists of" % (i, self.num, len(dummy))) msg.append("datapoint %i<%i with %i components consists of" % (i, self.num, len(dummy)))
for k, entry in enumerate(dummy): for k, entry in enumerate(dummy):
...@@ -701,7 +679,9 @@ class PrintData(ProxyDataFlow): ...@@ -701,7 +679,9 @@ class PrintData(ProxyDataFlow):
label = "" if self.label is "" else " (" + self.label + ")" label = "" if self.label is "" else " (" + self.label + ")"
logger.info(colored("DataFlow Info%s:" % label, 'cyan') + '\n'.join(msg)) logger.info(colored("DataFlow Info%s:" % label, 'cyan') + '\n'.join(msg))
# reset again after print
self.ds.reset_state() self.ds.reset_state()
def get_data(self): def reset_state(self):
return self.ds.get_data() super(PrintData, self).reset_state()
self.print_info()
...@@ -119,10 +119,10 @@ def QueueInputTrainer(config, input_queue=None): ...@@ -119,10 +119,10 @@ def QueueInputTrainer(config, input_queue=None):
input_queue (tf.QueueBase): an input queue. Defaults to the input_queue (tf.QueueBase): an input queue. Defaults to the
:class:`QueueInput` default. :class:`QueueInput` default.
""" """
if config.dataflow is not None: if config.data is not None:
config.data = QueueInput(config.dataflow, input_queue)
else:
assert isinstance(config.data, QueueInput), config.data assert isinstance(config.data, QueueInput), config.data
else:
config.data = QueueInput(config.dataflow, input_queue)
# debug # debug
# from tensorpack.train.input_source import StagingInputWrapper, DummyConstantInput # from tensorpack.train.input_source import StagingInputWrapper, DummyConstantInput
......
...@@ -161,7 +161,7 @@ class FeedfreeInput(InputSource): ...@@ -161,7 +161,7 @@ class FeedfreeInput(InputSource):
e.g. by queue or other operations. """ e.g. by queue or other operations. """
def reset_state(self): def reset_state(self):
# TODO cannot reset # TODO no state to reset
pass pass
def next_feed(self): def next_feed(self):
...@@ -212,17 +212,19 @@ class QueueInput(FeedfreeInput): ...@@ -212,17 +212,19 @@ class QueueInput(FeedfreeInput):
And the model receives dequeued tensors. And the model receives dequeued tensors.
""" """
def __init__(self, ds, queue=None): def __init__(self, ds, queue=None, names=None):
""" """
Args: Args:
ds(DataFlow): the input DataFlow. ds(DataFlow): the input DataFlow.
queue (tf.QueueBase): A :class:`tf.QueueBase` whose type queue (tf.QueueBase): A :class:`tf.QueueBase` whose type
should match the corresponding InputDesc of the model. should match the corresponding InputDesc of the model.
Defaults to a FIFO queue of size 50. Defaults to a FIFO queue of size 50.
names(list[str]): list of input names corresponding to the dataflow.
""" """
assert isinstance(ds, DataFlow), ds assert isinstance(ds, DataFlow), ds
self.queue = queue self.queue = queue
self.ds = ds self.ds = ds
self._names = names
def size(self): def size(self):
return self.ds.size() return self.ds.size()
...@@ -231,13 +233,17 @@ class QueueInput(FeedfreeInput): ...@@ -231,13 +233,17 @@ class QueueInput(FeedfreeInput):
def setup(self, model): def setup(self, model):
logger.info("Setting up the queue for CPU prefetching ...") logger.info("Setting up the queue for CPU prefetching ...")
self.input_placehdrs = model.get_reused_placehdrs() self.input_placehdrs = model.get_reused_placehdrs()
assert len(self.input_placehdrs) > 0, \ if self._names is None:
"QueueInput has to be used with some InputDesc!" self._queue_feedpoint = self.input_placehdrs
else:
self._queue_feedpoint = get_placeholders_by_names(self.input_placehdrs, self._names)
assert len(self._queue_feedpoint) > 0, \
"QueueInput has to be used with some inputs!"
if self.queue is None: if self.queue is None:
self.queue = tf.FIFOQueue( self.queue = tf.FIFOQueue(
50, [x.dtype for x in self.input_placehdrs], 50, [x.dtype for x in self._queue_feedpoint],
name='input_queue') name='input_queue')
self.thread = EnqueueThread(self.queue, self.ds, self.input_placehdrs) self.thread = EnqueueThread(self.queue, self.ds, self._queue_feedpoint)
def setup_training(self, trainer): def setup_training(self, trainer):
super(QueueInput, self).setup_training(trainer) super(QueueInput, self).setup_training(trainer)
...@@ -250,10 +256,13 @@ class QueueInput(FeedfreeInput): ...@@ -250,10 +256,13 @@ class QueueInput(FeedfreeInput):
ret = self.queue.dequeue(name='input_deque') ret = self.queue.dequeue(name='input_deque')
if isinstance(ret, tf.Tensor): # only one input if isinstance(ret, tf.Tensor): # only one input
ret = [ret] ret = [ret]
assert len(ret) == len(self.input_placehdrs) assert len(ret) == len(self._queue_feedpoint)
for qv, v in zip(ret, self.input_placehdrs): for qv, v in zip(ret, self._queue_feedpoint):
qv.set_shape(v.get_shape()) qv.set_shape(v.get_shape())
return ret if self._names is None:
return ret
else:
return get_tensors_inputs(self.input_placehdrs, ret, self._names)
class BatchQueueInput(FeedfreeInput): class BatchQueueInput(FeedfreeInput):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment