Commit 3700a803 authored by Yuxin Wu's avatar Yuxin Wu

Make FeedInput & QueueInput support dict-based dataflow (#768)

parent 209da29e
...@@ -36,6 +36,18 @@ def _get_reset_callback(df): ...@@ -36,6 +36,18 @@ def _get_reset_callback(df):
return CallbackFactory(setup_graph=lambda _: df.reset_state()) return CallbackFactory(setup_graph=lambda _: df.reset_state())
def _make_feeds(placeholders, datapoint):
assert len(datapoint) == len(placeholders), \
"Size of datapoint and placeholders are different: {} != {}".format(
len(datapoint), len(placeholders))
if isinstance(datapoint, (list, tuple)):
return dict(zip(placeholders, datapoint))
elif isinstance(datapoint, dict):
ret = {p: datapoint[p.op.name] for p in placeholders}
return ret
class PlaceholderInput(InputSource): class PlaceholderInput(InputSource):
""" """
Just produce placeholders as input tensors. Just produce placeholders as input tensors.
...@@ -69,7 +81,7 @@ class FeedInput(InputSource): ...@@ -69,7 +81,7 @@ class FeedInput(InputSource):
def _before_run(self, _): def _before_run(self, _):
dp = next(self._itr) dp = next(self._itr)
assert len(dp) == len(self._placeholders), "[FeedInput] datapoints and inputs are of different length!" assert len(dp) == len(self._placeholders), "[FeedInput] datapoints and inputs are of different length!"
feed = dict(zip(self._placeholders, dp)) feed = _make_feeds(self._placeholders, dp)
return tf.train.SessionRunArgs(fetches=[], feed_dict=feed) return tf.train.SessionRunArgs(fetches=[], feed_dict=feed)
def _reset(self): def _reset(self):
...@@ -142,7 +154,7 @@ class EnqueueThread(ShareSessionThread): ...@@ -142,7 +154,7 @@ class EnqueueThread(ShareSessionThread):
self._running.wait() self._running.wait()
dp = next(self._itr) dp = next(self._itr)
feed = dict(zip(self.placehdrs, dp)) feed = _make_feeds(self.placehdrs, dp)
# _, sz = sess.run([self.op, self._sz], feed_dict=feed) # _, sz = sess.run([self.op, self._sz], feed_dict=feed)
self.op.run(feed_dict=feed) self.op.run(feed_dict=feed)
except (tf.errors.CancelledError, tf.errors.OutOfRangeError, DataFlowTerminated): except (tf.errors.CancelledError, tf.errors.OutOfRangeError, DataFlowTerminated):
...@@ -473,12 +485,13 @@ class TFDatasetInput(FeedfreeInput): ...@@ -473,12 +485,13 @@ class TFDatasetInput(FeedfreeInput):
dataset, if the dataflow iterator can terminate. dataset, if the dataflow iterator can terminate.
Args: Args:
df (DataFlow) df (DataFlow): a dataflow which produces lists
types([tf.DType]) types([tf.DType])
Returns: Returns:
(tf.data.Dataset) (tf.data.Dataset)
""" """
# TODO theoretically it can support dict
assert isinstance(df, DataFlow), df assert isinstance(df, DataFlow), df
assert isinstance(types, (list, tuple)), types assert isinstance(types, (list, tuple)), types
df = MapData(df, lambda dp: tuple(dp)) df = MapData(df, lambda dp: tuple(dp))
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment