Commit 9f4600a0 authored by Yuxin Wu's avatar Yuxin Wu

Convert zmq input to dataset

parent 07e464d8
...@@ -441,6 +441,31 @@ class ZMQInput(TensorInput): ...@@ -441,6 +441,31 @@ class ZMQInput(TensorInput):
hwm=self._hwm, hwm=self._hwm,
bind=self._bind) bind=self._bind)
def to_dataset(self, input_signature):
"""
Convert to a TF dataset.
Args:
input_signature (list[InputSpec]):
Returns:
tf.data.Dataset
"""
import zmq_ops
zmq_pull_socket = zmq_ops.ZMQPullSocket(
self._end_point, [x.dtype for x in input_signature],
hwm=self._hwm, bind=self._bind)
def mapper(_):
inputs = list(zmq_pull_socket.pull())
for v, sig in zip(inputs, input_signature):
v.set_shape(sig.shape)
return inputs
# Is there a better way to construct from stateful tensor?
dataset = tf.data.Dataset.from_tensors([1]) # just a placeholder
return dataset.map(mapper)
class TFDatasetInput(FeedfreeInput): class TFDatasetInput(FeedfreeInput):
""" """
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
# File: interface.py # File: interface.py
from ..compat import tfv1 from ..compat import tfv1
from ..input_source import DummyConstantInput, FeedfreeInput, FeedInput, InputSource, QueueInput, StagingInput from ..input_source import FeedInput, InputSource, QueueInput, StagingInput
from ..utils import logger from ..utils import logger
from ..compat import is_tfv2 from ..compat import is_tfv2
from .config import TrainConfig from .config import TrainConfig
...@@ -34,12 +34,10 @@ def apply_default_prefetch(input_source_or_dataflow, trainer): ...@@ -34,12 +34,10 @@ def apply_default_prefetch(input_source_or_dataflow, trainer):
input = input_source_or_dataflow input = input_source_or_dataflow
if hasattr(trainer, 'devices'): if hasattr(trainer, 'devices'):
towers = trainer.devices towers = trainer.devices
if len(towers) > 1: if len(towers) > 1: # seem to only help on >1 GPUs
# seem to only improve on >1 GPUs
assert not isinstance(trainer, SimpleTrainer) assert not isinstance(trainer, SimpleTrainer)
if isinstance(input, FeedfreeInput) and \ if isinstance(input, QueueInput):
not isinstance(input, (StagingInput, DummyConstantInput)):
logger.info("Automatically applying StagingInput on the DataFlow.") logger.info("Automatically applying StagingInput on the DataFlow.")
input = StagingInput(input) input = StagingInput(input)
return input return input
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment