Commit e46e6bca authored by Yuxin Wu's avatar Yuxin Wu

make DataParallelFeedInput runnable again

parent d9b96535
......@@ -175,16 +175,18 @@ class FeedfreeInferenceRunner(InferenceRunnerBase):
return InferencerToHook(inf, ret)
# TODO completely broken now!
# TODO some scripts to test
class DataParallelInferenceRunner(InferenceRunnerBase):
"""
Broken. Don't use.
Inference by feeding datapoints in a data-parallel way to multiple GPUs.
Doesn't support remapped InputSource for now.
"""
def __init__(self, input, infs, gpus):
"""
Args:
input (DataParallelFeedInput or DataFlow)
gpus (list[int]): list of GPU id
"""
if isinstance(input, DataFlow):
tower_names = [TowerContext.get_predict_tower_name(k) for k in range(len(gpus))]
......@@ -197,8 +199,6 @@ class DataParallelInferenceRunner(InferenceRunnerBase):
def _setup_graph(self):
model = self.trainer.model
self._input_source.setup(model.get_inputs_desc())
assert len(self._input_source.get_callbacks()) == 0, \
"DataParallelInferenceRunner doesn't support any InputSource which requires callbacks!"
# build graph
def build_tower(k):
......@@ -214,6 +214,8 @@ class DataParallelInferenceRunner(InferenceRunnerBase):
# setup feeds and hooks
self._hooks_parallel = [self._build_hook_parallel(inf) for inf in self.infs]
self._hooks = [self._build_hook(inf) for inf in self.infs]
cbs = self._input_source.get_callbacks()
self._hooks_parallel.extend([CallbackToHook(cb) for cb in cbs])
def _duplicate_names_across_towers(self, names):
ret = []
......@@ -262,15 +264,19 @@ class DataParallelInferenceRunner(InferenceRunnerBase):
nr_tower = len(self._gpus)
with tqdm.tqdm(total=total, **get_tqdm_kwargs()) as pbar:
while total >= nr_tower:
feed = self._input_source.next_feed()
self._parallel_hooked_sess.run(fetches=[], feed_dict=feed)
self._parallel_hooked_sess.run(fetches=[])
pbar.update(nr_tower)
total -= nr_tower
# take care of the rest
try:
while total > 0:
# TODO XXX doesn't support remap
feed = self._input_source._next_feed(cnt=1)
feed = self._input_source.next_feed(cnt=1)
self._hooked_sess.run(fetches=[], feed_dict=feed)
pbar.update(1)
total -= 1
except AttributeError:
logger.error(
"[DataParallelInferenceRunner] doesn't support InputSource wrappers very well!")
logger.error("[DataParallelInferenceRunner] Skipping the rest of the datapoints ...")
summary_inferencer(self.trainer, self.infs)
......@@ -27,7 +27,7 @@ class MovingAverageSummary(Callback):
def _setup_graph(self):
ops = tf.get_collection(self._collection)
logger.info("Maintain moving averages of {} ops.".format(len(ops)))
logger.info("Maintain moving averages of {} tensors.".format(len(ops)))
self.ema_op = tf.group(*ops, name='summary_moving_averages')
self._fetch = tf.train.SessionRunArgs(fetches=self.ema_op)
......
......@@ -132,6 +132,7 @@ class FeedInput(InputSource):
def _reset(self):
self._ds.reset_state()
self._itr = self._ds.get_data()
def __init__(self, ds):
"""
......@@ -160,11 +161,31 @@ class FeedInput(InputSource):
return [self._cb]
# TODO completely broken now!
class DataParallelFeedInput(FeedInput):
"""
Input by feeding k datapoints to k copies of placeholders located on k towers.
"""
class _DataParallelFeedCallback(Callback):
def __init__(self, ds, placeholders_per_tower):
self._ds = ds
self._itr = self._ds.get_data()
self._placehdrs_per_tower = placeholders_per_tower
self._nr_tower = len(self._placehdrs_per_tower)
def _reset(self):
self._ds.reset_state()
self._itr = self._ds.get_data()
def _before_run(self, _):
cnt = self._nr_tower
feed = {}
for t in range(cnt):
dp = next(self._itr)
f = dict(zip(self._placehdrs_per_tower[t], dp))
feed.update(f)
return tf.train.SessionRunArgs(fetches=[], feed_dict=feed)
def __init__(self, ds, tower_names):
super(DataParallelFeedInput, self).__init__(ds)
self._tower_names = tower_names
......@@ -176,6 +197,7 @@ class DataParallelFeedInput(FeedInput):
# build a list of placeholders for each tower
self._placehdrs_per_tower.append(
[v.build_placeholder(prefix=tname + '/') for v in inputs])
self._cb = self._DataParallelFeedCallback(self._repeat_ds, self._placehdrs_per_tower)
self.reset_state()
def _get_input_tensors(self):
......@@ -183,16 +205,16 @@ class DataParallelFeedInput(FeedInput):
ctx = get_current_tower_context()
return self._placehdrs_per_tower[ctx.index]
def next_feed(self, cnt=None):
def next_feed(self, cnt=1):
"""
Args:
cnt: how many towers to feed to. Defaults to the total number of towers
cnt: how many towers to feed to.
"""
if cnt is None:
cnt = self._nr_tower
cnt = int(cnt)
assert cnt < self._nr_tower
feed = {}
for t in range(cnt):
dp = next(self.data_producer)
dp = next(self._cb._itr)
f = dict(zip(self._placehdrs_per_tower[t], dp))
feed.update(f)
return feed
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment