Commit fa0d4dc6 authored by Yuxin Wu's avatar Yuxin Wu

Add dataflow_to_dataset (#397)

parent 985236c3
...@@ -14,7 +14,7 @@ from six.moves import range, zip ...@@ -14,7 +14,7 @@ from six.moves import range, zip
import threading import threading
from .input_source_base import InputSource from .input_source_base import InputSource
from ..dataflow import DataFlow, RepeatedData, DataFlowTerminated from ..dataflow import DataFlow, MapData, RepeatedData, DataFlowTerminated
from ..tfutils.summary import add_moving_summary from ..tfutils.summary import add_moving_summary
from ..tfutils.common import get_op_tensor_name from ..tfutils.common import get_op_tensor_name
from ..tfutils.tower import get_current_tower_context from ..tfutils.tower import get_current_tower_context
...@@ -437,7 +437,30 @@ class TFDatasetInput(FeedfreeInput): ...@@ -437,7 +437,30 @@ class TFDatasetInput(FeedfreeInput):
self._init_op.run() self._init_op.run()
def _get_input_tensors(self): def _get_input_tensors(self):
return self._iterator.get_next() desc_shapes = [k.shape for k in self._desc]
ret = self._iterator.get_next()
assert len(ret) == len(desc_shapes)
for t, shp in zip(ret, desc_shapes):
t.set_shape(shp)
return ret
@staticmethod
def dataflow_to_dataset(df, types):
"""
Wrap a dataflow to tf.data.Dataset.
Will reset df.
Args:
df (DataFlow)
types([tf.DType])
"""
assert isinstance(df, DataFlow), df
assert isinstance(types, (list, tuple)), types
df = MapData(df, lambda dp: tuple(dp))
df.reset_state()
ds = tf.data.Dataset.from_generator(
df.get_data, tuple(types))
return ds
class StagingInput(FeedfreeInput): class StagingInput(FeedfreeInput):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment