Commit 3bc0bed2 authored by Yuxin Wu's avatar Yuxin Wu

some small change in input_data

parent c5de2ef9
...@@ -5,14 +5,16 @@ ...@@ -5,14 +5,16 @@
import tensorflow as tf import tensorflow as tf
from collections import namedtuple from collections import namedtuple
import tqdm
import six import six
import copy
from six.moves import zip, range from six.moves import zip, range
from ..utils import logger, get_tqdm from ..utils import logger, get_tqdm_kwargs, get_tqdm
from ..dataflow import DataFlow from ..dataflow import DataFlow
from ..tfutils.common import get_op_tensor_name from ..tfutils.common import get_op_tensor_name
from ..tfutils import TowerContext from ..tfutils import TowerContext
from ..train.input_data import TensorInput from ..train.input_data import TensorInput, FeedInput
from ..predict import PredictorTowerBuilder from ..predict import PredictorTowerBuilder
from .base import Triggerable from .base import Triggerable
...@@ -78,8 +80,9 @@ class InferenceRunner(Triggerable): ...@@ -78,8 +80,9 @@ class InferenceRunner(Triggerable):
input_tensor_names(list): list of tensors to feed the dataflow to. input_tensor_names(list): list of tensors to feed the dataflow to.
Defaults to all the input placeholders. Defaults to all the input placeholders.
""" """
assert isinstance(ds, DataFlow), ds if isinstance(ds, DataFlow):
self.ds = ds self.ds = FeedInput(ds)
assert isinstance(self.ds, FeedInput), self.ds
if not isinstance(infs, list): if not isinstance(infs, list):
self.infs = [infs] self.infs = [infs]
else: else:
...@@ -132,14 +135,13 @@ class InferenceRunner(Triggerable): ...@@ -132,14 +135,13 @@ class InferenceRunner(Triggerable):
inf.before_inference() inf.before_inference()
self.ds.reset_state() self.ds.reset_state()
with get_tqdm(total=self.ds.size()) as pbar: for _ in tqdm.trange(self.ds.size(), **get_tqdm_kwargs()):
for dp in self.ds.get_data(): dp = self.ds.next_feed()
outputs = self.predictor(dp) outputs = self.predictor(dp)
for inf, tensormap in zip(self.infs, self.inf_to_tensors): for inf, tensormap in zip(self.infs, self.inf_to_tensors):
inf_output = [(outputs if k.isOutput else dp)[k.index] inf_output = [(outputs if k.isOutput else dp)[k.index]
for k in tensormap] for k in tensormap]
inf.datapoint(inf_output) inf.datapoint(inf_output)
pbar.update()
self._write_summary_after_inference() self._write_summary_after_inference()
def _write_summary_after_inference(self): def _write_summary_after_inference(self):
...@@ -195,7 +197,7 @@ class FeedfreeInferenceRunner(Triggerable): ...@@ -195,7 +197,7 @@ class FeedfreeInferenceRunner(Triggerable):
self._input_data.setup(self.trainer.model) self._input_data.setup(self.trainer.model)
# only 1 prediction tower will be used for inference # only 1 prediction tower will be used for inference
self._input_tensors = self._input_data.get_input_tensors() self._input_tensors = self._input_data.get_input_tensors()
model_placehdrs = self.trainer.model.get_reused_placehdrs() model_placehdrs = copy.copy(self.trainer.model.get_reused_placehdrs())
if self._input_names is not None: if self._input_names is not None:
raise NotImplementedError("Random code. Not tested.") raise NotImplementedError("Random code. Not tested.")
assert len(self._input_names) == len(self._input_tensors), \ assert len(self._input_names) == len(self._input_tensors), \
......
...@@ -11,7 +11,7 @@ from ..utils import logger ...@@ -11,7 +11,7 @@ from ..utils import logger
from ..utils.develop import deprecated from ..utils.develop import deprecated
from ..utils.argtools import memoized from ..utils.argtools import memoized
from ..utils.naming import SUMMARY_BACKUP_KEYS from ..utils.naming import SUMMARY_BACKUP_KEYS
from ..tfutils import get_tensors_by_names, TowerContext from ..tfutils import get_tensors_by_names, TowerContext, get_op_tensor_name
from ..tfutils.collection import freeze_collection from ..tfutils.collection import freeze_collection
__all__ = ['PredictorBase', 'AsyncPredictorBase', __all__ = ['PredictorBase', 'AsyncPredictorBase',
...@@ -192,6 +192,19 @@ class PredictorTowerBuilder(object): ...@@ -192,6 +192,19 @@ class PredictorTowerBuilder(object):
TowerContext(towername, is_training=False): TowerContext(towername, is_training=False):
self._fn(tower) self._fn(tower)
@staticmethod
def get_tensors_maybe_in_tower(placeholder_names, names, k, prefix=''):
def maybe_inside_tower(name):
name = get_op_tensor_name(name)[0]
if name in placeholder_names:
return name
else:
# if the name is not a placeholder, use it's name in each tower
return TowerContext.get_predict_tower_name(k, prefix) + '/' + name
names = list(map(maybe_inside_tower, names))
tensors = get_tensors_by_names(names)
return tensors
def build_prediction_graph(build_tower_fn, towers=[0], prefix=''): def build_prediction_graph(build_tower_fn, towers=[0], prefix=''):
""" """
......
...@@ -4,8 +4,8 @@ ...@@ -4,8 +4,8 @@
# Author: Yuxin Wu <ppwwyyxxc@gmail.com> # Author: Yuxin Wu <ppwwyyxxc@gmail.com>
from ..utils import logger from ..utils import logger
from ..tfutils import get_tensors_by_names, TowerContext, get_op_tensor_name from ..tfutils import get_tensors_by_names, TowerContext
from .base import OnlinePredictor, build_prediction_graph from .base import OnlinePredictor, build_prediction_graph, PredictorTowerBuilder
__all__ = ['MultiTowerOfflinePredictor', __all__ = ['MultiTowerOfflinePredictor',
'DataParallelOfflinePredictor'] 'DataParallelOfflinePredictor']
...@@ -33,26 +33,13 @@ class MultiTowerOfflinePredictor(OnlinePredictor): ...@@ -33,26 +33,13 @@ class MultiTowerOfflinePredictor(OnlinePredictor):
self.sess = config.session_creator.create_session() self.sess = config.session_creator.create_session()
config.session_init.init(self.sess) config.session_init.init(self.sess)
get_tensor_fn = MultiTowerOfflinePredictor.get_tensors_maybe_in_tower get_tensor_fn = PredictorTowerBuilder.get_tensors_maybe_in_tower
for k in towers: for k in towers:
input_tensors = get_tensor_fn(placeholder_names, config.input_names, k) input_tensors = get_tensor_fn(placeholder_names, config.input_names, k)
output_tensors = get_tensor_fn(placeholder_names, config.output_names, k) output_tensors = get_tensor_fn(placeholder_names, config.output_names, k)
self.predictors.append(OnlinePredictor( self.predictors.append(OnlinePredictor(
input_tensors, output_tensors, config.return_input, self.sess)) input_tensors, output_tensors, config.return_input, self.sess))
@staticmethod
def get_tensors_maybe_in_tower(placeholder_names, names, k):
def maybe_inside_tower(name):
name = get_op_tensor_name(name)[0]
if name in placeholder_names:
return name
else:
# if the name is not a placeholder, use it's name in each tower
return TowerContext.get_predict_tower_name(k) + '/' + name
names = map(maybe_inside_tower, names)
tensors = get_tensors_by_names(names)
return tensors
def _do_call(self, dp): def _do_call(self, dp):
# use the first tower for compatible PredictorBase interface # use the first tower for compatible PredictorBase interface
return self.predictors[0]._do_call(dp) return self.predictors[0]._do_call(dp)
......
...@@ -76,9 +76,9 @@ class TowerContext(object): ...@@ -76,9 +76,9 @@ class TowerContext(object):
def get_predict_tower_name(towerid=0, prefix=''): def get_predict_tower_name(towerid=0, prefix=''):
""" """
Args: Args:
prefix(str): an alphanumeric prefix.
towerid(int): an integer, the id of this predict tower, usually towerid(int): an integer, the id of this predict tower, usually
used to choose the GPU id. used to choose the GPU id.
prefix(str): an alphanumeric prefix.
Returns: Returns:
str: the final tower name used to create a predict tower. str: the final tower name used to create a predict tower.
Currently it is ``PREDICT_TOWER + prefix + towerid``. Currently it is ``PREDICT_TOWER + prefix + towerid``.
......
...@@ -23,12 +23,27 @@ __all__ = ['InputData', 'FeedfreeInput', ...@@ -23,12 +23,27 @@ __all__ = ['InputData', 'FeedfreeInput',
class InputData(object): class InputData(object):
""" Base class for the abstract InputData. """ """ Base class for the abstract InputData. """
@abstractmethod
def get_input_tensors(self):
"""
Returns:
list: A list of tensors corresponding to the inputs of the model.
Always create and return a list of new input tensors when called.
"""
def setup(self, model): def setup(self, model):
pass pass
def setup_training(self, trainer): def setup_training(self, trainer):
self.setup(trainer.model) self.setup(trainer.model)
@abstractmethod
def reset_state(self):
pass
def next_feed(self):
return []
class FeedInput(InputData): class FeedInput(InputData):
""" Input by iterating over a DataFlow and feed datapoints. """ """ Input by iterating over a DataFlow and feed datapoints. """
...@@ -49,30 +64,25 @@ class FeedInput(InputData): ...@@ -49,30 +64,25 @@ class FeedInput(InputData):
rds.reset_state() rds.reset_state()
self.data_producer = rds.get_data() self.data_producer = rds.get_data()
def next_feed(self): def reset_state(self):
data = next(self.data_producer) rds = RepeatedData(self.ds, -1)
feed = dict(zip(self.input_placehdrs, data)) rds.reset_state()
self._last_feed = feed self.data_producer = rds.get_data()
return feed
def last_feed(self): def get_input_tensors(self):
return self._last_feed return self.input_placehdrs
def next_feed(self):
return next(self.data_producer)
class FeedfreeInput(InputData): class FeedfreeInput(InputData):
""" Abstract base for input without feed, """ Abstract base for input without feed,
e.g. by queue or other operations. """ e.g. by queue or other operations. """
@abstractmethod def reset_state(self):
def get_input_tensors(self): # TODO cannot reset
""" pass
Returns:
list: A list of tensors corresponding to the inputs of the model.
Always create and return a list of new input tensors when called.
"""
def get_client_threads(self):
return []
class EnqueueThread(ShareSessionThread): class EnqueueThread(ShareSessionThread):
...@@ -234,6 +244,9 @@ class DummyConstantInput(FeedfreeInput): ...@@ -234,6 +244,9 @@ class DummyConstantInput(FeedfreeInput):
self.shapes = shapes self.shapes = shapes
logger.warn("Using dummy input for debug!") logger.warn("Using dummy input for debug!")
def setup(self, model):
self.input_placehdrs = model.get_reused_placehdrs()
def get_input_tensors(self): def get_input_tensors(self):
placehdrs = self.input_placehdrs placehdrs = self.input_placehdrs
assert len(self.shapes) == len(placehdrs) assert len(self.shapes) == len(placehdrs)
......
...@@ -5,7 +5,7 @@ ...@@ -5,7 +5,7 @@
import tensorflow as tf import tensorflow as tf
from ..predict import (OnlinePredictor, from ..predict import (OnlinePredictor,
PredictorTowerBuilder, MultiTowerOfflinePredictor) PredictorTowerBuilder)
__all__ = ['PredictorFactory'] __all__ = ['PredictorFactory']
...@@ -39,7 +39,7 @@ class PredictorFactory(object): ...@@ -39,7 +39,7 @@ class PredictorFactory(object):
self._tower_builder.build(tower) self._tower_builder.build(tower)
placeholder_names = set([k.name for k in self.model.get_inputs_desc()]) placeholder_names = set([k.name for k in self.model.get_inputs_desc()])
get_tensor_fn = MultiTowerOfflinePredictor.get_tensors_maybe_in_tower get_tensor_fn = PredictorTowerBuilder.get_tensors_maybe_in_tower
in_tensors = get_tensor_fn(placeholder_names, input_names, tower) in_tensors = get_tensor_fn(placeholder_names, input_names, tower)
out_tensors = get_tensor_fn(placeholder_names, output_names, tower) out_tensors = get_tensor_fn(placeholder_names, output_names, tower)
return OnlinePredictor(in_tensors, out_tensors) return OnlinePredictor(in_tensors, out_tensors)
...@@ -28,7 +28,8 @@ class SimpleTrainer(Trainer): ...@@ -28,7 +28,8 @@ class SimpleTrainer(Trainer):
def run_step(self): def run_step(self):
""" Feed data into the graph and run the updates. """ """ Feed data into the graph and run the updates. """
feed = self._input_method.next_feed() dp = self._input_method.next_feed()
feed = dict(zip(self.inputs, dp))
self.hooked_sess.run(self.train_op, feed_dict=feed) self.hooked_sess.run(self.train_op, feed_dict=feed)
def _setup(self): def _setup(self):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment