Commit 7a7295ee authored by Yuxin Wu's avatar Yuxin Wu

PrintData when reset & support names in QueueInput

parent 1175aade
......@@ -71,3 +71,6 @@ class ProxyDataFlow(DataFlow):
def size(self):
return self.ds.size()
def get_data(self):
return self.ds.get_data()
......@@ -5,6 +5,7 @@
from __future__ import division
import numpy as np
from copy import copy
import itertools
from termcolor import colored
from collections import deque, defaultdict
from six.moves import range, map
......@@ -623,7 +624,6 @@ class PrintData(ProxyDataFlow):
super(PrintData, self).__init__(ds)
self.num = num
self.label = label
self.print_info()
def _analyze_input_data(self, el, k, depth=1):
"""
......@@ -670,30 +670,8 @@ class PrintData(ProxyDataFlow):
"""
Dump gathered debugging information to stdout.
"""
def cutoff(gen, num=1):
"""
Stop a generator after n iterations.
Args:
gen (PyGenObject): arbitrary generator
num (int, optional): number of maximal iterations
Yields:
element from generator object
"""
c = 0
for el in gen:
yield el
c += 1
if c == num:
break
ds = self.ds
ds.reset_state()
msg = [""]
for i, dummy in enumerate(cutoff(ds.get_data(), self.num)):
for i, dummy in enumerate(itertools.islice(self.ds.get_data(), self.num)):
if isinstance(dummy, list):
msg.append("datapoint %i<%i with %i components consists of" % (i, self.num, len(dummy)))
for k, entry in enumerate(dummy):
......@@ -701,7 +679,9 @@ class PrintData(ProxyDataFlow):
label = "" if self.label is "" else " (" + self.label + ")"
logger.info(colored("DataFlow Info%s:" % label, 'cyan') + '\n'.join(msg))
# reset again after print
self.ds.reset_state()
def get_data(self):
return self.ds.get_data()
def reset_state(self):
super(PrintData, self).reset_state()
self.print_info()
......@@ -119,10 +119,10 @@ def QueueInputTrainer(config, input_queue=None):
input_queue (tf.QueueBase): an input queue. Defaults to the
:class:`QueueInput` default.
"""
if config.dataflow is not None:
config.data = QueueInput(config.dataflow, input_queue)
else:
if config.data is not None:
assert isinstance(config.data, QueueInput), config.data
else:
config.data = QueueInput(config.dataflow, input_queue)
# debug
# from tensorpack.train.input_source import StagingInputWrapper, DummyConstantInput
......
......@@ -161,7 +161,7 @@ class FeedfreeInput(InputSource):
e.g. by queue or other operations. """
def reset_state(self):
# TODO cannot reset
# TODO no state to reset
pass
def next_feed(self):
......@@ -212,17 +212,19 @@ class QueueInput(FeedfreeInput):
And the model receives dequeued tensors.
"""
def __init__(self, ds, queue=None):
def __init__(self, ds, queue=None, names=None):
"""
Args:
ds(DataFlow): the input DataFlow.
queue (tf.QueueBase): A :class:`tf.QueueBase` whose type
should match the corresponding InputDesc of the model.
Defaults to a FIFO queue of size 50.
names(list[str]): list of input names corresponding to the dataflow.
"""
assert isinstance(ds, DataFlow), ds
self.queue = queue
self.ds = ds
self._names = names
def size(self):
return self.ds.size()
......@@ -231,13 +233,17 @@ class QueueInput(FeedfreeInput):
def setup(self, model):
logger.info("Setting up the queue for CPU prefetching ...")
self.input_placehdrs = model.get_reused_placehdrs()
assert len(self.input_placehdrs) > 0, \
"QueueInput has to be used with some InputDesc!"
if self._names is None:
self._queue_feedpoint = self.input_placehdrs
else:
self._queue_feedpoint = get_placeholders_by_names(self.input_placehdrs, self._names)
assert len(self._queue_feedpoint) > 0, \
"QueueInput has to be used with some inputs!"
if self.queue is None:
self.queue = tf.FIFOQueue(
50, [x.dtype for x in self.input_placehdrs],
50, [x.dtype for x in self._queue_feedpoint],
name='input_queue')
self.thread = EnqueueThread(self.queue, self.ds, self.input_placehdrs)
self.thread = EnqueueThread(self.queue, self.ds, self._queue_feedpoint)
def setup_training(self, trainer):
super(QueueInput, self).setup_training(trainer)
......@@ -250,10 +256,13 @@ class QueueInput(FeedfreeInput):
ret = self.queue.dequeue(name='input_deque')
if isinstance(ret, tf.Tensor): # only one input
ret = [ret]
assert len(ret) == len(self.input_placehdrs)
for qv, v in zip(ret, self.input_placehdrs):
assert len(ret) == len(self._queue_feedpoint)
for qv, v in zip(ret, self._queue_feedpoint):
qv.set_shape(v.get_shape())
if self._names is None:
return ret
else:
return get_tensors_inputs(self.input_placehdrs, ret, self._names)
class BatchQueueInput(FeedfreeInput):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment