Commit 93a177bf authored by Yuxin Wu's avatar Yuxin Wu

[Trainerv2] Add get_predictor support

parent e0b13533
......@@ -143,6 +143,7 @@ class InferenceRunner(InferenceRunnerBase):
# new Trainer API
from ..trainv2 import TowerTrainer
assert isinstance(self.trainer, TowerTrainer), self.trainer
assert self.trainer.tower_func is not None, "You must set tower_func of the trainer to use InferenceRunner!"
input_callbacks = self._input_source.setup(self.trainer.inputs_desc)
with tf.variable_scope(tf.get_variable_scope(), reuse=True):
......@@ -214,6 +215,9 @@ class DataParallelInferenceRunner(InferenceRunnerBase):
tower_name, device, self._input_source))
else:
# new Trainer API
from ..trainv2 import TowerTrainer
assert isinstance(self.trainer, TowerTrainer), self.trainer
assert self.trainer.tower_func is not None, "You must set tower_func of the trainer to use InferenceRunner!"
input_callbacks = self._input_source.setup(self.trainer.inputs_desc)
with tf.variable_scope(tf.get_variable_scope(), reuse=True):
for idx, t in enumerate(self._gpus):
......
......@@ -42,6 +42,14 @@ class SimplePredictBuilder(GraphBuilder):
yield
def build(self, input, tower_fn):
"""
Args:
input (InputSource): must have been setup
tower_fn ( [tf.Tensors] ->): callable that takes input tensors.
Returns:
The return value of tower_fn called under the proper context.
"""
assert input.setup_done()
logger.info("Building predictor tower '{}' on device {} ...".format(
self._ns_name, self._device))
......
......@@ -115,9 +115,10 @@ class OnlinePredictor(PredictorBase):
fetches=output_tensors,
feed_list=input_tensors)
else:
log_once(
"TF>=1.2 is recommended for better performance of predictor!", 'warn')
self._callable = None
else:
log_once(
"TF>=1.2 is recommended for better performance of predictor!", 'warn')
def _do_call_old(self, dp):
feed = dict(zip(self.input_tensors, dp))
......
......@@ -166,7 +166,7 @@ class TowerFuncWrapper(object):
self._tower_fn = tower_fn
self._inputs_desc = inputs_desc
self._towers = []
self._handles = []
def __new__(cls, tower_fn, inputs_desc):
# to avoid double-wrapping a function
......@@ -180,19 +180,33 @@ class TowerFuncWrapper(object):
assert ctx is not None, "Function must be called under TowerContext!"
output = self._tower_fn(*args)
handle = TowerTensorHandle(ctx, args, output, self._inputs_desc)
self._towers.append(handle)
self._handles.append(handle)
return output
@property
def towers(self):
# TODO another wrapper around towerhandlelist
return self._towers
return TowerTensorHandles(self._handles)
@property
def inputs_desc(self):
return self._inputs_desc
class TowerTensorHandles(object):
"""
Wrap a list of :class:`TowerTensorHandle`,
to support access to them by index or names.
"""
def __init__(self, handles):
self._handles = handles
self._name_to_handle = {k.ns_name: k for k in handles}
def __getitem__(self, name_or_index):
if isinstance(name_or_index, int):
return self._handles[name_or_index]
return self._name_to_handle[name_or_index]
class TowerTensorHandle(object):
"""
When a function is called multiple times under each tower,
......@@ -281,14 +295,3 @@ class TowerTensorHandle(object):
The output returned by the tower function.
"""
return self._output
# should move to somewhere else.
# def get_predictor(self, input_names, output_names):
# """
# Get a predictor with tensors inside this tower.
# """
# input_tensors = self.get_tensors(input_names)
# output_tensors = self.get_tensors(output_names)
# # TODO sort out the import order
# from ..predict.base import OnlinePredictor # noqa
# return OnlinePredictor(input_tensors, output_tensors)
......@@ -11,7 +11,6 @@ from abc import abstractmethod, ABCMeta
from ..utils import logger
from ..utils.argtools import call_only_once, memoized
from ..input_source import FeedfreeInput
from ..callbacks import Callback, Callbacks
from ..callbacks.monitor import Monitors, TrainingMonitor
from ..tfutils.model_utils import describe_trainable_vars
......@@ -21,10 +20,14 @@ from ..tfutils.tower import TowerFuncWrapper, get_current_tower_context
from ..tfutils.gradproc import FilterNoneGrad
from ..callbacks.steps import MaintainStepCounter
from ..graph_builder.predictor_factory import SimplePredictBuilder
from ..input_source import FeedfreeInput, PlaceholderInput
from ..predict.base import OnlinePredictor
import tensorpack.train as old_train # noqa
from ..train.base import StopTraining, TrainLoop
__all__ = ['Trainer', 'SingleCostTrainer']
__all__ = ['Trainer', 'SingleCostTrainer', 'TowerTrainer']
class Trainer(object):
......@@ -190,7 +193,8 @@ class Trainer(object):
# create the old trainer when called with TrainConfig
def __new__(cls, *args, **kwargs):
if isinstance(args[0], old_train.TrainConfig) or 'config' in kwargs:
if (len(args) > 0 and isinstance(args[0], old_train.TrainConfig)) \
or 'config' in kwargs:
name = cls.__name__
old_trainer = getattr(old_train, name)
return old_trainer(*args, **kwargs)
......@@ -237,6 +241,7 @@ class TowerTrainer(Trainer):
Args:
tower_func (TowerFuncWrapper)
"""
assert isinstance(tower_func, TowerFuncWrapper), tower_func
self.tower_func = tower_func
@property
......@@ -247,6 +252,34 @@ class TowerTrainer(Trainer):
"""
return self.tower_func.inputs_desc
def get_predictor(self, input_names, output_names, device=0):
"""
Returns a callable predictor built under ``TowerContext(is_training=False)``.
Args:
input_names (list), output_names(list): list of names
device (int): build the predictor on device '/gpu:{device}' or use -1 for '/cpu:0'.
Returns:
an :class:`OnlinePredictor`.
"""
assert self.tower_func is not None, "Must set tower_func on the trainer to use get_predictor()!"
tower_name = 'tower-pred-{}'.format(device) if device >= 0 else 'tower-pred-cpu'
try:
tower = self.tower_func.towers[tower_name]
except KeyError:
input = PlaceholderInput()
input.setup(self.inputs_desc)
SimplePredictBuilder(
ns_name=tower_name, vs_name='',
device=device).build(input, self.tower_func)
tower = self.tower_func.towers[tower_name]
input_tensors = tower.get_tensors(input_names)
output_tensors = tower.get_tensors(output_names)
return OnlinePredictor(input_tensors, output_tensors)
@six.add_metaclass(ABCMeta)
class SingleCostTrainer(TowerTrainer):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment