Commit d9b96535 authored by Yuxin Wu's avatar Yuxin Wu

Use hooks for feed_dict in FeedInput

parent 1e7fa5f9
...@@ -44,7 +44,9 @@ class Callback(object): ...@@ -44,7 +44,9 @@ class Callback(object):
self._steps_per_epoch = trainer.config.steps_per_epoch self._steps_per_epoch = trainer.config.steps_per_epoch
self.trainer = trainer self.trainer = trainer
self.graph = tf.get_default_graph() self.graph = tf.get_default_graph()
with tf.name_scope(type(self).__name__): scope_name = type(self).__name__
scope_name = scope_name.replace('_', '')
with tf.name_scope(scope_name):
self._setup_graph() self._setup_graph()
def _setup_graph(self): def _setup_graph(self):
......
...@@ -22,6 +22,7 @@ from ..predict import PredictorTowerBuilder ...@@ -22,6 +22,7 @@ from ..predict import PredictorTowerBuilder
from .base import Callback from .base import Callback
from .inference import Inferencer from .inference import Inferencer
from .hooks import CallbackToHook
__all__ = ['InferenceRunner', 'FeedfreeInferenceRunner', __all__ = ['InferenceRunner', 'FeedfreeInferenceRunner',
'DataParallelInferenceRunner'] 'DataParallelInferenceRunner']
...@@ -85,8 +86,6 @@ class InferenceRunnerBase(Callback): ...@@ -85,8 +86,6 @@ class InferenceRunnerBase(Callback):
def _setup_graph(self): def _setup_graph(self):
self._input_source.setup(self.trainer.model.get_inputs_desc()) self._input_source.setup(self.trainer.model.get_inputs_desc())
assert len(self._input_source.get_callbacks()) == 0, \
"InferenceRunner doesn't support any InputSource which requires callbacks!"
# Use predict_tower in train config. either gpuid or -1 # Use predict_tower in train config. either gpuid or -1
self._predict_tower_id = self.trainer.config.predict_tower[0] self._predict_tower_id = self.trainer.config.predict_tower[0]
...@@ -97,6 +96,8 @@ class InferenceRunnerBase(Callback): ...@@ -97,6 +96,8 @@ class InferenceRunnerBase(Callback):
PredictorTowerBuilder(fn, self._prefix).build(self._predict_tower_id) PredictorTowerBuilder(fn, self._prefix).build(self._predict_tower_id)
self._hooks = [self._build_hook(inf) for inf in self.infs] self._hooks = [self._build_hook(inf) for inf in self.infs]
cbs = self._input_source.get_callbacks()
self._hooks.extend([CallbackToHook(cb) for cb in cbs])
def _before_train(self): def _before_train(self):
self._hooks.extend(self._extra_hooks) self._hooks.extend(self._extra_hooks)
...@@ -118,8 +119,7 @@ class InferenceRunnerBase(Callback): ...@@ -118,8 +119,7 @@ class InferenceRunnerBase(Callback):
# iterate over the data, and run the hooked session # iterate over the data, and run the hooked session
self._input_source.reset_state() self._input_source.reset_state()
for _ in tqdm.trange(self._size, **get_tqdm_kwargs()): for _ in tqdm.trange(self._size, **get_tqdm_kwargs()):
feed = self._input_source.next_feed() self._hooked_sess.run(fetches=[])
self._hooked_sess.run(fetches=[], feed_dict=feed)
summary_inferencer(self.trainer, self.infs) summary_inferencer(self.trainer, self.infs)
...@@ -170,19 +170,17 @@ class FeedfreeInferenceRunner(InferenceRunnerBase): ...@@ -170,19 +170,17 @@ class FeedfreeInferenceRunner(InferenceRunnerBase):
placeholder_names = [k.name + ':0' for k in self.trainer.model.get_inputs_desc()] placeholder_names = [k.name + ':0' for k in self.trainer.model.get_inputs_desc()]
ret = [] ret = []
for name in out_names: for name in out_names:
if name not in placeholder_names: assert name not in placeholder_names, "Currently inferencer don't support fetching placeholders!"
ret.append(self._get_tensors_maybe_in_tower([name])[0]) ret.append(self._get_tensors_maybe_in_tower([name])[0])
else: # requesting an input
idx = placeholder_names.index(name)
ret.append(self._input_tensors[idx])
return InferencerToHook(inf, ret) return InferencerToHook(inf, ret)
# TODO completely broken now!
# TODO some scripts to test
class DataParallelInferenceRunner(InferenceRunnerBase): class DataParallelInferenceRunner(InferenceRunnerBase):
""" """
Not tested. Don't use. Broken. Don't use.
""" """
# TODO some scripts to test
def __init__(self, input, infs, gpus): def __init__(self, input, infs, gpus):
""" """
Args: Args:
...@@ -200,7 +198,7 @@ class DataParallelInferenceRunner(InferenceRunnerBase): ...@@ -200,7 +198,7 @@ class DataParallelInferenceRunner(InferenceRunnerBase):
model = self.trainer.model model = self.trainer.model
self._input_source.setup(model.get_inputs_desc()) self._input_source.setup(model.get_inputs_desc())
assert len(self._input_source.get_callbacks()) == 0, \ assert len(self._input_source.get_callbacks()) == 0, \
"InferenceRunner doesn't support any InputSource which requires callbacks!" "DataParallelInferenceRunner doesn't support any InputSource which requires callbacks!"
# build graph # build graph
def build_tower(k): def build_tower(k):
......
...@@ -80,17 +80,6 @@ class InputSource(object): ...@@ -80,17 +80,6 @@ class InputSource(object):
def _reset_state(self): def _reset_state(self):
pass pass
def next_feed(self):
"""
Returns:
a feed_dict of {Tensor: data}, to be used to run the steps
"""
return self._next_feed()
@abstractmethod
def _next_feed(self):
pass
def size(self): def size(self):
""" """
Returns: Returns:
...@@ -122,15 +111,28 @@ class ProxyInputSource(InputSource): ...@@ -122,15 +111,28 @@ class ProxyInputSource(InputSource):
def _size(self): def _size(self):
return self._input.size() return self._input.size()
def _next_feed(self):
return self._input.next_feed()
def _reset_state(self): def _reset_state(self):
self._input.reset_state() self._input.reset_state()
class FeedInput(InputSource): class FeedInput(InputSource):
""" Input by iterating over a DataFlow and feed datapoints. """ """ Input by iterating over a DataFlow and feed datapoints. """
class _FeedCallback(Callback):
def __init__(self, ds, placeholders):
self._ds = ds
self._itr = self._ds.get_data()
self._placeholders = placeholders
def _before_run(self, _):
dp = next(self._itr)
assert len(dp) == len(self._placeholders), "[FeedInput] datapoints and inputs are of different length!"
feed = dict(zip(self._placeholders, dp))
return tf.train.SessionRunArgs(fetches=[], feed_dict=feed)
def _reset(self):
self._ds.reset_state()
def __init__(self, ds): def __init__(self, ds):
""" """
Args: Args:
...@@ -138,28 +140,27 @@ class FeedInput(InputSource): ...@@ -138,28 +140,27 @@ class FeedInput(InputSource):
""" """
assert isinstance(ds, DataFlow), ds assert isinstance(ds, DataFlow), ds
self.ds = ds self.ds = ds
self._repeat_ds = RepeatedData(self.ds, -1)
def _size(self): def _size(self):
return self.ds.size() return self.ds.size()
def _setup(self, inputs): def _setup(self, inputs):
self._all_placehdrs = [v.build_placeholder_reuse() for v in inputs] self._all_placehdrs = [v.build_placeholder_reuse() for v in inputs]
self._cb = self._FeedCallback(self._repeat_ds, self._all_placehdrs)
self.reset_state() self.reset_state()
def _reset_state(self): def _reset_state(self):
rds = RepeatedData(self.ds, -1) self._cb._reset()
rds.reset_state()
self.data_producer = rds.get_data()
def _get_input_tensors(self): def _get_input_tensors(self):
return self._all_placehdrs return self._all_placehdrs
def _next_feed(self): def _get_callbacks(self):
dp = next(self.data_producer) return [self._cb]
assert len(dp) == len(self._all_placehdrs), "[FeedInput] datapoints and inputs are of different length!"
return dict(zip(self._all_placehdrs, dp))
# TODO completely broken now!
class DataParallelFeedInput(FeedInput): class DataParallelFeedInput(FeedInput):
""" """
Input by feeding k datapoints to k copies of placeholders located on k towers. Input by feeding k datapoints to k copies of placeholders located on k towers.
...@@ -182,7 +183,7 @@ class DataParallelFeedInput(FeedInput): ...@@ -182,7 +183,7 @@ class DataParallelFeedInput(FeedInput):
ctx = get_current_tower_context() ctx = get_current_tower_context()
return self._placehdrs_per_tower[ctx.index] return self._placehdrs_per_tower[ctx.index]
def _next_feed(self, cnt=None): def next_feed(self, cnt=None):
""" """
Args: Args:
cnt: how many towers to feed to. Defaults to the total number of towers cnt: how many towers to feed to. Defaults to the total number of towers
...@@ -204,9 +205,6 @@ class FeedfreeInput(InputSource): ...@@ -204,9 +205,6 @@ class FeedfreeInput(InputSource):
def _reset_state(self): def _reset_state(self):
pass pass
def _next_feed(self):
return {}
# TODO enqueu_many? https://github.com/tensorflow/tensorflow/issues/7817#issuecomment-282053155 # TODO enqueu_many? https://github.com/tensorflow/tensorflow/issues/7817#issuecomment-282053155
class EnqueueThread(ShareSessionThread): class EnqueueThread(ShareSessionThread):
......
...@@ -31,14 +31,14 @@ class SimpleTrainer(Trainer): ...@@ -31,14 +31,14 @@ class SimpleTrainer(Trainer):
def run_step(self): def run_step(self):
""" Feed data into the graph and run the updates. """ """ Feed data into the graph and run the updates. """
feed = self._input_source.next_feed() self.hooked_sess.run(self.train_op)
self.hooked_sess.run(self.train_op, feed_dict=feed)
def _setup(self): def _setup(self):
model = self.model model = self.model
self._input_source.setup(model.get_inputs_desc()) self._input_source.setup(model.get_inputs_desc())
cbs = self._input_source.get_callbacks() cbs = self._input_source.get_callbacks()
assert len(cbs) == 0, "Feedinput has no callbacks!" for cb in cbs:
self.register_callback(cb)
self.inputs = self._input_source.get_input_tensors() self.inputs = self._input_source.get_input_tensors()
with TowerContext('', is_training=True): with TowerContext('', is_training=True):
model.build_graph(self.inputs) model.build_graph(self.inputs)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment