Commit 9a711e72 authored by Yuxin Wu's avatar Yuxin Wu

[Trainerv2] use v2 inference interface in v1 trainer

parent e121701a
......@@ -108,7 +108,7 @@ class InferenceRunner(InferenceRunnerBase):
infs (list): a list of :class:`Inferencer` instances.
tower_name (str): the name scope of the tower to build. Need to set a
different one if multiple InferenceRunner are used.
gpu (int): the device to use
device (int): the device to use
"""
if isinstance(input, DataFlow):
input = FeedInput(input, infinite=False)
......@@ -124,32 +124,17 @@ class InferenceRunner(InferenceRunnerBase):
return InferencerToHook(inf, fetches)
def _setup_graph(self):
if self.trainer._API_VERSION == 1:
# old Trainer API
assert self.trainer.model is not None
# Use predict_tower in train config. either gpuid or -1
if self.trainer._config.predict_tower is not None:
if self.trainer._API_VERSION == 1 and self.trainer._config.predict_tower is not None:
device = self.trainer._config.predict_tower[0]
else:
device = self._device
device = '/gpu:{}'.format(device) if device >= 0 else '/cpu:0'
input_callbacks = self._input_source.setup(self.trainer.model.get_inputs_desc())
with tf.variable_scope(tf.get_variable_scope(), reuse=True):
self._tower_handle = self.trainer.predictor_factory.build(
self._tower_name, device, self._input_source)
else:
# new Trainer API
from ..trainv2 import TowerTrainer
assert isinstance(self.trainer, TowerTrainer), self.trainer
assert self.trainer.tower_func is not None, "You must set tower_func of the trainer to use InferenceRunner!"
input_callbacks = self._input_source.setup(self.trainer.inputs_desc)
with tf.variable_scope(tf.get_variable_scope(), reuse=True):
SimplePredictBuilder(
ns_name=self._tower_name,
vs_name=self.trainer._main_tower_vs_name, device=0).build(
vs_name=self.trainer._main_tower_vs_name, device=device).build(
self._input_source, self.trainer.tower_func)
self._tower_handle = self.trainer.tower_func.towers[-1]
......@@ -202,21 +187,7 @@ class DataParallelInferenceRunner(InferenceRunnerBase):
def _setup_graph(self):
self._handles = []
if self.trainer._API_VERSION == 1:
# old Trainer API
input_callbacks = self._input_source.setup(self.trainer.model.get_inputs_desc())
# build each predict tower
with tf.variable_scope(tf.get_variable_scope(), reuse=True):
for idx, t in enumerate(self._gpus):
tower_name = self._tower_names[idx]
device = '/gpu:{}'.format(t)
self._handles.append(
self.trainer.predictor_factory.build(
tower_name, device, self._input_source))
else:
# new Trainer API
from ..trainv2 import TowerTrainer
assert isinstance(self.trainer, TowerTrainer), self.trainer
assert self.trainer.tower_func is not None, "You must set tower_func of the trainer to use InferenceRunner!"
input_callbacks = self._input_source.setup(self.trainer.inputs_desc)
with tf.variable_scope(tf.get_variable_scope(), reuse=True):
......
......@@ -60,9 +60,6 @@ class TowerContext(object):
(self.is_training and len(self._vs_name) > 0) or \
(not self.is_training and len(self._vs_name) > 0 and not self._initial_vs_reuse)
# TODO clarify the interface on name/vs_name/ns_name.
# TODO in inference, vs_name may need to be different from ns_name.i
# How to deal with this?
@property
def name(self):
return self._name
......@@ -151,7 +148,6 @@ class TowerContext(object):
def get_current_tower_context():
global _CurrentTowerContext
return _CurrentTowerContext
......
......@@ -18,8 +18,11 @@ from ..tfutils import get_global_step_value
from ..tfutils.model_utils import describe_trainable_vars
from ..tfutils.sesscreate import ReuseSessionCreator
from ..tfutils.sessinit import JustCurrentSession
from ..tfutils.tower import TowerFuncWrapper
from ..graph_builder.predictor_factory import PredictorFactory
from ..input_source import PlaceholderInput
from ..graph_builder.predictor_factory import SimplePredictBuilder
from ..predict.base import OnlinePredictor
from ..callbacks.steps import MaintainStepCounter
__all__ = ['Trainer', 'StopTraining']
......@@ -117,6 +120,16 @@ class Trainer(object):
assert isinstance(config, TrainConfig), type(config)
self._config = config
self.model = config.model
if self.model is not None:
def f(*inputs):
self.model.build_graph(inputs)
"""
Only to mimic new trainer interafce on inference.
"""
self.inputs_desc = self.model.get_inputs_desc()
self.tower_func = TowerFuncWrapper(f, self.inputs_desc)
self._callbacks = []
self._monitors = []
......@@ -268,8 +281,7 @@ class Trainer(object):
def get_predictor(self, input_names, output_names, tower=0):
"""
Returns a callable predictor built under ``is_training=False`` tower context.
Note that this method is only valid when this trainer has a ``ModelDesc``.
Returns a callable predictor built under ``TowerContext(is_training=False)``.
Args:
input_names (list), output_names(list): list of names
......@@ -278,19 +290,27 @@ class Trainer(object):
Returns:
an :class:`OnlinePredictor`.
"""
return self.predictor_factory.get_predictor(input_names, output_names, tower)
device = tower
assert self.tower_func is not None, "Must set tower_func on the trainer to use get_predictor()!"
tower_name = 'tower-pred-{}'.format(device) if device >= 0 else 'tower-pred-cpu'
@property
def predictor_factory(self):
assert self.model is not None, \
"Predictor can only be built one Trainer has ModelDesc!"
if not hasattr(self, '_predictor_factory'):
self._predictor_factory = PredictorFactory(
self.model, self.vs_name_for_predictor)
return self._predictor_factory
try:
tower = self.tower_func.towers[tower_name]
except KeyError:
input = PlaceholderInput()
input.setup(self.inputs_desc)
with tf.variable_scope(tf.get_variable_scope(), reuse=True):
SimplePredictBuilder(
ns_name=tower_name, vs_name=self._main_tower_vs_name,
device=device).build(input, self.tower_func)
tower = self.tower_func.towers[tower_name]
input_tensors = tower.get_tensors(input_names)
output_tensors = tower.get_tensors(output_names)
return OnlinePredictor(input_tensors, output_tensors)
@property
def vs_name_for_predictor(self):
def _main_tower_vs_name(self):
# The vs name a predictor should be built under.
# for internal use only. Should let graphbuilder return it.
return ""
......
......@@ -95,5 +95,5 @@ class DistributedTrainerReplicated(Trainer):
self._config.session_creator = get_distributed_session_creator(self.server)
@property
def vs_name_for_predictor(self):
def _main_tower_vs_name(self):
return "tower0"
......@@ -268,6 +268,7 @@ class TowerTrainer(Trainer):
input = PlaceholderInput()
input.setup(self.inputs_desc)
with tf.variable_scope(tf.get_variable_scope(), reuse=True):
SimplePredictBuilder(
ns_name=tower_name, vs_name=self._main_tower_vs_name,
device=device).build(input, self.tower_func)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment