Commit 9f4600a0 authored by Yuxin Wu's avatar Yuxin Wu

Convert zmq input to dataset

parent 07e464d8
......@@ -441,6 +441,31 @@ class ZMQInput(TensorInput):
hwm=self._hwm,
bind=self._bind)
def to_dataset(self, input_signature):
"""
Convert to a TF dataset.
Args:
input_signature (list[InputSpec]):
Returns:
tf.data.Dataset
"""
import zmq_ops
zmq_pull_socket = zmq_ops.ZMQPullSocket(
self._end_point, [x.dtype for x in input_signature],
hwm=self._hwm, bind=self._bind)
def mapper(_):
inputs = list(zmq_pull_socket.pull())
for v, sig in zip(inputs, input_signature):
v.set_shape(sig.shape)
return inputs
# Is there a better way to construct from stateful tensor?
dataset = tf.data.Dataset.from_tensors([1]) # just a placeholder
return dataset.map(mapper)
class TFDatasetInput(FeedfreeInput):
"""
......
......@@ -2,7 +2,7 @@
# File: interface.py
from ..compat import tfv1
from ..input_source import DummyConstantInput, FeedfreeInput, FeedInput, InputSource, QueueInput, StagingInput
from ..input_source import FeedInput, InputSource, QueueInput, StagingInput
from ..utils import logger
from ..compat import is_tfv2
from .config import TrainConfig
......@@ -34,12 +34,10 @@ def apply_default_prefetch(input_source_or_dataflow, trainer):
input = input_source_or_dataflow
if hasattr(trainer, 'devices'):
towers = trainer.devices
if len(towers) > 1:
# seem to only improve on >1 GPUs
if len(towers) > 1: # seem to only help on >1 GPUs
assert not isinstance(trainer, SimpleTrainer)
if isinstance(input, FeedfreeInput) and \
not isinstance(input, (StagingInput, DummyConstantInput)):
if isinstance(input, QueueInput):
logger.info("Automatically applying StagingInput on the DataFlow.")
input = StagingInput(input)
return input
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment