Commit 3bc0bed2 authored by Yuxin Wu's avatar Yuxin Wu

some small change in input_data

parent c5de2ef9
......@@ -5,14 +5,16 @@
import tensorflow as tf
from collections import namedtuple
import tqdm
import six
import copy
from six.moves import zip, range
from ..utils import logger, get_tqdm
from ..utils import logger, get_tqdm_kwargs, get_tqdm
from ..dataflow import DataFlow
from ..tfutils.common import get_op_tensor_name
from ..tfutils import TowerContext
from ..train.input_data import TensorInput
from ..train.input_data import TensorInput, FeedInput
from ..predict import PredictorTowerBuilder
from .base import Triggerable
......@@ -78,8 +80,9 @@ class InferenceRunner(Triggerable):
input_tensor_names(list): list of tensors to feed the dataflow to.
Defaults to all the input placeholders.
"""
assert isinstance(ds, DataFlow), ds
self.ds = ds
if isinstance(ds, DataFlow):
self.ds = FeedInput(ds)
assert isinstance(self.ds, FeedInput), self.ds
if not isinstance(infs, list):
self.infs = [infs]
else:
......@@ -132,14 +135,13 @@ class InferenceRunner(Triggerable):
inf.before_inference()
self.ds.reset_state()
with get_tqdm(total=self.ds.size()) as pbar:
for dp in self.ds.get_data():
for _ in tqdm.trange(self.ds.size(), **get_tqdm_kwargs()):
dp = self.ds.next_feed()
outputs = self.predictor(dp)
for inf, tensormap in zip(self.infs, self.inf_to_tensors):
inf_output = [(outputs if k.isOutput else dp)[k.index]
for k in tensormap]
inf.datapoint(inf_output)
pbar.update()
self._write_summary_after_inference()
def _write_summary_after_inference(self):
......@@ -195,7 +197,7 @@ class FeedfreeInferenceRunner(Triggerable):
self._input_data.setup(self.trainer.model)
# only 1 prediction tower will be used for inference
self._input_tensors = self._input_data.get_input_tensors()
model_placehdrs = self.trainer.model.get_reused_placehdrs()
model_placehdrs = copy.copy(self.trainer.model.get_reused_placehdrs())
if self._input_names is not None:
raise NotImplementedError("Random code. Not tested.")
assert len(self._input_names) == len(self._input_tensors), \
......
......@@ -11,7 +11,7 @@ from ..utils import logger
from ..utils.develop import deprecated
from ..utils.argtools import memoized
from ..utils.naming import SUMMARY_BACKUP_KEYS
from ..tfutils import get_tensors_by_names, TowerContext
from ..tfutils import get_tensors_by_names, TowerContext, get_op_tensor_name
from ..tfutils.collection import freeze_collection
__all__ = ['PredictorBase', 'AsyncPredictorBase',
......@@ -192,6 +192,19 @@ class PredictorTowerBuilder(object):
TowerContext(towername, is_training=False):
self._fn(tower)
@staticmethod
def get_tensors_maybe_in_tower(placeholder_names, names, k, prefix=''):
def maybe_inside_tower(name):
name = get_op_tensor_name(name)[0]
if name in placeholder_names:
return name
else:
# if the name is not a placeholder, use it's name in each tower
return TowerContext.get_predict_tower_name(k, prefix) + '/' + name
names = list(map(maybe_inside_tower, names))
tensors = get_tensors_by_names(names)
return tensors
def build_prediction_graph(build_tower_fn, towers=[0], prefix=''):
"""
......
......@@ -4,8 +4,8 @@
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
from ..utils import logger
from ..tfutils import get_tensors_by_names, TowerContext, get_op_tensor_name
from .base import OnlinePredictor, build_prediction_graph
from ..tfutils import get_tensors_by_names, TowerContext
from .base import OnlinePredictor, build_prediction_graph, PredictorTowerBuilder
__all__ = ['MultiTowerOfflinePredictor',
'DataParallelOfflinePredictor']
......@@ -33,26 +33,13 @@ class MultiTowerOfflinePredictor(OnlinePredictor):
self.sess = config.session_creator.create_session()
config.session_init.init(self.sess)
get_tensor_fn = MultiTowerOfflinePredictor.get_tensors_maybe_in_tower
get_tensor_fn = PredictorTowerBuilder.get_tensors_maybe_in_tower
for k in towers:
input_tensors = get_tensor_fn(placeholder_names, config.input_names, k)
output_tensors = get_tensor_fn(placeholder_names, config.output_names, k)
self.predictors.append(OnlinePredictor(
input_tensors, output_tensors, config.return_input, self.sess))
@staticmethod
def get_tensors_maybe_in_tower(placeholder_names, names, k):
def maybe_inside_tower(name):
name = get_op_tensor_name(name)[0]
if name in placeholder_names:
return name
else:
# if the name is not a placeholder, use it's name in each tower
return TowerContext.get_predict_tower_name(k) + '/' + name
names = map(maybe_inside_tower, names)
tensors = get_tensors_by_names(names)
return tensors
def _do_call(self, dp):
# use the first tower for compatible PredictorBase interface
return self.predictors[0]._do_call(dp)
......
......@@ -76,9 +76,9 @@ class TowerContext(object):
def get_predict_tower_name(towerid=0, prefix=''):
"""
Args:
prefix(str): an alphanumeric prefix.
towerid(int): an integer, the id of this predict tower, usually
used to choose the GPU id.
prefix(str): an alphanumeric prefix.
Returns:
str: the final tower name used to create a predict tower.
Currently it is ``PREDICT_TOWER + prefix + towerid``.
......
......@@ -23,12 +23,27 @@ __all__ = ['InputData', 'FeedfreeInput',
class InputData(object):
""" Base class for the abstract InputData. """
@abstractmethod
def get_input_tensors(self):
"""
Returns:
list: A list of tensors corresponding to the inputs of the model.
Always create and return a list of new input tensors when called.
"""
def setup(self, model):
pass
def setup_training(self, trainer):
self.setup(trainer.model)
@abstractmethod
def reset_state(self):
pass
def next_feed(self):
return []
class FeedInput(InputData):
""" Input by iterating over a DataFlow and feed datapoints. """
......@@ -49,30 +64,25 @@ class FeedInput(InputData):
rds.reset_state()
self.data_producer = rds.get_data()
def next_feed(self):
data = next(self.data_producer)
feed = dict(zip(self.input_placehdrs, data))
self._last_feed = feed
return feed
def reset_state(self):
rds = RepeatedData(self.ds, -1)
rds.reset_state()
self.data_producer = rds.get_data()
def last_feed(self):
return self._last_feed
def get_input_tensors(self):
return self.input_placehdrs
def next_feed(self):
return next(self.data_producer)
class FeedfreeInput(InputData):
""" Abstract base for input without feed,
e.g. by queue or other operations. """
@abstractmethod
def get_input_tensors(self):
"""
Returns:
list: A list of tensors corresponding to the inputs of the model.
Always create and return a list of new input tensors when called.
"""
def get_client_threads(self):
return []
def reset_state(self):
# TODO cannot reset
pass
class EnqueueThread(ShareSessionThread):
......@@ -234,6 +244,9 @@ class DummyConstantInput(FeedfreeInput):
self.shapes = shapes
logger.warn("Using dummy input for debug!")
def setup(self, model):
self.input_placehdrs = model.get_reused_placehdrs()
def get_input_tensors(self):
placehdrs = self.input_placehdrs
assert len(self.shapes) == len(placehdrs)
......
......@@ -5,7 +5,7 @@
import tensorflow as tf
from ..predict import (OnlinePredictor,
PredictorTowerBuilder, MultiTowerOfflinePredictor)
PredictorTowerBuilder)
__all__ = ['PredictorFactory']
......@@ -39,7 +39,7 @@ class PredictorFactory(object):
self._tower_builder.build(tower)
placeholder_names = set([k.name for k in self.model.get_inputs_desc()])
get_tensor_fn = MultiTowerOfflinePredictor.get_tensors_maybe_in_tower
get_tensor_fn = PredictorTowerBuilder.get_tensors_maybe_in_tower
in_tensors = get_tensor_fn(placeholder_names, input_names, tower)
out_tensors = get_tensor_fn(placeholder_names, output_names, tower)
return OnlinePredictor(in_tensors, out_tensors)
......@@ -28,7 +28,8 @@ class SimpleTrainer(Trainer):
def run_step(self):
""" Feed data into the graph and run the updates. """
feed = self._input_method.next_feed()
dp = self._input_method.next_feed()
feed = dict(zip(self.inputs, dp))
self.hooked_sess.run(self.train_op, feed_dict=feed)
def _setup(self):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment