Commit e0b13533 authored by Yuxin Wu's avatar Yuxin Wu

[Trainerv2] Add TowerTrainer on top of singlecost

parent 0f90d4c2
...@@ -20,7 +20,6 @@ from ..dataflow.base import DataFlow ...@@ -20,7 +20,6 @@ from ..dataflow.base import DataFlow
from ..input_source import ( from ..input_source import (
InputSource, FeedInput, QueueInput) InputSource, FeedInput, QueueInput)
from ..graph_builder.predictor_factory import SimplePredictBuilder from ..graph_builder.predictor_factory import SimplePredictBuilder
# from ..trainv2 import SingleCostTrainer
from .base import Callback from .base import Callback
from .group import Callbacks from .group import Callbacks
...@@ -125,7 +124,7 @@ class InferenceRunner(InferenceRunnerBase): ...@@ -125,7 +124,7 @@ class InferenceRunner(InferenceRunnerBase):
return InferencerToHook(inf, fetches) return InferencerToHook(inf, fetches)
def _setup_graph(self): def _setup_graph(self):
if hasattr(self.trainer, 'model'): if self.trainer._API_VERSION == 1:
# old Trainer API # old Trainer API
assert self.trainer.model is not None assert self.trainer.model is not None
# Use predict_tower in train config. either gpuid or -1 # Use predict_tower in train config. either gpuid or -1
...@@ -142,16 +141,16 @@ class InferenceRunner(InferenceRunnerBase): ...@@ -142,16 +141,16 @@ class InferenceRunner(InferenceRunnerBase):
self._tower_name, device, self._input_source) self._tower_name, device, self._input_source)
else: else:
# new Trainer API # new Trainer API
# only works for singlecost trainer from ..trainv2 import TowerTrainer
# assert isinstance(self.trainer, SingleCostTrainer), self.trainer assert isinstance(self.trainer, TowerTrainer), self.trainer
input_callbacks = self._input_source.setup(self.trainer.inputs_desc) input_callbacks = self._input_source.setup(self.trainer.inputs_desc)
with tf.variable_scope(tf.get_variable_scope(), reuse=True): with tf.variable_scope(tf.get_variable_scope(), reuse=True):
SimplePredictBuilder( SimplePredictBuilder(
ns_name=self._tower_name, ns_name=self._tower_name,
vs_name='', device=0).build( # TODO fix vs_name and maybe device vs_name='', device=0).build( # TODO fix vs_name and maybe device
self._input_source, self.trainer.get_cost_fn) self._input_source, self.trainer.tower_func)
self._tower_handle = self.trainer.get_cost_fn.towers[-1] self._tower_handle = self.trainer.tower_func.towers[-1]
self._hooks = [self._build_hook(inf) for inf in self.infs] self._hooks = [self._build_hook(inf) for inf in self.infs]
# trigger_{step,epoch}, {before,after}_epoch is ignored. # trigger_{step,epoch}, {before,after}_epoch is ignored.
...@@ -202,7 +201,7 @@ class DataParallelInferenceRunner(InferenceRunnerBase): ...@@ -202,7 +201,7 @@ class DataParallelInferenceRunner(InferenceRunnerBase):
def _setup_graph(self): def _setup_graph(self):
self._handles = [] self._handles = []
if hasattr(self.trainer, 'model'): if self.trainer._API_VERSION == 1:
# old Trainer API # old Trainer API
input_callbacks = self._input_source.setup(self.trainer.model.get_inputs_desc()) input_callbacks = self._input_source.setup(self.trainer.model.get_inputs_desc())
# build each predict tower # build each predict tower
...@@ -222,8 +221,8 @@ class DataParallelInferenceRunner(InferenceRunnerBase): ...@@ -222,8 +221,8 @@ class DataParallelInferenceRunner(InferenceRunnerBase):
SimplePredictBuilder( SimplePredictBuilder(
ns_name=tower_name, ns_name=tower_name,
vs_name='', device=t).build( # TODO fix vs_name and maybe device vs_name='', device=t).build( # TODO fix vs_name and maybe device
self._input_source, self.trainer.get_cost_fn) self._input_source, self.trainer.tower_func)
self._handles.append(self.trainer.get_cost_fn.towers[-1]) self._handles.append(self.trainer.tower_func.towers[-1])
# setup callbacks and hooks # setup callbacks and hooks
self._input_callbacks = Callbacks(input_callbacks) self._input_callbacks = Callbacks(input_callbacks)
......
...@@ -154,7 +154,7 @@ class TowerFuncWrapper(object): ...@@ -154,7 +154,7 @@ class TowerFuncWrapper(object):
each time the function is called. each time the function is called.
""" """
def __init__(self, tower_fn, inputs_desc=None): def __init__(self, tower_fn, inputs_desc):
""" """
Args: Args:
tower_func: a function which builds one tower in the graph. tower_func: a function which builds one tower in the graph.
...@@ -168,7 +168,7 @@ class TowerFuncWrapper(object): ...@@ -168,7 +168,7 @@ class TowerFuncWrapper(object):
self._towers = [] self._towers = []
def __new__(cls, tower_fn, inputs_desc=None): def __new__(cls, tower_fn, inputs_desc):
# to avoid double-wrapping a function # to avoid double-wrapping a function
if isinstance(tower_fn, TowerFuncWrapper): if isinstance(tower_fn, TowerFuncWrapper):
return tower_fn return tower_fn
...@@ -188,6 +188,10 @@ class TowerFuncWrapper(object): ...@@ -188,6 +188,10 @@ class TowerFuncWrapper(object):
# TODO another wrapper around towerhandlelist # TODO another wrapper around towerhandlelist
return self._towers return self._towers
@property
def inputs_desc(self):
return self._inputs_desc
class TowerTensorHandle(object): class TowerTensorHandle(object):
""" """
......
...@@ -101,6 +101,8 @@ class Trainer(object): ...@@ -101,6 +101,8 @@ class Trainer(object):
monitors (Monitors): the monitors. Other callbacks can use it for logging. monitors (Monitors): the monitors. Other callbacks can use it for logging.
""" """
_API_VERSION = 1
is_chief = True is_chief = True
""" """
Whether this process is the chief worker in distributed training. Whether this process is the chief worker in distributed training.
......
...@@ -31,6 +31,8 @@ class Trainer(object): ...@@ -31,6 +31,8 @@ class Trainer(object):
""" Base class for a trainer. """ Base class for a trainer.
""" """
_API_VERSION = 2
is_chief = True is_chief = True
def __init__(self): def __init__(self):
...@@ -215,8 +217,39 @@ for name in ['global_step', 'local_step', 'steps_per_epoch', ...@@ -215,8 +217,39 @@ for name in ['global_step', 'local_step', 'steps_per_epoch',
setattr(Trainer, name, _get_property(name)) setattr(Trainer, name, _get_property(name))
class TowerTrainer(Trainer):
"""
Base trainers for models that can be built by calling a tower function under a :class:`TowerContext`.
This is required by some features that replicates the model
automatically, e.g. creating a predictor.
"""
tower_func = None
"""
A :class:`TowerFuncWrapper` instance.
A callable which takes some input tensors and builds one replicate of the model.
"""
@call_only_once
def set_tower_func(self, tower_func):
"""
Args:
tower_func (TowerFuncWrapper)
"""
self.tower_func = tower_func
@property
def inputs_desc(self):
"""
Returns:
list[InputDesc]: metainfo about the inputs to the tower.
"""
return self.tower_func.inputs_desc
@six.add_metaclass(ABCMeta) @six.add_metaclass(ABCMeta)
class SingleCostTrainer(Trainer): class SingleCostTrainer(TowerTrainer):
""" """
Base class for single-cost trainer. Base class for single-cost trainer.
...@@ -261,12 +294,11 @@ class SingleCostTrainer(Trainer): ...@@ -261,12 +294,11 @@ class SingleCostTrainer(Trainer):
""" """
get_cost_fn = TowerFuncWrapper(get_cost_fn, inputs_desc) get_cost_fn = TowerFuncWrapper(get_cost_fn, inputs_desc)
get_opt_fn = memoized(get_opt_fn) get_opt_fn = memoized(get_opt_fn)
self.set_tower_func(get_cost_fn)
input_callbacks = self._setup_input(inputs_desc, input) input_callbacks = self._setup_input(inputs_desc, input)
train_callbacks = self._setup_graph(input, get_cost_fn, get_opt_fn) train_callbacks = self._setup_graph(input, get_cost_fn, get_opt_fn)
self._internal_callbacks = input_callbacks + train_callbacks self._internal_callbacks = input_callbacks + train_callbacks
self.inputs_desc = inputs_desc
self.get_cost_fn = get_cost_fn
return self._internal_callbacks return self._internal_callbacks
@abstractmethod @abstractmethod
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment