Commit 76fa8e38 authored by Yuxin Wu's avatar Yuxin Wu

Simplify inference_runner: 1. move input_names mapping to InputSource 2. add DataParallelFeedInput

parent 48f6c267
...@@ -10,14 +10,14 @@ from tensorflow.python.training.monitored_session \ ...@@ -10,14 +10,14 @@ from tensorflow.python.training.monitored_session \
from abc import ABCMeta, abstractmethod from abc import ABCMeta, abstractmethod
import tqdm import tqdm
import six import six
import copy from six.moves import range
from six.moves import zip
from ..utils import logger, get_tqdm_kwargs from ..utils import logger, get_tqdm_kwargs
from ..dataflow import DataFlow from ..dataflow import DataFlow
from ..tfutils.common import get_op_tensor_name, get_tensors_by_names from ..tfutils.common import get_tensors_by_names
from ..tfutils.tower import TowerContext from ..tfutils.tower import TowerContext
from ..train.input_source import TensorInput, FeedInput from ..train.input_source import TensorInput, FeedInput, DataParallelFeedInput
from ..train.utils import get_tensors_inputs
from ..predict import PredictorTowerBuilder from ..predict import PredictorTowerBuilder
from .base import Callback from .base import Callback
...@@ -60,11 +60,12 @@ class InferenceRunnerBase(Callback): ...@@ -60,11 +60,12 @@ class InferenceRunnerBase(Callback):
""" """
Args: Args:
input (InputSource): the input to use. Must have ``size()``. input (InputSource): the input to use. Must have ``size()``.
infs (list): list of :class:`Inferencer` to run. infs (list[Inferencer]): list of :class:`Inferencer` to run.
input_names (list): must be a subset of the names in InputDesc. input_names (list[str]): list of names to match ``input``, must be a subset of the names in
InputDesc of the model. Defaults to be all the inputs of the model.
prefix(str): an prefix used to build the tower. Must be set prefix(str): an prefix used to build the tower. Must be set
differently if more than one :class:`InferenceRunner` are used. differently if more than one :class:`InferenceRunner` are used.
extra_hooks (list): extra ``SessionRunHook`` to run with the evaluation. extra_hooks (list[SessionRunHook]): extra :class:`SessionRunHook` to run with the evaluation.
""" """
self._input_source = input self._input_source = input
if not isinstance(infs, list): if not isinstance(infs, list):
...@@ -87,33 +88,17 @@ class InferenceRunnerBase(Callback): ...@@ -87,33 +88,17 @@ class InferenceRunnerBase(Callback):
extra_hooks = [] extra_hooks = []
self._extra_hooks = extra_hooks self._extra_hooks = extra_hooks
def _setup_input_names(self):
# just use all the placeholders, if input_name is None
if self.input_names is None:
inputs = self.trainer.model.get_reused_placehdrs()
self.input_names = [x.name for x in inputs]
# TODO sparse. even if it works here, sparse still is unavailable
# because get_tensor_by_name doesn't work for sparse
# def get_name(x):
# if isinstance(x, tf.SparseTensor):
# return x.op.name.split('/')[0]
# return x.name
def _setup_graph(self): def _setup_graph(self):
self._input_source.setup(self.trainer.model) self._input_source.setup(self.trainer.model)
self._setup_input_names()
# Use predict_tower in train config. either gpuid or -1 # Use predict_tower in train config. either gpuid or -1
self._predict_tower_id = self.trainer.config.predict_tower[0] self._predict_tower_id = self.trainer.config.predict_tower[0]
in_tensors = self._find_input_tensors()
assert isinstance(in_tensors, list), in_tensors
def fn(_): def fn(_):
in_tensors = self._find_input_tensors()
assert isinstance(in_tensors, list), in_tensors
self.trainer.model.build_graph(in_tensors) self.trainer.model.build_graph(in_tensors)
PredictorTowerBuilder(fn, self._prefix).build(self._predict_tower_id) PredictorTowerBuilder(fn, self._prefix).build(self._predict_tower_id)
self._feed_tensors = self._find_feed_tensors()
self._hooks = [self._build_hook(inf) for inf in self.infs] self._hooks = [self._build_hook(inf) for inf in self.infs]
def _before_train(self): def _before_train(self):
...@@ -128,10 +113,6 @@ class InferenceRunnerBase(Callback): ...@@ -128,10 +113,6 @@ class InferenceRunnerBase(Callback):
def _find_input_tensors(self): def _find_input_tensors(self):
pass pass
@abstractmethod
def _find_feed_tensors(self):
pass
@abstractmethod @abstractmethod
def _build_hook(self, inf): def _build_hook(self, inf):
pass pass
...@@ -143,8 +124,7 @@ class InferenceRunnerBase(Callback): ...@@ -143,8 +124,7 @@ class InferenceRunnerBase(Callback):
# iterate over the data, and run the hooked session # iterate over the data, and run the hooked session
self._input_source.reset_state() self._input_source.reset_state()
for _ in tqdm.trange(self._input_source.size(), **get_tqdm_kwargs()): for _ in tqdm.trange(self._input_source.size(), **get_tqdm_kwargs()):
dp = self._input_source.next_feed() feed = self._input_source.next_feed()
feed = dict(zip(self._feed_tensors, dp))
self._hooked_sess.run(fetches=[], feed_dict=feed) self._hooked_sess.run(fetches=[], feed_dict=feed)
summary_inferencer(self.trainer, self.infs) summary_inferencer(self.trainer, self.infs)
...@@ -160,19 +140,15 @@ class InferenceRunner(InferenceRunnerBase): ...@@ -160,19 +140,15 @@ class InferenceRunner(InferenceRunnerBase):
Args: Args:
ds (DataFlow): the DataFlow to run inferencer on. ds (DataFlow): the DataFlow to run inferencer on.
infs (list): a list of `Inferencer` instances. infs (list): a list of `Inferencer` instances.
input_names(list): list of tensors to feed the dataflow to. input_names (list[str]): same as in :class:`InferenceRunnerBase`.
Defaults to all the input placeholders.
""" """
assert isinstance(ds, DataFlow), ds assert isinstance(ds, DataFlow), ds
input = FeedInput(ds) input = FeedInput(ds, input_names)
super(InferenceRunner, self).__init__( super(InferenceRunner, self).__init__(
input, infs, input_names, prefix='', extra_hooks=extra_hooks) input, infs, input_names, prefix='', extra_hooks=extra_hooks)
def _find_input_tensors(self): def _find_input_tensors(self):
return self.trainer.model.get_reused_placehdrs() return self._input_source.get_input_tensors()
def _find_feed_tensors(self):
return self._get_tensors_maybe_in_tower(self.input_names)
def _build_hook(self, inf): def _build_hook(self, inf):
out_names = inf.get_output_tensors() out_names = inf.get_output_tensors()
...@@ -191,7 +167,7 @@ class FeedfreeInferenceRunner(InferenceRunnerBase): ...@@ -191,7 +167,7 @@ class FeedfreeInferenceRunner(InferenceRunnerBase):
Args: Args:
input (TensorInput): the input to use. Must have ``size()``. input (TensorInput): the input to use. Must have ``size()``.
infs (list): list of :class:`Inferencer` to run. infs (list): list of :class:`Inferencer` to run.
input_names (list): must be a subset of the names in InputDesc. input_names (list[str]): same as in :class:`InferenceRunnerBase`.
prefix(str): an prefix used to build the tower. Must be set prefix(str): an prefix used to build the tower. Must be set
differently if more than one :class:`FeedfreeInferenceRunner` are used. differently if more than one :class:`FeedfreeInferenceRunner` are used.
""" """
...@@ -199,36 +175,14 @@ class FeedfreeInferenceRunner(InferenceRunnerBase): ...@@ -199,36 +175,14 @@ class FeedfreeInferenceRunner(InferenceRunnerBase):
super(FeedfreeInferenceRunner, self).__init__( super(FeedfreeInferenceRunner, self).__init__(
input, infs, input_names, prefix=prefix, extra_hooks=extra_hooks) input, infs, input_names, prefix=prefix, extra_hooks=extra_hooks)
def _setup_input_names(self):
super(FeedfreeInferenceRunner, self)._setup_input_names()
placeholder_names = set([k.name for k in self.trainer.model.get_inputs_desc()])
for n in self.input_names:
opname = get_op_tensor_name(n)[0]
assert opname in placeholder_names, \
"[FeedfreeInferenceRunner] name {} is not a model input!".format(n)
def _find_input_tensors(self): def _find_input_tensors(self):
# TODO move mapping to InputSource
tensors = self._input_source.get_input_tensors() tensors = self._input_source.get_input_tensors()
placeholders = self.trainer.model.get_reused_placehdrs()
assert len(self.input_names) == len(tensors), \ if self.input_names is None:
"[FeedfreeInferenceRunner] Input names must match the " \ return tensors
"length of the input data, but {} != {}".format(len(self.input_names), len(tensors))
# use placeholders for the unused inputs, use TensorInput for the used inpupts
ret = copy.copy(self.trainer.model.get_reused_placehdrs())
for name, tensor in zip(self.input_names, tensors):
tname = get_op_tensor_name(name)[1]
for idx, hdr in enumerate(ret):
if hdr.name == tname:
ret[idx] = tensor
break
else: else:
assert tname in set([k.name for k in ret]), \ return get_tensors_inputs(placeholders, tensors, self.input_names)
"Input name {} is not among model inputs: {}!".format(tname, ret)
self._input_tensors = ret
return ret
def _find_feed_tensors(self):
return []
def _build_hook(self, inf): def _build_hook(self, inf):
out_names = inf.get_output_tensors() # all is tensorname out_names = inf.get_output_tensors() # all is tensorname
...@@ -243,22 +197,24 @@ class FeedfreeInferenceRunner(InferenceRunnerBase): ...@@ -243,22 +197,24 @@ class FeedfreeInferenceRunner(InferenceRunnerBase):
return InferencerToHook(inf, ret) return InferencerToHook(inf, ret)
class DataParallelInferenceRunner(InferenceRunner): class DataParallelInferenceRunner(InferenceRunnerBase):
def __init__(self, ds, infs, gpus, input_names=None): def __init__(self, ds, infs, gpus, input_names=None):
super(DataParallelInferenceRunner, self).__init__(ds, infs, input_names) self._tower_names = [TowerContext.get_predict_tower_name(k)
for k in range(len(gpus))]
input = DataParallelFeedInput(
ds, self._tower_names, input_names=input_names)
super(DataParallelInferenceRunner, self).__init__(
input, infs, input_names)
self._gpus = gpus self._gpus = gpus
def _setup_graph(self): def _setup_graph(self):
model = self.trainer.model model = self.trainer.model
self._input_source.setup(model) self._input_source.setup(model)
self._setup_input_names()
# build graph # build graph
def build_tower(k): def build_tower(k):
towername = TowerContext.get_predict_tower_name(k)
# inputs (placeholders) for this tower only # inputs (placeholders) for this tower only
input_tensors = model.build_placeholders( input_tensors = self._input_source.get_input_tensors()
prefix=towername + '/')
model.build_graph(input_tensors) model.build_graph(input_tensors)
builder = PredictorTowerBuilder(build_tower, prefix=self._prefix) builder = PredictorTowerBuilder(build_tower, prefix=self._prefix)
...@@ -267,7 +223,6 @@ class DataParallelInferenceRunner(InferenceRunner): ...@@ -267,7 +223,6 @@ class DataParallelInferenceRunner(InferenceRunner):
builder.build(t) builder.build(t)
# setup feeds and hooks # setup feeds and hooks
self._feed_tensors = self._find_feed_tensors()
self._hooks_parallel = [self._build_hook_parallel(inf) for inf in self.infs] self._hooks_parallel = [self._build_hook_parallel(inf) for inf in self.infs]
self._hooks = [self._build_hook(inf) for inf in self.infs] self._hooks = [self._build_hook(inf) for inf in self.infs]
...@@ -278,10 +233,6 @@ class DataParallelInferenceRunner(InferenceRunner): ...@@ -278,10 +233,6 @@ class DataParallelInferenceRunner(InferenceRunner):
'/' + n for n in names]) '/' + n for n in names])
return ret return ret
def _find_feed_tensors(self):
names = self._duplicate_names_across_towers(self.input_names)
return get_tensors_by_names(names)
class InferencerToHookDataParallel(InferencerToHook): class InferencerToHookDataParallel(InferencerToHook):
def __init__(self, inf, fetches, size): def __init__(self, inf, fetches, size):
super(DataParallelInferenceRunner.InferencerToHookDataParallel, self).__init__(inf, fetches) super(DataParallelInferenceRunner.InferencerToHookDataParallel, self).__init__(inf, fetches)
...@@ -322,16 +273,13 @@ class DataParallelInferenceRunner(InferenceRunner): ...@@ -322,16 +273,13 @@ class DataParallelInferenceRunner(InferenceRunner):
nr_tower = len(self._gpus) nr_tower = len(self._gpus)
with tqdm.tqdm(total=total, **get_tqdm_kwargs()) as pbar: with tqdm.tqdm(total=total, **get_tqdm_kwargs()) as pbar:
while total >= nr_tower: while total >= nr_tower:
dps = [] feed = self._input_source.next_feed()
for k in self._gpus:
dps.extend(self._input_source.next_feed())
feed = dict(zip(self._feed_tensors, dps))
self._parallel_hooked_sess.run(fetches=[], feed_dict=feed) self._parallel_hooked_sess.run(fetches=[], feed_dict=feed)
pbar.update(nr_tower) pbar.update(nr_tower)
total -= nr_tower total -= nr_tower
# take care of the rest # take care of the rest
while total > 0: while total > 0:
dp = self._input_source.next_feed() feed = self._input_source.next_feed(cnt=1)
feed = dict(zip(self._feed_tensors[:len(dp)], dp))
self._hooked_sess.run(fetches=[], feed_dict=feed) self._hooked_sess.run(fetches=[], feed_dict=feed)
pbar.update(1)
summary_inferencer(self.trainer, self.infs) summary_inferencer(self.trainer, self.infs)
...@@ -12,8 +12,9 @@ except ImportError: ...@@ -12,8 +12,9 @@ except ImportError:
from itertools import chain from itertools import chain
from abc import ABCMeta, abstractmethod from abc import ABCMeta, abstractmethod
import six import six
from six.moves import range from six.moves import range, zip
from .utils import get_placeholders_by_names
from ..dataflow import DataFlow, RepeatedData from ..dataflow import DataFlow, RepeatedData
from ..tfutils.summary import add_moving_summary from ..tfutils.summary import add_moving_summary
from ..tfutils import get_op_tensor_name from ..tfutils import get_op_tensor_name
...@@ -24,7 +25,7 @@ from ..utils.concurrency import ShareSessionThread ...@@ -24,7 +25,7 @@ from ..utils.concurrency import ShareSessionThread
from ..callbacks.concurrency import StartProcOrThread from ..callbacks.concurrency import StartProcOrThread
from ..callbacks.base import Callback from ..callbacks.base import Callback
__all__ = ['InputSource', 'FeedfreeInput', __all__ = ['InputSource', 'FeedfreeInput', 'DataParallelFeedInput',
'QueueInput', 'BatchQueueInput', 'QueueInput', 'BatchQueueInput',
'ZMQInput', 'ZMQInput',
'DummyConstantInput', 'TensorInput', 'StagingInputWrapper'] 'DummyConstantInput', 'TensorInput', 'StagingInputWrapper']
...@@ -38,8 +39,9 @@ class InputSource(object): ...@@ -38,8 +39,9 @@ class InputSource(object):
def get_input_tensors(self): def get_input_tensors(self):
""" """
Returns: Returns:
list: A list of tensors corresponding to the inputs of the model. list: A list of tensors corresponding to the inputs of the model,
Always create and return a list of new input tensors when called. used as input of :func:`build_graph`.
For non-placeholder tensors, should always create and return new tensors when called.
""" """
def setup(self, model): def setup(self, model):
...@@ -53,27 +55,37 @@ class InputSource(object): ...@@ -53,27 +55,37 @@ class InputSource(object):
pass pass
def next_feed(self): def next_feed(self):
return [] """
Returns:
a feed_dict of {Tensor: data}, to be used to run the steps
"""
return {}
class FeedInput(InputSource): class FeedInput(InputSource):
""" Input by iterating over a DataFlow and feed datapoints. """ """ Input by iterating over a DataFlow and feed datapoints. """
def __init__(self, ds): def __init__(self, ds, input_names=None):
""" """
Args: Args:
ds (DataFlow): the input DataFlow. ds (DataFlow): the input DataFlow.
input_names (list[str]): input names this DataFlow maps to
""" """
assert isinstance(ds, DataFlow), ds assert isinstance(ds, DataFlow), ds
self.ds = ds self.ds = ds
self._input_names = input_names
def size(self): def size(self):
return self.ds.size() return self.ds.size()
def setup(self, model): def setup(self, model):
self.input_placehdrs = model.get_reused_placehdrs() self._all_placehdrs = model.get_reused_placehdrs()
rds = RepeatedData(self.ds, -1) if self._input_names is None:
rds.reset_state() self._placehdrs_to_feed = self._all_placehdrs
self.data_producer = rds.get_data() else:
self._placehdrs_to_feed = get_placeholders_by_names(
self._all_placehdrs, self._input_names)
self.reset_state()
def reset_state(self): def reset_state(self):
rds = RepeatedData(self.ds, -1) rds = RepeatedData(self.ds, -1)
...@@ -81,10 +93,61 @@ class FeedInput(InputSource): ...@@ -81,10 +93,61 @@ class FeedInput(InputSource):
self.data_producer = rds.get_data() self.data_producer = rds.get_data()
def get_input_tensors(self): def get_input_tensors(self):
return self.input_placehdrs return self._all_placehdrs
def next_feed(self): def next_feed(self):
return next(self.data_producer) dp = next(self.data_producer)
return dict(zip(self._placehdrs_to_feed, dp))
class DataParallelFeedInput(FeedInput):
"""
Input by feeding k datapoints to k copies of placeholders located on k towers.
"""
def __init__(self, ds, tower_names, input_names=None):
super(DataParallelFeedInput, self).__init__(ds, input_names)
self._tower_names = tower_names
self._nr_tower = len(tower_names)
def setup(self, model):
self._placehdrs_per_tower = []
self._feed_placehdrs_per_tower = []
for tname in self._tower_names:
# build a list of placeholders for each tower
self._placehdrs_per_tower.append(
model.build_placeholders(
prefix=tname + '/'))
# apply input mapping and store results in feed_placehdrs_per_tower
if self._input_names is None:
self._feed_placehdrs_per_tower = self._placehdrs_per_tower
else:
for phdrs, tname in zip(
self._placehdrs_per_tower, self._tower_names):
input_names = [tname + '/' + n for n in self._input_names]
# input_names to be used for this specific tower
self._feed_placehdrs_per_tower.append(
get_placeholders_by_names(phdrs, input_names))
self.reset_state()
def get_input_tensors(self):
# return placeholders for each tower
ctx = get_current_tower_context()
return self._placehdrs_per_tower[ctx.index]
def next_feed(self, cnt=None):
"""
Args:
cnt: how many towers to feed to. Defaults to the total number of towers
"""
if cnt is None:
cnt = self._nr_tower
feed = {}
for t in range(cnt):
dp = next(self.data_producer)
f = dict(zip(self._feed_placehdrs_per_tower[t], dp))
feed.update(f)
return feed
class FeedfreeInput(InputSource): class FeedfreeInput(InputSource):
......
...@@ -3,8 +3,6 @@ ...@@ -3,8 +3,6 @@
# Author: Yuxin Wu <ppwwyyxx@gmail.com> # Author: Yuxin Wu <ppwwyyxx@gmail.com>
from six.moves import zip
from .base import Trainer from .base import Trainer
from ..utils import logger from ..utils import logger
...@@ -33,8 +31,7 @@ class SimpleTrainer(Trainer): ...@@ -33,8 +31,7 @@ class SimpleTrainer(Trainer):
def run_step(self): def run_step(self):
""" Feed data into the graph and run the updates. """ """ Feed data into the graph and run the updates. """
dp = self._input_source.next_feed() feed = self._input_source.next_feed()
feed = dict(zip(self.inputs, dp))
self.hooked_sess.run(self.train_op, feed_dict=feed) self.hooked_sess.run(self.train_op, feed_dict=feed)
def _setup(self): def _setup(self):
......
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# File: utils.py
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
import copy
from six.moves import zip
from ..tfutils.common import get_op_tensor_name
from ..utils import logger
__all__ = ['get_tensors_inputs', 'get_placeholders_by_names']
def get_tensors_inputs(placeholders, tensors, names):
"""
Quite often we want to `build_graph()` with normal tensors
(rather than placeholders).
Args:
placeholders (list[Tensor]):
tensors (list[Tensor]): list of tf.Tensor
names (list[str]): names matching the tensors
Returns:
list[Tensor]: inputs to used with build_graph(),
with the corresponding placeholders replaced by tensors.
"""
assert len(tensors) == len(names), \
"Input tensors {} and input names {} have different length!".format(
tensors, names)
ret = copy.copy(placeholders)
placeholder_names = [p.name for p in placeholders]
for name, tensor in zip(names, tensors):
tensorname = get_op_tensor_name(name)[1]
try:
idx = placeholder_names.index(tensorname)
except ValueError:
logger.error("Name {} is not a model input!".format(tensorname))
raise
ret[idx] = tensor
return ret
def get_placeholders_by_names(placeholders, names):
"""
Returns:
list[Tensor]: a sublist of placeholders, matching names
"""
placeholder_names = [p.name for p in placeholders]
ret = []
for name in names:
tensorname = get_op_tensor_name(name)[1]
try:
idx = placeholder_names.index(tensorname)
except ValueError:
logger.error("Name {} is not a model input!".format(tensorname))
raise
ret.append(placeholders[idx])
return ret
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment