Commit 76fa8e38 authored by Yuxin Wu's avatar Yuxin Wu

Simplify inference_runner: 1. move input_names mapping to InputSource 2. add DataParallelFeedInput

parent 48f6c267
This diff is collapsed.
...@@ -12,8 +12,9 @@ except ImportError: ...@@ -12,8 +12,9 @@ except ImportError:
from itertools import chain from itertools import chain
from abc import ABCMeta, abstractmethod from abc import ABCMeta, abstractmethod
import six import six
from six.moves import range from six.moves import range, zip
from .utils import get_placeholders_by_names
from ..dataflow import DataFlow, RepeatedData from ..dataflow import DataFlow, RepeatedData
from ..tfutils.summary import add_moving_summary from ..tfutils.summary import add_moving_summary
from ..tfutils import get_op_tensor_name from ..tfutils import get_op_tensor_name
...@@ -24,7 +25,7 @@ from ..utils.concurrency import ShareSessionThread ...@@ -24,7 +25,7 @@ from ..utils.concurrency import ShareSessionThread
from ..callbacks.concurrency import StartProcOrThread from ..callbacks.concurrency import StartProcOrThread
from ..callbacks.base import Callback from ..callbacks.base import Callback
__all__ = ['InputSource', 'FeedfreeInput', __all__ = ['InputSource', 'FeedfreeInput', 'DataParallelFeedInput',
'QueueInput', 'BatchQueueInput', 'QueueInput', 'BatchQueueInput',
'ZMQInput', 'ZMQInput',
'DummyConstantInput', 'TensorInput', 'StagingInputWrapper'] 'DummyConstantInput', 'TensorInput', 'StagingInputWrapper']
...@@ -38,8 +39,9 @@ class InputSource(object): ...@@ -38,8 +39,9 @@ class InputSource(object):
def get_input_tensors(self): def get_input_tensors(self):
""" """
Returns: Returns:
list: A list of tensors corresponding to the inputs of the model. list: A list of tensors corresponding to the inputs of the model,
Always create and return a list of new input tensors when called. used as input of :func:`build_graph`.
For non-placeholder tensors, should always create and return new tensors when called.
""" """
def setup(self, model): def setup(self, model):
...@@ -53,27 +55,37 @@ class InputSource(object): ...@@ -53,27 +55,37 @@ class InputSource(object):
pass pass
def next_feed(self): def next_feed(self):
return [] """
Returns:
a feed_dict of {Tensor: data}, to be used to run the steps
"""
return {}
class FeedInput(InputSource): class FeedInput(InputSource):
""" Input by iterating over a DataFlow and feed datapoints. """ """ Input by iterating over a DataFlow and feed datapoints. """
def __init__(self, ds): def __init__(self, ds, input_names=None):
""" """
Args: Args:
ds (DataFlow): the input DataFlow. ds (DataFlow): the input DataFlow.
input_names (list[str]): input names this DataFlow maps to
""" """
assert isinstance(ds, DataFlow), ds assert isinstance(ds, DataFlow), ds
self.ds = ds self.ds = ds
self._input_names = input_names
def size(self): def size(self):
return self.ds.size() return self.ds.size()
def setup(self, model): def setup(self, model):
self.input_placehdrs = model.get_reused_placehdrs() self._all_placehdrs = model.get_reused_placehdrs()
rds = RepeatedData(self.ds, -1) if self._input_names is None:
rds.reset_state() self._placehdrs_to_feed = self._all_placehdrs
self.data_producer = rds.get_data() else:
self._placehdrs_to_feed = get_placeholders_by_names(
self._all_placehdrs, self._input_names)
self.reset_state()
def reset_state(self): def reset_state(self):
rds = RepeatedData(self.ds, -1) rds = RepeatedData(self.ds, -1)
...@@ -81,10 +93,61 @@ class FeedInput(InputSource): ...@@ -81,10 +93,61 @@ class FeedInput(InputSource):
self.data_producer = rds.get_data() self.data_producer = rds.get_data()
def get_input_tensors(self): def get_input_tensors(self):
return self.input_placehdrs return self._all_placehdrs
def next_feed(self): def next_feed(self):
return next(self.data_producer) dp = next(self.data_producer)
return dict(zip(self._placehdrs_to_feed, dp))
class DataParallelFeedInput(FeedInput):
"""
Input by feeding k datapoints to k copies of placeholders located on k towers.
"""
def __init__(self, ds, tower_names, input_names=None):
super(DataParallelFeedInput, self).__init__(ds, input_names)
self._tower_names = tower_names
self._nr_tower = len(tower_names)
def setup(self, model):
self._placehdrs_per_tower = []
self._feed_placehdrs_per_tower = []
for tname in self._tower_names:
# build a list of placeholders for each tower
self._placehdrs_per_tower.append(
model.build_placeholders(
prefix=tname + '/'))
# apply input mapping and store results in feed_placehdrs_per_tower
if self._input_names is None:
self._feed_placehdrs_per_tower = self._placehdrs_per_tower
else:
for phdrs, tname in zip(
self._placehdrs_per_tower, self._tower_names):
input_names = [tname + '/' + n for n in self._input_names]
# input_names to be used for this specific tower
self._feed_placehdrs_per_tower.append(
get_placeholders_by_names(phdrs, input_names))
self.reset_state()
def get_input_tensors(self):
# return placeholders for each tower
ctx = get_current_tower_context()
return self._placehdrs_per_tower[ctx.index]
def next_feed(self, cnt=None):
"""
Args:
cnt: how many towers to feed to. Defaults to the total number of towers
"""
if cnt is None:
cnt = self._nr_tower
feed = {}
for t in range(cnt):
dp = next(self.data_producer)
f = dict(zip(self._feed_placehdrs_per_tower[t], dp))
feed.update(f)
return feed
class FeedfreeInput(InputSource): class FeedfreeInput(InputSource):
......
...@@ -3,8 +3,6 @@ ...@@ -3,8 +3,6 @@
# Author: Yuxin Wu <ppwwyyxx@gmail.com> # Author: Yuxin Wu <ppwwyyxx@gmail.com>
from six.moves import zip
from .base import Trainer from .base import Trainer
from ..utils import logger from ..utils import logger
...@@ -33,8 +31,7 @@ class SimpleTrainer(Trainer): ...@@ -33,8 +31,7 @@ class SimpleTrainer(Trainer):
def run_step(self): def run_step(self):
""" Feed data into the graph and run the updates. """ """ Feed data into the graph and run the updates. """
dp = self._input_source.next_feed() feed = self._input_source.next_feed()
feed = dict(zip(self.inputs, dp))
self.hooked_sess.run(self.train_op, feed_dict=feed) self.hooked_sess.run(self.train_op, feed_dict=feed)
def _setup(self): def _setup(self):
......
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# File: utils.py
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
import copy
from six.moves import zip
from ..tfutils.common import get_op_tensor_name
from ..utils import logger
__all__ = ['get_tensors_inputs', 'get_placeholders_by_names']
def get_tensors_inputs(placeholders, tensors, names):
"""
Quite often we want to `build_graph()` with normal tensors
(rather than placeholders).
Args:
placeholders (list[Tensor]):
tensors (list[Tensor]): list of tf.Tensor
names (list[str]): names matching the tensors
Returns:
list[Tensor]: inputs to used with build_graph(),
with the corresponding placeholders replaced by tensors.
"""
assert len(tensors) == len(names), \
"Input tensors {} and input names {} have different length!".format(
tensors, names)
ret = copy.copy(placeholders)
placeholder_names = [p.name for p in placeholders]
for name, tensor in zip(names, tensors):
tensorname = get_op_tensor_name(name)[1]
try:
idx = placeholder_names.index(tensorname)
except ValueError:
logger.error("Name {} is not a model input!".format(tensorname))
raise
ret[idx] = tensor
return ret
def get_placeholders_by_names(placeholders, names):
"""
Returns:
list[Tensor]: a sublist of placeholders, matching names
"""
placeholder_names = [p.name for p in placeholders]
ret = []
for name in names:
tensorname = get_op_tensor_name(name)[1]
try:
idx = placeholder_names.index(tensorname)
except ValueError:
logger.error("Name {} is not a model input!".format(tensorname))
raise
ret.append(placeholders[idx])
return ret
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment