Commit 59160bcf authored by Yuxin Wu's avatar Yuxin Wu

make zmq/dummy input subclass of TensorInput

parent ddebb23c
...@@ -324,46 +324,6 @@ class BatchQueueInput(FeedfreeInput): ...@@ -324,46 +324,6 @@ class BatchQueueInput(FeedfreeInput):
return ret return ret
class DummyConstantInput(FeedfreeInput):
""" Input with some random tensor placed on GPU.
Useful for debugging performance issues """
def __init__(self, shapes):
"""
Args:
shapes (list[list]): a list of fully-sepcified shapes.
"""
self.shapes = shapes
logger.warn("Using dummy input for debug!")
def setup(self, model):
self.input_placehdrs = model.get_reused_placehdrs()
def setup_training(self, trainer):
super(DummyConstantInput, self).setup_training(trainer)
nr_tower = trainer.config.nr_tower
placehdrs = self.input_placehdrs
assert len(self.shapes) == len(placehdrs)
self.tensors = []
# don't share variables
for tower in range(nr_tower):
tlist = []
with tf.device('/gpu:{}'.format(tower)):
for idx, p in enumerate(placehdrs):
tlist.append(tf.get_variable(
'dummy-{}-{}'.format(p.op.name, tower), shape=self.shapes[idx],
dtype=p.dtype, trainable=False))
self.tensors.append(tlist)
def get_input_tensors(self):
ctx = get_current_tower_context()
ret = self.tensors[ctx.index]
return ret
class TensorInput(FeedfreeInput): class TensorInput(FeedfreeInput):
""" Input from a list of tensors, e.g. a TF data reading pipeline. """ """ Input from a list of tensors, e.g. a TF data reading pipeline. """
...@@ -391,27 +351,53 @@ class TensorInput(FeedfreeInput): ...@@ -391,27 +351,53 @@ class TensorInput(FeedfreeInput):
return self.get_tensor_fn() return self.get_tensor_fn()
class ZMQInput(FeedfreeInput): class DummyConstantInput(TensorInput):
""" Input with some random tensor placed on GPU.
Useful for debugging performance issues """
def __init__(self, shapes):
"""
Args:
shapes (list[list]): a list of fully-sepcified shapes.
"""
self.shapes = shapes
logger.warn("Using dummy input for debug!")
def fn():
tlist = []
ctx = get_current_tower_context()
assert len(self.shapes) == len(self.input_placehdrs)
for idx, p in enumerate(self.input_placehdrs):
tlist.append(tf.get_variable(
'dummy-{}-{}'.format(p.op.name, ctx.index), shape=self.shapes[idx],
dtype=p.dtype, trainable=False))
return tlist
super(DummyConstantInput, self).__init__(fn)
def setup(self, model):
self.input_placehdrs = model.get_reused_placehdrs()
# TODO doesn't support remapping
class ZMQInput(TensorInput):
def __init__(self, endpoint): def __init__(self, endpoint):
self._endpoint = endpoint self._endpoint = endpoint
def size(self): from tensorpack.user_ops import zmq_recv
raise NotImplementedError()
def fn():
ret = zmq_recv(self._endpoint, [x.dtype for x in self.input_placehdrs])
if isinstance(ret, tf.Tensor):
ret = [ret]
assert len(ret) == len(self.input_placehdrs)
for qv, v in zip(ret, self.input_placehdrs):
qv.set_shape(v.get_shape())
return ret
super(ZMQInput, self).__init__(fn)
def setup(self, model): def setup(self, model):
self.input_placehdrs = model.get_reused_placehdrs() self.input_placehdrs = model.get_reused_placehdrs()
assert len(self.input_placehdrs) > 0, \ assert len(self.input_placehdrs) > 0, \
"ZMQInput has to be used with input placeholders!" "ZMQInput has to be used with InputDesc!"
def get_input_tensors(self):
from tensorpack.user_ops import zmq_recv
ret = zmq_recv(self._endpoint, [x.dtype for x in self.input_placehdrs])
if isinstance(ret, tf.Tensor):
ret = [ret]
assert len(ret) == len(self.input_placehdrs)
for qv, v in zip(ret, self.input_placehdrs):
qv.set_shape(v.get_shape())
return ret
class StagingInputWrapper(FeedfreeInput): class StagingInputWrapper(FeedfreeInput):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment