Commit d9b96535 authored by Yuxin Wu's avatar Yuxin Wu

Use hooks for feed_dict in FeedInput

parent 1e7fa5f9
......@@ -44,7 +44,9 @@ class Callback(object):
self._steps_per_epoch = trainer.config.steps_per_epoch
self.trainer = trainer
self.graph = tf.get_default_graph()
with tf.name_scope(type(self).__name__):
scope_name = type(self).__name__
scope_name = scope_name.replace('_', '')
with tf.name_scope(scope_name):
self._setup_graph()
def _setup_graph(self):
......
......@@ -22,6 +22,7 @@ from ..predict import PredictorTowerBuilder
from .base import Callback
from .inference import Inferencer
from .hooks import CallbackToHook
__all__ = ['InferenceRunner', 'FeedfreeInferenceRunner',
'DataParallelInferenceRunner']
......@@ -85,8 +86,6 @@ class InferenceRunnerBase(Callback):
def _setup_graph(self):
self._input_source.setup(self.trainer.model.get_inputs_desc())
assert len(self._input_source.get_callbacks()) == 0, \
"InferenceRunner doesn't support any InputSource which requires callbacks!"
# Use predict_tower in train config. either gpuid or -1
self._predict_tower_id = self.trainer.config.predict_tower[0]
......@@ -97,6 +96,8 @@ class InferenceRunnerBase(Callback):
PredictorTowerBuilder(fn, self._prefix).build(self._predict_tower_id)
self._hooks = [self._build_hook(inf) for inf in self.infs]
cbs = self._input_source.get_callbacks()
self._hooks.extend([CallbackToHook(cb) for cb in cbs])
def _before_train(self):
self._hooks.extend(self._extra_hooks)
......@@ -118,8 +119,7 @@ class InferenceRunnerBase(Callback):
# iterate over the data, and run the hooked session
self._input_source.reset_state()
for _ in tqdm.trange(self._size, **get_tqdm_kwargs()):
feed = self._input_source.next_feed()
self._hooked_sess.run(fetches=[], feed_dict=feed)
self._hooked_sess.run(fetches=[])
summary_inferencer(self.trainer, self.infs)
......@@ -170,19 +170,17 @@ class FeedfreeInferenceRunner(InferenceRunnerBase):
placeholder_names = [k.name + ':0' for k in self.trainer.model.get_inputs_desc()]
ret = []
for name in out_names:
if name not in placeholder_names:
assert name not in placeholder_names, "Currently inferencer don't support fetching placeholders!"
ret.append(self._get_tensors_maybe_in_tower([name])[0])
else: # requesting an input
idx = placeholder_names.index(name)
ret.append(self._input_tensors[idx])
return InferencerToHook(inf, ret)
# TODO completely broken now!
# TODO some scripts to test
class DataParallelInferenceRunner(InferenceRunnerBase):
"""
Not tested. Don't use.
Broken. Don't use.
"""
# TODO some scripts to test
def __init__(self, input, infs, gpus):
"""
Args:
......@@ -200,7 +198,7 @@ class DataParallelInferenceRunner(InferenceRunnerBase):
model = self.trainer.model
self._input_source.setup(model.get_inputs_desc())
assert len(self._input_source.get_callbacks()) == 0, \
"InferenceRunner doesn't support any InputSource which requires callbacks!"
"DataParallelInferenceRunner doesn't support any InputSource which requires callbacks!"
# build graph
def build_tower(k):
......
......@@ -80,17 +80,6 @@ class InputSource(object):
def _reset_state(self):
pass
def next_feed(self):
"""
Returns:
a feed_dict of {Tensor: data}, to be used to run the steps
"""
return self._next_feed()
@abstractmethod
def _next_feed(self):
pass
def size(self):
"""
Returns:
......@@ -122,15 +111,28 @@ class ProxyInputSource(InputSource):
def _size(self):
return self._input.size()
def _next_feed(self):
return self._input.next_feed()
def _reset_state(self):
self._input.reset_state()
class FeedInput(InputSource):
""" Input by iterating over a DataFlow and feed datapoints. """
class _FeedCallback(Callback):
def __init__(self, ds, placeholders):
self._ds = ds
self._itr = self._ds.get_data()
self._placeholders = placeholders
def _before_run(self, _):
dp = next(self._itr)
assert len(dp) == len(self._placeholders), "[FeedInput] datapoints and inputs are of different length!"
feed = dict(zip(self._placeholders, dp))
return tf.train.SessionRunArgs(fetches=[], feed_dict=feed)
def _reset(self):
self._ds.reset_state()
def __init__(self, ds):
"""
Args:
......@@ -138,28 +140,27 @@ class FeedInput(InputSource):
"""
assert isinstance(ds, DataFlow), ds
self.ds = ds
self._repeat_ds = RepeatedData(self.ds, -1)
def _size(self):
return self.ds.size()
def _setup(self, inputs):
self._all_placehdrs = [v.build_placeholder_reuse() for v in inputs]
self._cb = self._FeedCallback(self._repeat_ds, self._all_placehdrs)
self.reset_state()
def _reset_state(self):
rds = RepeatedData(self.ds, -1)
rds.reset_state()
self.data_producer = rds.get_data()
self._cb._reset()
def _get_input_tensors(self):
return self._all_placehdrs
def _next_feed(self):
dp = next(self.data_producer)
assert len(dp) == len(self._all_placehdrs), "[FeedInput] datapoints and inputs are of different length!"
return dict(zip(self._all_placehdrs, dp))
def _get_callbacks(self):
return [self._cb]
# TODO completely broken now!
class DataParallelFeedInput(FeedInput):
"""
Input by feeding k datapoints to k copies of placeholders located on k towers.
......@@ -182,7 +183,7 @@ class DataParallelFeedInput(FeedInput):
ctx = get_current_tower_context()
return self._placehdrs_per_tower[ctx.index]
def _next_feed(self, cnt=None):
def next_feed(self, cnt=None):
"""
Args:
cnt: how many towers to feed to. Defaults to the total number of towers
......@@ -204,9 +205,6 @@ class FeedfreeInput(InputSource):
def _reset_state(self):
pass
def _next_feed(self):
return {}
# TODO enqueu_many? https://github.com/tensorflow/tensorflow/issues/7817#issuecomment-282053155
class EnqueueThread(ShareSessionThread):
......
......@@ -31,14 +31,14 @@ class SimpleTrainer(Trainer):
def run_step(self):
""" Feed data into the graph and run the updates. """
feed = self._input_source.next_feed()
self.hooked_sess.run(self.train_op, feed_dict=feed)
self.hooked_sess.run(self.train_op)
def _setup(self):
model = self.model
self._input_source.setup(model.get_inputs_desc())
cbs = self._input_source.get_callbacks()
assert len(cbs) == 0, "Feedinput has no callbacks!"
for cb in cbs:
self.register_callback(cb)
self.inputs = self._input_source.get_input_tensors()
with TowerContext('', is_training=True):
model.build_graph(self.inputs)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment