Commit 76fa8e38 authored by Yuxin Wu's avatar Yuxin Wu

Simplify inference_runner: 1. move input_names mapping to InputSource 2. add DataParallelFeedInput

parent 48f6c267
This diff is collapsed.
......@@ -12,8 +12,9 @@ except ImportError:
from itertools import chain
from abc import ABCMeta, abstractmethod
import six
from six.moves import range
from six.moves import range, zip
from .utils import get_placeholders_by_names
from ..dataflow import DataFlow, RepeatedData
from ..tfutils.summary import add_moving_summary
from ..tfutils import get_op_tensor_name
......@@ -24,7 +25,7 @@ from ..utils.concurrency import ShareSessionThread
from ..callbacks.concurrency import StartProcOrThread
from ..callbacks.base import Callback
__all__ = ['InputSource', 'FeedfreeInput',
__all__ = ['InputSource', 'FeedfreeInput', 'DataParallelFeedInput',
'QueueInput', 'BatchQueueInput',
'ZMQInput',
'DummyConstantInput', 'TensorInput', 'StagingInputWrapper']
......@@ -38,8 +39,9 @@ class InputSource(object):
def get_input_tensors(self):
"""
Returns:
list: A list of tensors corresponding to the inputs of the model.
Always create and return a list of new input tensors when called.
list: A list of tensors corresponding to the inputs of the model,
used as input of :func:`build_graph`.
For non-placeholder tensors, should always create and return new tensors when called.
"""
def setup(self, model):
......@@ -53,27 +55,37 @@ class InputSource(object):
pass
def next_feed(self):
return []
"""
Returns:
a feed_dict of {Tensor: data}, to be used to run the steps
"""
return {}
class FeedInput(InputSource):
""" Input by iterating over a DataFlow and feed datapoints. """
def __init__(self, ds):
def __init__(self, ds, input_names=None):
"""
Args:
ds (DataFlow): the input DataFlow.
input_names (list[str]): input names this DataFlow maps to
"""
assert isinstance(ds, DataFlow), ds
self.ds = ds
self._input_names = input_names
def size(self):
return self.ds.size()
def setup(self, model):
self.input_placehdrs = model.get_reused_placehdrs()
rds = RepeatedData(self.ds, -1)
rds.reset_state()
self.data_producer = rds.get_data()
self._all_placehdrs = model.get_reused_placehdrs()
if self._input_names is None:
self._placehdrs_to_feed = self._all_placehdrs
else:
self._placehdrs_to_feed = get_placeholders_by_names(
self._all_placehdrs, self._input_names)
self.reset_state()
def reset_state(self):
rds = RepeatedData(self.ds, -1)
......@@ -81,10 +93,61 @@ class FeedInput(InputSource):
self.data_producer = rds.get_data()
def get_input_tensors(self):
return self.input_placehdrs
return self._all_placehdrs
def next_feed(self):
return next(self.data_producer)
dp = next(self.data_producer)
return dict(zip(self._placehdrs_to_feed, dp))
class DataParallelFeedInput(FeedInput):
"""
Input by feeding k datapoints to k copies of placeholders located on k towers.
"""
def __init__(self, ds, tower_names, input_names=None):
super(DataParallelFeedInput, self).__init__(ds, input_names)
self._tower_names = tower_names
self._nr_tower = len(tower_names)
def setup(self, model):
self._placehdrs_per_tower = []
self._feed_placehdrs_per_tower = []
for tname in self._tower_names:
# build a list of placeholders for each tower
self._placehdrs_per_tower.append(
model.build_placeholders(
prefix=tname + '/'))
# apply input mapping and store results in feed_placehdrs_per_tower
if self._input_names is None:
self._feed_placehdrs_per_tower = self._placehdrs_per_tower
else:
for phdrs, tname in zip(
self._placehdrs_per_tower, self._tower_names):
input_names = [tname + '/' + n for n in self._input_names]
# input_names to be used for this specific tower
self._feed_placehdrs_per_tower.append(
get_placeholders_by_names(phdrs, input_names))
self.reset_state()
def get_input_tensors(self):
# return placeholders for each tower
ctx = get_current_tower_context()
return self._placehdrs_per_tower[ctx.index]
def next_feed(self, cnt=None):
"""
Args:
cnt: how many towers to feed to. Defaults to the total number of towers
"""
if cnt is None:
cnt = self._nr_tower
feed = {}
for t in range(cnt):
dp = next(self.data_producer)
f = dict(zip(self._feed_placehdrs_per_tower[t], dp))
feed.update(f)
return feed
class FeedfreeInput(InputSource):
......
......@@ -3,8 +3,6 @@
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
from six.moves import zip
from .base import Trainer
from ..utils import logger
......@@ -33,8 +31,7 @@ class SimpleTrainer(Trainer):
def run_step(self):
""" Feed data into the graph and run the updates. """
dp = self._input_source.next_feed()
feed = dict(zip(self.inputs, dp))
feed = self._input_source.next_feed()
self.hooked_sess.run(self.train_op, feed_dict=feed)
def _setup(self):
......
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# File: utils.py
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
import copy
from six.moves import zip
from ..tfutils.common import get_op_tensor_name
from ..utils import logger
__all__ = ['get_tensors_inputs', 'get_placeholders_by_names']
def get_tensors_inputs(placeholders, tensors, names):
"""
Quite often we want to `build_graph()` with normal tensors
(rather than placeholders).
Args:
placeholders (list[Tensor]):
tensors (list[Tensor]): list of tf.Tensor
names (list[str]): names matching the tensors
Returns:
list[Tensor]: inputs to used with build_graph(),
with the corresponding placeholders replaced by tensors.
"""
assert len(tensors) == len(names), \
"Input tensors {} and input names {} have different length!".format(
tensors, names)
ret = copy.copy(placeholders)
placeholder_names = [p.name for p in placeholders]
for name, tensor in zip(names, tensors):
tensorname = get_op_tensor_name(name)[1]
try:
idx = placeholder_names.index(tensorname)
except ValueError:
logger.error("Name {} is not a model input!".format(tensorname))
raise
ret[idx] = tensor
return ret
def get_placeholders_by_names(placeholders, names):
"""
Returns:
list[Tensor]: a sublist of placeholders, matching names
"""
placeholder_names = [p.name for p in placeholders]
ret = []
for name in names:
tensorname = get_op_tensor_name(name)[1]
try:
idx = placeholder_names.index(tensorname)
except ValueError:
logger.error("Name {} is not a model input!".format(tensorname))
raise
ret.append(placeholders[idx])
return ret
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment