Commit da98e447 authored by Yuxin Wu's avatar Yuxin Wu

[Trainerv2] Let InferenceRunner run with new Trainer

parent e5ff50e7
......@@ -19,6 +19,8 @@ from ..dataflow.base import DataFlow
from ..input_source import (
InputSource, FeedInput, QueueInput)
from ..graph_builder.predictor_factory import SimplePredictBuilder
# from ..trainv2 import SingleCostTrainer
from .base import Callback
from .group import Callbacks
......@@ -121,16 +123,30 @@ class InferenceRunner(InferenceRunnerBase):
return InferencerToHook(inf, fetches)
def _setup_graph(self):
assert self.trainer.model is not None
# Use predict_tower in train config. either gpuid or -1
tower_id = self.trainer._config.predict_tower[0]
device = '/gpu:{}'.format(tower_id) if tower_id >= 0 else '/cpu:0'
input_callbacks = self._input_source.setup(self.trainer.model.get_inputs_desc())
with tf.variable_scope(tf.get_variable_scope(), reuse=True):
self._tower_handle = self.trainer.predictor_factory.build(
self._tower_name, device, self._input_source)
if hasattr(self.trainer, 'model'):
# old Trainer API
assert self.trainer.model is not None
# Use predict_tower in train config. either gpuid or -1
tower_id = self.trainer._config.predict_tower[0]
device = '/gpu:{}'.format(tower_id) if tower_id >= 0 else '/cpu:0'
input_callbacks = self._input_source.setup(self.trainer.model.get_inputs_desc())
with tf.variable_scope(tf.get_variable_scope(), reuse=True):
self._tower_handle = self.trainer.predictor_factory.build(
self._tower_name, device, self._input_source)
else:
# new Trainer API
# only works for singlecost trainer
# assert isinstance(self.trainer, SingleCostTrainer), self.trainer
input_callbacks = self._input_source.setup(self.trainer.inputs_desc)
with tf.variable_scope(tf.get_variable_scope(), reuse=True):
SimplePredictBuilder(
ns_name=self._tower_name,
vs_name='', device=0).build( # TODO fix vs_name and maybe device
self._input_source, self.trainer.get_cost_fn)
self._tower_handle = self.trainer.get_cost_fn.towers[-1]
self._hooks = [self._build_hook(inf) for inf in self.infs]
# trigger_{step,epoch}, {before,after}_epoch is ignored.
......@@ -180,20 +196,32 @@ class DataParallelInferenceRunner(InferenceRunnerBase):
self._gpus = gpus
def _setup_graph(self):
assert self.trainer.model is not None
cbs = self._input_source.setup(self.trainer.model.get_inputs_desc())
# build each predict tower
self._handles = []
with tf.variable_scope(tf.get_variable_scope(), reuse=True):
for idx, t in enumerate(self._gpus):
tower_name = self._tower_names[idx]
device = '/gpu:{}'.format(t)
self._handles.append(
self.trainer.predictor_factory.build(
tower_name, device, self._input_source))
if hasattr(self.trainer, 'model'):
# old Trainer API
input_callbacks = self._input_source.setup(self.trainer.model.get_inputs_desc())
# build each predict tower
with tf.variable_scope(tf.get_variable_scope(), reuse=True):
for idx, t in enumerate(self._gpus):
tower_name = self._tower_names[idx]
device = '/gpu:{}'.format(t)
self._handles.append(
self.trainer.predictor_factory.build(
tower_name, device, self._input_source))
else:
# new Trainer API
input_callbacks = self._input_source.setup(self.trainer.inputs_desc)
with tf.variable_scope(tf.get_variable_scope(), reuse=True):
for idx, t in enumerate(self._gpus):
tower_name = self._tower_names[idx]
SimplePredictBuilder(
ns_name=tower_name,
vs_name='', device=t).build( # TODO fix vs_name and maybe device
self._input_source, self.trainer.get_cost_fn)
self._handles.append(self.trainer.get_cost_fn.towers[-1])
# setup callbacks and hooks
self._input_callbacks = Callbacks(cbs)
self._input_callbacks = Callbacks(input_callbacks)
# InputSource might have hooks which break us.
# e.g. hooks from StagingInputWrapper will force the consumption
......
......@@ -3,13 +3,59 @@
# File: predictor_factory.py
import tensorflow as tf
from contextlib import contextmanager
from ..utils import logger
from ..tfutils.tower import TowerContext, TowerFuncWrapper
from ..tfutils.collection import freeze_collection
from ..utils.naming import TOWER_FREEZE_KEYS
from ..input_source import PlaceholderInput
from .training import GraphBuilder
__all__ = ['SimplePredictBuilder']
class SimplePredictBuilder(GraphBuilder):
"""
Single-tower predictor.
"""
def __init__(self, ns_name='', vs_name='', device=0):
"""
Args:
ns_name (str):
vs_name (str):
device (int):
"""
# TODO does vs_name work properly here when different from ns_name?
self._ns_name = ns_name
self._vs_name = vs_name
device = '/gpu:{}'.format(device) if device >= 0 else '/cpu:0'
self._device = device
__all__ = []
@contextmanager
def _maybe_open_vs(self):
if len(self._vs_name):
with tf.variable_scope(self._vs_name):
yield
else:
yield
def build(self, input, tower_fn):
assert input.setup_done()
logger.info("Building predictor tower '{}' on device {} ...".format(
self._ns_name, self._device))
with tf.device(self._device), \
self._maybe_open_vs(), \
TowerContext(self._ns_name, is_training=False), \
freeze_collection(TOWER_FREEZE_KEYS + [tf.GraphKeys.UPDATE_OPS]):
# also freeze UPDATE_OPS in inference, because they should never be used
# TODO a better way to log and warn about collection change during build_graph.
inputs = input.get_input_tensors()
assert isinstance(inputs, (list, tuple)), inputs
return tower_fn(*inputs)
class PredictorFactory(object):
......
......@@ -27,7 +27,7 @@ class TrainConfig(object):
callbacks=None, extra_callbacks=None, monitors=None,
session_creator=None, session_config=None, session_init=None,
starting_epoch=1, steps_per_epoch=None, max_epoch=99999,
nr_tower=1, tower=None, predict_tower=None,
nr_tower=1, tower=None,
**kwargs):
"""
Note:
......@@ -127,6 +127,7 @@ class TrainConfig(object):
assert self.nr_tower == 1, "Cannot set both nr_tower and tower in TrainConfig!"
self.tower = tower
predict_tower = kwargs.pop('predict_tower', None)
if predict_tower is None:
predict_tower = [0]
self.predict_tower = predict_tower
......
......@@ -16,6 +16,7 @@ from ..callbacks.monitor import Monitors, TrainingMonitor
from ..tfutils.model_utils import describe_trainable_vars
from ..tfutils.sessinit import JustCurrentSession
from ..tfutils.sesscreate import ReuseSessionCreator
from ..tfutils.tower import TowerFuncWrapper
from ..callbacks.steps import MaintainStepCounter
from ..train.base import StopTraining, TrainLoop
......@@ -240,9 +241,13 @@ class SingleCostTrainer(Trainer):
These callbacks will be automatically added when you call `train()`.
So you can usually ignore the return value.
"""
get_cost_fn = TowerFuncWrapper(get_cost_fn, inputs_desc)
input_callbacks = self._setup_input(inputs_desc, input)
train_callbacks = self._setup_graph(input, get_cost_fn, get_opt_fn)
self._internal_callbacks = input_callbacks + train_callbacks
self.inputs_desc = inputs_desc
self.get_cost_fn = get_cost_fn
return self._internal_callbacks
@abstractmethod
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment