Commit 9a711e72 authored by Yuxin Wu's avatar Yuxin Wu

[Trainerv2] use v2 inference interface in v1 trainer

parent e121701a
...@@ -108,7 +108,7 @@ class InferenceRunner(InferenceRunnerBase): ...@@ -108,7 +108,7 @@ class InferenceRunner(InferenceRunnerBase):
infs (list): a list of :class:`Inferencer` instances. infs (list): a list of :class:`Inferencer` instances.
tower_name (str): the name scope of the tower to build. Need to set a tower_name (str): the name scope of the tower to build. Need to set a
different one if multiple InferenceRunner are used. different one if multiple InferenceRunner are used.
gpu (int): the device to use device (int): the device to use
""" """
if isinstance(input, DataFlow): if isinstance(input, DataFlow):
input = FeedInput(input, infinite=False) input = FeedInput(input, infinite=False)
...@@ -124,34 +124,19 @@ class InferenceRunner(InferenceRunnerBase): ...@@ -124,34 +124,19 @@ class InferenceRunner(InferenceRunnerBase):
return InferencerToHook(inf, fetches) return InferencerToHook(inf, fetches)
def _setup_graph(self): def _setup_graph(self):
if self.trainer._API_VERSION == 1: if self.trainer._API_VERSION == 1 and self.trainer._config.predict_tower is not None:
# old Trainer API device = self.trainer._config.predict_tower[0]
assert self.trainer.model is not None
# Use predict_tower in train config. either gpuid or -1
if self.trainer._config.predict_tower is not None:
device = self.trainer._config.predict_tower[0]
else:
device = self._device
device = '/gpu:{}'.format(device) if device >= 0 else '/cpu:0'
input_callbacks = self._input_source.setup(self.trainer.model.get_inputs_desc())
with tf.variable_scope(tf.get_variable_scope(), reuse=True):
self._tower_handle = self.trainer.predictor_factory.build(
self._tower_name, device, self._input_source)
else: else:
# new Trainer API device = self._device
from ..trainv2 import TowerTrainer assert self.trainer.tower_func is not None, "You must set tower_func of the trainer to use InferenceRunner!"
assert isinstance(self.trainer, TowerTrainer), self.trainer input_callbacks = self._input_source.setup(self.trainer.inputs_desc)
assert self.trainer.tower_func is not None, "You must set tower_func of the trainer to use InferenceRunner!"
input_callbacks = self._input_source.setup(self.trainer.inputs_desc)
with tf.variable_scope(tf.get_variable_scope(), reuse=True): with tf.variable_scope(tf.get_variable_scope(), reuse=True):
SimplePredictBuilder( SimplePredictBuilder(
ns_name=self._tower_name, ns_name=self._tower_name,
vs_name=self.trainer._main_tower_vs_name, device=0).build( vs_name=self.trainer._main_tower_vs_name, device=device).build(
self._input_source, self.trainer.tower_func) self._input_source, self.trainer.tower_func)
self._tower_handle = self.trainer.tower_func.towers[-1] self._tower_handle = self.trainer.tower_func.towers[-1]
self._hooks = [self._build_hook(inf) for inf in self.infs] self._hooks = [self._build_hook(inf) for inf in self.infs]
# trigger_{step,epoch}, {before,after}_epoch is ignored. # trigger_{step,epoch}, {before,after}_epoch is ignored.
...@@ -202,31 +187,17 @@ class DataParallelInferenceRunner(InferenceRunnerBase): ...@@ -202,31 +187,17 @@ class DataParallelInferenceRunner(InferenceRunnerBase):
def _setup_graph(self): def _setup_graph(self):
self._handles = [] self._handles = []
if self.trainer._API_VERSION == 1:
# old Trainer API assert self.trainer.tower_func is not None, "You must set tower_func of the trainer to use InferenceRunner!"
input_callbacks = self._input_source.setup(self.trainer.model.get_inputs_desc()) input_callbacks = self._input_source.setup(self.trainer.inputs_desc)
# build each predict tower with tf.variable_scope(tf.get_variable_scope(), reuse=True):
with tf.variable_scope(tf.get_variable_scope(), reuse=True): for idx, t in enumerate(self._gpus):
for idx, t in enumerate(self._gpus): tower_name = self._tower_names[idx]
tower_name = self._tower_names[idx] SimplePredictBuilder(
device = '/gpu:{}'.format(t) ns_name=tower_name,
self._handles.append( vs_name=self.trainer._main_tower_vs_name, device=t).build(
self.trainer.predictor_factory.build( self._input_source, self.trainer.tower_func)
tower_name, device, self._input_source)) self._handles.append(self.trainer.tower_func.towers[-1])
else:
# new Trainer API
from ..trainv2 import TowerTrainer
assert isinstance(self.trainer, TowerTrainer), self.trainer
assert self.trainer.tower_func is not None, "You must set tower_func of the trainer to use InferenceRunner!"
input_callbacks = self._input_source.setup(self.trainer.inputs_desc)
with tf.variable_scope(tf.get_variable_scope(), reuse=True):
for idx, t in enumerate(self._gpus):
tower_name = self._tower_names[idx]
SimplePredictBuilder(
ns_name=tower_name,
vs_name=self.trainer._main_tower_vs_name, device=t).build(
self._input_source, self.trainer.tower_func)
self._handles.append(self.trainer.tower_func.towers[-1])
# setup callbacks and hooks # setup callbacks and hooks
self._input_callbacks = Callbacks(input_callbacks) self._input_callbacks = Callbacks(input_callbacks)
......
...@@ -60,9 +60,6 @@ class TowerContext(object): ...@@ -60,9 +60,6 @@ class TowerContext(object):
(self.is_training and len(self._vs_name) > 0) or \ (self.is_training and len(self._vs_name) > 0) or \
(not self.is_training and len(self._vs_name) > 0 and not self._initial_vs_reuse) (not self.is_training and len(self._vs_name) > 0 and not self._initial_vs_reuse)
# TODO clarify the interface on name/vs_name/ns_name.
# TODO in inference, vs_name may need to be different from ns_name.i
# How to deal with this?
@property @property
def name(self): def name(self):
return self._name return self._name
...@@ -151,7 +148,6 @@ class TowerContext(object): ...@@ -151,7 +148,6 @@ class TowerContext(object):
def get_current_tower_context(): def get_current_tower_context():
global _CurrentTowerContext
return _CurrentTowerContext return _CurrentTowerContext
......
...@@ -18,8 +18,11 @@ from ..tfutils import get_global_step_value ...@@ -18,8 +18,11 @@ from ..tfutils import get_global_step_value
from ..tfutils.model_utils import describe_trainable_vars from ..tfutils.model_utils import describe_trainable_vars
from ..tfutils.sesscreate import ReuseSessionCreator from ..tfutils.sesscreate import ReuseSessionCreator
from ..tfutils.sessinit import JustCurrentSession from ..tfutils.sessinit import JustCurrentSession
from ..tfutils.tower import TowerFuncWrapper
from ..graph_builder.predictor_factory import PredictorFactory from ..input_source import PlaceholderInput
from ..graph_builder.predictor_factory import SimplePredictBuilder
from ..predict.base import OnlinePredictor
from ..callbacks.steps import MaintainStepCounter from ..callbacks.steps import MaintainStepCounter
__all__ = ['Trainer', 'StopTraining'] __all__ = ['Trainer', 'StopTraining']
...@@ -117,6 +120,16 @@ class Trainer(object): ...@@ -117,6 +120,16 @@ class Trainer(object):
assert isinstance(config, TrainConfig), type(config) assert isinstance(config, TrainConfig), type(config)
self._config = config self._config = config
self.model = config.model self.model = config.model
if self.model is not None:
def f(*inputs):
self.model.build_graph(inputs)
"""
Only to mimic new trainer interafce on inference.
"""
self.inputs_desc = self.model.get_inputs_desc()
self.tower_func = TowerFuncWrapper(f, self.inputs_desc)
self._callbacks = [] self._callbacks = []
self._monitors = [] self._monitors = []
...@@ -268,8 +281,7 @@ class Trainer(object): ...@@ -268,8 +281,7 @@ class Trainer(object):
def get_predictor(self, input_names, output_names, tower=0): def get_predictor(self, input_names, output_names, tower=0):
""" """
Returns a callable predictor built under ``is_training=False`` tower context. Returns a callable predictor built under ``TowerContext(is_training=False)``.
Note that this method is only valid when this trainer has a ``ModelDesc``.
Args: Args:
input_names (list), output_names(list): list of names input_names (list), output_names(list): list of names
...@@ -278,19 +290,27 @@ class Trainer(object): ...@@ -278,19 +290,27 @@ class Trainer(object):
Returns: Returns:
an :class:`OnlinePredictor`. an :class:`OnlinePredictor`.
""" """
return self.predictor_factory.get_predictor(input_names, output_names, tower) device = tower
assert self.tower_func is not None, "Must set tower_func on the trainer to use get_predictor()!"
tower_name = 'tower-pred-{}'.format(device) if device >= 0 else 'tower-pred-cpu'
@property try:
def predictor_factory(self): tower = self.tower_func.towers[tower_name]
assert self.model is not None, \ except KeyError:
"Predictor can only be built one Trainer has ModelDesc!" input = PlaceholderInput()
if not hasattr(self, '_predictor_factory'): input.setup(self.inputs_desc)
self._predictor_factory = PredictorFactory(
self.model, self.vs_name_for_predictor) with tf.variable_scope(tf.get_variable_scope(), reuse=True):
return self._predictor_factory SimplePredictBuilder(
ns_name=tower_name, vs_name=self._main_tower_vs_name,
device=device).build(input, self.tower_func)
tower = self.tower_func.towers[tower_name]
input_tensors = tower.get_tensors(input_names)
output_tensors = tower.get_tensors(output_names)
return OnlinePredictor(input_tensors, output_tensors)
@property @property
def vs_name_for_predictor(self): def _main_tower_vs_name(self):
# The vs name a predictor should be built under. # The vs name a predictor should be built under.
# for internal use only. Should let graphbuilder return it. # for internal use only. Should let graphbuilder return it.
return "" return ""
......
...@@ -95,5 +95,5 @@ class DistributedTrainerReplicated(Trainer): ...@@ -95,5 +95,5 @@ class DistributedTrainerReplicated(Trainer):
self._config.session_creator = get_distributed_session_creator(self.server) self._config.session_creator = get_distributed_session_creator(self.server)
@property @property
def vs_name_for_predictor(self): def _main_tower_vs_name(self):
return "tower0" return "tower0"
...@@ -268,9 +268,10 @@ class TowerTrainer(Trainer): ...@@ -268,9 +268,10 @@ class TowerTrainer(Trainer):
input = PlaceholderInput() input = PlaceholderInput()
input.setup(self.inputs_desc) input.setup(self.inputs_desc)
SimplePredictBuilder( with tf.variable_scope(tf.get_variable_scope(), reuse=True):
ns_name=tower_name, vs_name=self._main_tower_vs_name, SimplePredictBuilder(
device=device).build(input, self.tower_func) ns_name=tower_name, vs_name=self._main_tower_vs_name,
device=device).build(input, self.tower_func)
tower = self.tower_func.towers[tower_name] tower = self.tower_func.towers[tower_name]
input_tensors = tower.get_tensors(input_names) input_tensors = tower.get_tensors(input_names)
output_tensors = tower.get_tensors(output_names) output_tensors = tower.get_tensors(output_names)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment