Commit da98e447 authored by Yuxin Wu's avatar Yuxin Wu

[Trainerv2] Let InferenceRunner run with new Trainer

parent e5ff50e7
...@@ -19,6 +19,8 @@ from ..dataflow.base import DataFlow ...@@ -19,6 +19,8 @@ from ..dataflow.base import DataFlow
from ..input_source import ( from ..input_source import (
InputSource, FeedInput, QueueInput) InputSource, FeedInput, QueueInput)
from ..graph_builder.predictor_factory import SimplePredictBuilder
# from ..trainv2 import SingleCostTrainer
from .base import Callback from .base import Callback
from .group import Callbacks from .group import Callbacks
...@@ -121,16 +123,30 @@ class InferenceRunner(InferenceRunnerBase): ...@@ -121,16 +123,30 @@ class InferenceRunner(InferenceRunnerBase):
return InferencerToHook(inf, fetches) return InferencerToHook(inf, fetches)
def _setup_graph(self): def _setup_graph(self):
assert self.trainer.model is not None if hasattr(self.trainer, 'model'):
# Use predict_tower in train config. either gpuid or -1 # old Trainer API
tower_id = self.trainer._config.predict_tower[0] assert self.trainer.model is not None
device = '/gpu:{}'.format(tower_id) if tower_id >= 0 else '/cpu:0' # Use predict_tower in train config. either gpuid or -1
tower_id = self.trainer._config.predict_tower[0]
input_callbacks = self._input_source.setup(self.trainer.model.get_inputs_desc()) device = '/gpu:{}'.format(tower_id) if tower_id >= 0 else '/cpu:0'
with tf.variable_scope(tf.get_variable_scope(), reuse=True): input_callbacks = self._input_source.setup(self.trainer.model.get_inputs_desc())
self._tower_handle = self.trainer.predictor_factory.build(
self._tower_name, device, self._input_source) with tf.variable_scope(tf.get_variable_scope(), reuse=True):
self._tower_handle = self.trainer.predictor_factory.build(
self._tower_name, device, self._input_source)
else:
# new Trainer API
# only works for singlecost trainer
# assert isinstance(self.trainer, SingleCostTrainer), self.trainer
input_callbacks = self._input_source.setup(self.trainer.inputs_desc)
with tf.variable_scope(tf.get_variable_scope(), reuse=True):
SimplePredictBuilder(
ns_name=self._tower_name,
vs_name='', device=0).build( # TODO fix vs_name and maybe device
self._input_source, self.trainer.get_cost_fn)
self._tower_handle = self.trainer.get_cost_fn.towers[-1]
self._hooks = [self._build_hook(inf) for inf in self.infs] self._hooks = [self._build_hook(inf) for inf in self.infs]
# trigger_{step,epoch}, {before,after}_epoch is ignored. # trigger_{step,epoch}, {before,after}_epoch is ignored.
...@@ -180,20 +196,32 @@ class DataParallelInferenceRunner(InferenceRunnerBase): ...@@ -180,20 +196,32 @@ class DataParallelInferenceRunner(InferenceRunnerBase):
self._gpus = gpus self._gpus = gpus
def _setup_graph(self): def _setup_graph(self):
assert self.trainer.model is not None
cbs = self._input_source.setup(self.trainer.model.get_inputs_desc())
# build each predict tower
self._handles = [] self._handles = []
with tf.variable_scope(tf.get_variable_scope(), reuse=True): if hasattr(self.trainer, 'model'):
for idx, t in enumerate(self._gpus): # old Trainer API
tower_name = self._tower_names[idx] input_callbacks = self._input_source.setup(self.trainer.model.get_inputs_desc())
device = '/gpu:{}'.format(t) # build each predict tower
self._handles.append( with tf.variable_scope(tf.get_variable_scope(), reuse=True):
self.trainer.predictor_factory.build( for idx, t in enumerate(self._gpus):
tower_name, device, self._input_source)) tower_name = self._tower_names[idx]
device = '/gpu:{}'.format(t)
self._handles.append(
self.trainer.predictor_factory.build(
tower_name, device, self._input_source))
else:
# new Trainer API
input_callbacks = self._input_source.setup(self.trainer.inputs_desc)
with tf.variable_scope(tf.get_variable_scope(), reuse=True):
for idx, t in enumerate(self._gpus):
tower_name = self._tower_names[idx]
SimplePredictBuilder(
ns_name=tower_name,
vs_name='', device=t).build( # TODO fix vs_name and maybe device
self._input_source, self.trainer.get_cost_fn)
self._handles.append(self.trainer.get_cost_fn.towers[-1])
# setup callbacks and hooks # setup callbacks and hooks
self._input_callbacks = Callbacks(cbs) self._input_callbacks = Callbacks(input_callbacks)
# InputSource might have hooks which break us. # InputSource might have hooks which break us.
# e.g. hooks from StagingInputWrapper will force the consumption # e.g. hooks from StagingInputWrapper will force the consumption
......
...@@ -3,13 +3,59 @@ ...@@ -3,13 +3,59 @@
# File: predictor_factory.py # File: predictor_factory.py
import tensorflow as tf import tensorflow as tf
from contextlib import contextmanager
from ..utils import logger from ..utils import logger
from ..tfutils.tower import TowerContext, TowerFuncWrapper from ..tfutils.tower import TowerContext, TowerFuncWrapper
from ..tfutils.collection import freeze_collection from ..tfutils.collection import freeze_collection
from ..utils.naming import TOWER_FREEZE_KEYS from ..utils.naming import TOWER_FREEZE_KEYS
from ..input_source import PlaceholderInput from ..input_source import PlaceholderInput
from .training import GraphBuilder
__all__ = ['SimplePredictBuilder']
class SimplePredictBuilder(GraphBuilder):
"""
Single-tower predictor.
"""
def __init__(self, ns_name='', vs_name='', device=0):
"""
Args:
ns_name (str):
vs_name (str):
device (int):
"""
# TODO does vs_name work properly here when different from ns_name?
self._ns_name = ns_name
self._vs_name = vs_name
device = '/gpu:{}'.format(device) if device >= 0 else '/cpu:0'
self._device = device
__all__ = [] @contextmanager
def _maybe_open_vs(self):
if len(self._vs_name):
with tf.variable_scope(self._vs_name):
yield
else:
yield
def build(self, input, tower_fn):
assert input.setup_done()
logger.info("Building predictor tower '{}' on device {} ...".format(
self._ns_name, self._device))
with tf.device(self._device), \
self._maybe_open_vs(), \
TowerContext(self._ns_name, is_training=False), \
freeze_collection(TOWER_FREEZE_KEYS + [tf.GraphKeys.UPDATE_OPS]):
# also freeze UPDATE_OPS in inference, because they should never be used
# TODO a better way to log and warn about collection change during build_graph.
inputs = input.get_input_tensors()
assert isinstance(inputs, (list, tuple)), inputs
return tower_fn(*inputs)
class PredictorFactory(object): class PredictorFactory(object):
......
...@@ -27,7 +27,7 @@ class TrainConfig(object): ...@@ -27,7 +27,7 @@ class TrainConfig(object):
callbacks=None, extra_callbacks=None, monitors=None, callbacks=None, extra_callbacks=None, monitors=None,
session_creator=None, session_config=None, session_init=None, session_creator=None, session_config=None, session_init=None,
starting_epoch=1, steps_per_epoch=None, max_epoch=99999, starting_epoch=1, steps_per_epoch=None, max_epoch=99999,
nr_tower=1, tower=None, predict_tower=None, nr_tower=1, tower=None,
**kwargs): **kwargs):
""" """
Note: Note:
...@@ -127,6 +127,7 @@ class TrainConfig(object): ...@@ -127,6 +127,7 @@ class TrainConfig(object):
assert self.nr_tower == 1, "Cannot set both nr_tower and tower in TrainConfig!" assert self.nr_tower == 1, "Cannot set both nr_tower and tower in TrainConfig!"
self.tower = tower self.tower = tower
predict_tower = kwargs.pop('predict_tower', None)
if predict_tower is None: if predict_tower is None:
predict_tower = [0] predict_tower = [0]
self.predict_tower = predict_tower self.predict_tower = predict_tower
......
...@@ -16,6 +16,7 @@ from ..callbacks.monitor import Monitors, TrainingMonitor ...@@ -16,6 +16,7 @@ from ..callbacks.monitor import Monitors, TrainingMonitor
from ..tfutils.model_utils import describe_trainable_vars from ..tfutils.model_utils import describe_trainable_vars
from ..tfutils.sessinit import JustCurrentSession from ..tfutils.sessinit import JustCurrentSession
from ..tfutils.sesscreate import ReuseSessionCreator from ..tfutils.sesscreate import ReuseSessionCreator
from ..tfutils.tower import TowerFuncWrapper
from ..callbacks.steps import MaintainStepCounter from ..callbacks.steps import MaintainStepCounter
from ..train.base import StopTraining, TrainLoop from ..train.base import StopTraining, TrainLoop
...@@ -240,9 +241,13 @@ class SingleCostTrainer(Trainer): ...@@ -240,9 +241,13 @@ class SingleCostTrainer(Trainer):
These callbacks will be automatically added when you call `train()`. These callbacks will be automatically added when you call `train()`.
So you can usually ignore the return value. So you can usually ignore the return value.
""" """
get_cost_fn = TowerFuncWrapper(get_cost_fn, inputs_desc)
input_callbacks = self._setup_input(inputs_desc, input) input_callbacks = self._setup_input(inputs_desc, input)
train_callbacks = self._setup_graph(input, get_cost_fn, get_opt_fn) train_callbacks = self._setup_graph(input, get_cost_fn, get_opt_fn)
self._internal_callbacks = input_callbacks + train_callbacks self._internal_callbacks = input_callbacks + train_callbacks
self.inputs_desc = inputs_desc
self.get_cost_fn = get_cost_fn
return self._internal_callbacks return self._internal_callbacks
@abstractmethod @abstractmethod
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment