Commit 9fff46d5 authored by Yuxin Wu's avatar Yuxin Wu

remove DataParallelFeedInput, as dataparallel inference now relies on QueueInput

parent 12cd6e6c
...@@ -23,7 +23,7 @@ from ..utils.develop import log_deprecated ...@@ -23,7 +23,7 @@ from ..utils.develop import log_deprecated
from ..callbacks.base import Callback from ..callbacks.base import Callback
from ..callbacks.graph import RunOp from ..callbacks.graph import RunOp
__all__ = ['PlaceholderInput', 'FeedInput', 'DataParallelFeedInput', __all__ = ['PlaceholderInput', 'FeedInput',
'FeedfreeInput', 'FeedfreeInput',
'QueueInput', 'BatchQueueInput', 'QueueInput', 'BatchQueueInput',
'DummyConstantInput', 'TensorInput', 'DummyConstantInput', 'TensorInput',
...@@ -99,64 +99,6 @@ class FeedInput(InputSource): ...@@ -99,64 +99,6 @@ class FeedInput(InputSource):
return [self._cb] return [self._cb]
class DataParallelFeedInput(FeedInput):
"""
Input by feeding k datapoints to k copies of placeholders located on k towers.
"""
class _DataParallelFeedCallback(Callback):
def __init__(self, ds, placeholders_per_tower):
self._ds = ds
self._itr = self._ds.get_data()
self._placehdrs_per_tower = placeholders_per_tower
self._nr_tower = len(self._placehdrs_per_tower)
def _reset(self):
self._ds.reset_state()
self._itr = self._ds.get_data()
def _before_run(self, _):
cnt = self._nr_tower
feed = {}
for t in range(cnt):
dp = next(self._itr)
f = dict(zip(self._placehdrs_per_tower[t], dp))
feed.update(f)
return tf.train.SessionRunArgs(fetches=[], feed_dict=feed)
def __init__(self, ds, tower_names):
super(DataParallelFeedInput, self).__init__(ds)
self._tower_names = tower_names
self._nr_tower = len(tower_names)
def _setup(self, inputs):
self._placehdrs_per_tower = []
for tname in self._tower_names:
# build a list of placeholders for each tower
self._placehdrs_per_tower.append(
[v.build_placeholder(prefix=tname + '/') for v in inputs])
self._cb = self._DataParallelFeedCallback(self._iter_ds, self._placehdrs_per_tower)
def _get_input_tensors(self):
# return placeholders for each tower
ctx = get_current_tower_context()
return self._placehdrs_per_tower[ctx.index]
def next_feed(self, cnt=1):
"""
Args:
cnt: how many towers to feed to.
"""
cnt = int(cnt)
assert cnt < self._nr_tower
feed = {}
for t in range(cnt):
dp = next(self._cb._itr)
f = dict(zip(self._placehdrs_per_tower[t], dp))
feed.update(f)
return feed
class FeedfreeInput(InputSource): class FeedfreeInput(InputSource):
""" Abstract base for input without feed, """ Abstract base for input without feed,
e.g. by queue or other operations. """ e.g. by queue or other operations. """
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment