Commit 9fff46d5 authored by Yuxin Wu's avatar Yuxin Wu

remove DataParallelFeedInput, as dataparallel inference now relies on QueueInput

parent 12cd6e6c
......@@ -23,7 +23,7 @@ from ..utils.develop import log_deprecated
from ..callbacks.base import Callback
from ..callbacks.graph import RunOp
__all__ = ['PlaceholderInput', 'FeedInput', 'DataParallelFeedInput',
__all__ = ['PlaceholderInput', 'FeedInput',
'FeedfreeInput',
'QueueInput', 'BatchQueueInput',
'DummyConstantInput', 'TensorInput',
......@@ -99,64 +99,6 @@ class FeedInput(InputSource):
return [self._cb]
class DataParallelFeedInput(FeedInput):
"""
Input by feeding k datapoints to k copies of placeholders located on k towers.
"""
class _DataParallelFeedCallback(Callback):
def __init__(self, ds, placeholders_per_tower):
self._ds = ds
self._itr = self._ds.get_data()
self._placehdrs_per_tower = placeholders_per_tower
self._nr_tower = len(self._placehdrs_per_tower)
def _reset(self):
self._ds.reset_state()
self._itr = self._ds.get_data()
def _before_run(self, _):
cnt = self._nr_tower
feed = {}
for t in range(cnt):
dp = next(self._itr)
f = dict(zip(self._placehdrs_per_tower[t], dp))
feed.update(f)
return tf.train.SessionRunArgs(fetches=[], feed_dict=feed)
def __init__(self, ds, tower_names):
super(DataParallelFeedInput, self).__init__(ds)
self._tower_names = tower_names
self._nr_tower = len(tower_names)
def _setup(self, inputs):
self._placehdrs_per_tower = []
for tname in self._tower_names:
# build a list of placeholders for each tower
self._placehdrs_per_tower.append(
[v.build_placeholder(prefix=tname + '/') for v in inputs])
self._cb = self._DataParallelFeedCallback(self._iter_ds, self._placehdrs_per_tower)
def _get_input_tensors(self):
# return placeholders for each tower
ctx = get_current_tower_context()
return self._placehdrs_per_tower[ctx.index]
def next_feed(self, cnt=1):
"""
Args:
cnt: how many towers to feed to.
"""
cnt = int(cnt)
assert cnt < self._nr_tower
feed = {}
for t in range(cnt):
dp = next(self._cb._itr)
f = dict(zip(self._placehdrs_per_tower[t], dp))
feed.update(f)
return feed
class FeedfreeInput(InputSource):
""" Abstract base for input without feed,
e.g. by queue or other operations. """
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment