Commit 4d2a7b4c authored by Yuxin Wu's avatar Yuxin Wu

Simplify the connection between ModelDesc and InputSource

parent e46e6bca
...@@ -8,7 +8,8 @@ import numpy as np ...@@ -8,7 +8,8 @@ import numpy as np
import time import time
from tensorpack import (FeedfreeTrainerBase, QueueInput, from tensorpack import (FeedfreeTrainerBase, QueueInput,
ModelDesc, DataFlow, StagingInputWrapper, ModelDesc, DataFlow, StagingInputWrapper,
MultiGPUTrainerBase, LeastLoadedDeviceSetter) MultiGPUTrainerBase, LeastLoadedDeviceSetter,
TowerContext)
from tensorpack.tfutils.summary import add_moving_summary from tensorpack.tfutils.summary import add_moving_summary
...@@ -65,7 +66,8 @@ class GANTrainer(FeedfreeTrainerBase): ...@@ -65,7 +66,8 @@ class GANTrainer(FeedfreeTrainerBase):
def _setup(self): def _setup(self):
super(GANTrainer, self)._setup() super(GANTrainer, self)._setup()
self.build_train_tower() with TowerContext('', is_training=True):
self.model.build_graph(self._input_source)
opt = self.model.get_optimizer() opt = self.model.get_optimizer()
# by default, run one d_min after one g_min # by default, run one d_min after one g_min
...@@ -91,7 +93,8 @@ class SeparateGANTrainer(FeedfreeTrainerBase): ...@@ -91,7 +93,8 @@ class SeparateGANTrainer(FeedfreeTrainerBase):
def _setup(self): def _setup(self):
super(SeparateGANTrainer, self)._setup() super(SeparateGANTrainer, self)._setup()
self.build_train_tower() with TowerContext('', is_training=True):
self.model.build_graph(self._input_source)
opt = self.model.get_optimizer() opt = self.model.get_optimizer()
self.d_min = opt.minimize( self.d_min = opt.minimize(
...@@ -123,8 +126,9 @@ class MultiGPUGANTrainer(MultiGPUTrainerBase, FeedfreeTrainerBase): ...@@ -123,8 +126,9 @@ class MultiGPUGANTrainer(MultiGPUTrainerBase, FeedfreeTrainerBase):
super(MultiGPUGANTrainer, self)._setup() super(MultiGPUGANTrainer, self)._setup()
devices = [LeastLoadedDeviceSetter(d, self._raw_devices) for d in self._raw_devices] devices = [LeastLoadedDeviceSetter(d, self._raw_devices) for d in self._raw_devices]
# NOTE trainer internal APIs subject to change in the future
def get_cost(): def get_cost():
self.build_train_tower() self.model.build_graph(self._input_source)
return [self.model.d_loss, self.model.g_loss] return [self.model.d_loss, self.model.g_loss]
cost_list = MultiGPUTrainerBase.build_on_multi_tower( cost_list = MultiGPUTrainerBase.build_on_multi_tower(
self.config.tower, get_cost, devices) self.config.tower, get_cost, devices)
......
...@@ -90,8 +90,7 @@ class InferenceRunnerBase(Callback): ...@@ -90,8 +90,7 @@ class InferenceRunnerBase(Callback):
self._predict_tower_id = self.trainer.config.predict_tower[0] self._predict_tower_id = self.trainer.config.predict_tower[0]
def fn(_): def fn(_):
in_tensors = self._input_source.get_input_tensors() self.trainer.model.build_graph(self._input_source)
self.trainer.model.build_graph(in_tensors)
with tf.variable_scope(self.trainer.vs_name_for_predictor, reuse=True): with tf.variable_scope(self.trainer.vs_name_for_predictor, reuse=True):
PredictorTowerBuilder(fn, self._prefix).build(self._predict_tower_id) PredictorTowerBuilder(fn, self._prefix).build(self._predict_tower_id)
...@@ -203,8 +202,7 @@ class DataParallelInferenceRunner(InferenceRunnerBase): ...@@ -203,8 +202,7 @@ class DataParallelInferenceRunner(InferenceRunnerBase):
# build graph # build graph
def build_tower(k): def build_tower(k):
# inputs (placeholders) for this tower only # inputs (placeholders) for this tower only
input_tensors = self._input_source.get_input_tensors() model.build_graph(self._input_source)
model.build_graph(input_tensors)
builder = PredictorTowerBuilder(build_tower, prefix=self._prefix) builder = PredictorTowerBuilder(build_tower, prefix=self._prefix)
with tf.variable_scope(tf.get_variable_scope(), reuse=True): with tf.variable_scope(tf.get_variable_scope(), reuse=True):
......
...@@ -9,6 +9,8 @@ import tensorflow as tf ...@@ -9,6 +9,8 @@ import tensorflow as tf
import six import six
from ..utils.argtools import memoized from ..utils.argtools import memoized
# TODO sort out import issues
# from ..train.input_source import InputSource
from .regularize import regularize_cost_from_collection from .regularize import regularize_cost_from_collection
__all__ = ['InputDesc', 'ModelDesc'] __all__ = ['InputDesc', 'ModelDesc']
...@@ -130,15 +132,21 @@ class ModelDesc(object): ...@@ -130,15 +132,21 @@ class ModelDesc(object):
:returns: a list of InputDesc :returns: a list of InputDesc
""" """
def build_graph(self, model_inputs): # TODO only use InputSource in the future? Now mainly used in predict/
def build_graph(self, inputs):
""" """
Build the whole symbolic graph. Build the whole symbolic graph.
Args: Args:
model_inputs (list[tf.Tensor]): a list of inputs, corresponding to inputs (list[tf.Tensor] or InputSource): a list of tensors, or an :class:`InputSource`,
InputDesc of this model. that match the list of :class:`InputDesc` defined by ``_get_inputs``.
""" """
self._build_graph(model_inputs) if not isinstance(inputs, (list, tuple)):
inputs = inputs.get_input_tensors()
assert len(inputs) == len(self.get_inputs_desc()), \
"Number of inputs passed to the graph != number of inputs defined " \
"in ModelDesc! ({} != {})".format(len(inputs), len(self.get_inputs_desc()))
self._build_graph(inputs)
@abstractmethod @abstractmethod
def _build_graph(self, inputs): def _build_graph(self, inputs):
......
...@@ -108,6 +108,15 @@ class Trainer(object): ...@@ -108,6 +108,15 @@ class Trainer(object):
""" Abstract method: run one iteration. Subclass should define what is "iteration". """ Abstract method: run one iteration. Subclass should define what is "iteration".
""" """
def _setup_input_source(self, input_source):
"""
Setup InputSource on this trainer.
"""
input_source.setup(self.model.get_inputs_desc())
cbs = input_source.get_callbacks()
for cb in cbs:
self.register_callback(cb)
def setup(self): def setup(self):
""" """
Setup the trainer and be ready for the main loop. Setup the trainer and be ready for the main loop.
......
...@@ -20,27 +20,10 @@ class FeedfreeTrainerBase(Trainer): ...@@ -20,27 +20,10 @@ class FeedfreeTrainerBase(Trainer):
""" A base trainer which runs iteration without feed_dict (therefore faster) """ A base trainer which runs iteration without feed_dict (therefore faster)
Expect ``self.data`` to be a :class:`FeedfreeInput`. Expect ``self.data`` to be a :class:`FeedfreeInput`.
""" """
def build_train_tower(self):
"""
Get input tensors from `self.input_source` and build the forward graph.
"""
def f():
self._input_tensors = self._input_source.get_input_tensors()
self.model.build_graph(self._input_tensors)
ctx = get_current_tower_context()
if ctx is None: # call without a context, use a default one
with TowerContext('', is_training=True):
f()
else:
assert ctx.is_training, ctx
f()
def _setup(self): def _setup(self):
assert isinstance(self._input_source, FeedfreeInput), type(self._input_source) assert isinstance(self._input_source, FeedfreeInput), type(self._input_source)
self._input_source.setup(self.model.get_inputs_desc()) self._setup_input_source(self._input_source)
input_callbacks = self._input_source.get_callbacks()
for cb in input_callbacks:
self.register_callback(cb)
def run_step(self): def run_step(self):
""" Simply run ``self.train_op``.""" """ Simply run ``self.train_op``."""
...@@ -51,10 +34,14 @@ class SingleCostFeedfreeTrainer(FeedfreeTrainerBase): ...@@ -51,10 +34,14 @@ class SingleCostFeedfreeTrainer(FeedfreeTrainerBase):
""" A feedfree Trainer which assumes a single cost. """ """ A feedfree Trainer which assumes a single cost. """
def _get_cost_and_grad(self): def _get_cost_and_grad(self):
""" get the cost and gradient""" """ get the cost and gradient"""
self.build_train_tower() ctx = get_current_tower_context()
assert ctx.is_training, ctx
self.model.build_graph(self._input_source)
cost = self.model.get_cost() # assume single cost cost = self.model.get_cost() # assume single cost
# produce gradients
varlist = tf.trainable_variables() varlist = tf.trainable_variables()
ctx = get_current_tower_context()
if ctx is not None and ctx.has_own_variables and ctx.vs_name: if ctx is not None and ctx.has_own_variables and ctx.vs_name:
# only optimize w.r.t vars in this tower # only optimize w.r.t vars in this tower
# TODO use ctx.vars? # TODO use ctx.vars?
...@@ -93,8 +80,6 @@ class SimpleFeedfreeTrainer(SingleCostFeedfreeTrainer): ...@@ -93,8 +80,6 @@ class SimpleFeedfreeTrainer(SingleCostFeedfreeTrainer):
cost, grads = self._get_cost_and_grad() cost, grads = self._get_cost_and_grad()
opt = self.model.get_optimizer() opt = self.model.get_optimizer()
self.train_op = opt.apply_gradients(grads, name='min_op') self.train_op = opt.apply_gradients(grads, name='min_op')
# skip training
# self.train_op = tf.group(*self._input_tensors)
def QueueInputTrainer(config, input_queue=None): def QueueInputTrainer(config, input_queue=None):
......
...@@ -34,15 +34,11 @@ class SimpleTrainer(Trainer): ...@@ -34,15 +34,11 @@ class SimpleTrainer(Trainer):
self.hooked_sess.run(self.train_op) self.hooked_sess.run(self.train_op)
def _setup(self): def _setup(self):
model = self.model self._setup_input_source(self._input_source)
self._input_source.setup(model.get_inputs_desc())
cbs = self._input_source.get_callbacks()
for cb in cbs:
self.register_callback(cb)
self.inputs = self._input_source.get_input_tensors()
with TowerContext('', is_training=True): with TowerContext('', is_training=True):
model.build_graph(self.inputs) self.model.build_graph(self._input_source)
cost_var = model.get_cost() cost_var = self.model.get_cost()
opt = self.model.get_optimizer() opt = self.model.get_optimizer()
self.train_op = opt.minimize(cost_var, name='min_op') self.train_op = opt.minimize(cost_var, name='min_op')
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment