Commit 4d2a7b4c authored by Yuxin Wu's avatar Yuxin Wu

Simplify the connection between ModelDesc and InputSource

parent e46e6bca
......@@ -8,7 +8,8 @@ import numpy as np
import time
from tensorpack import (FeedfreeTrainerBase, QueueInput,
ModelDesc, DataFlow, StagingInputWrapper,
MultiGPUTrainerBase, LeastLoadedDeviceSetter)
MultiGPUTrainerBase, LeastLoadedDeviceSetter,
TowerContext)
from tensorpack.tfutils.summary import add_moving_summary
......@@ -65,7 +66,8 @@ class GANTrainer(FeedfreeTrainerBase):
def _setup(self):
super(GANTrainer, self)._setup()
self.build_train_tower()
with TowerContext('', is_training=True):
self.model.build_graph(self._input_source)
opt = self.model.get_optimizer()
# by default, run one d_min after one g_min
......@@ -91,7 +93,8 @@ class SeparateGANTrainer(FeedfreeTrainerBase):
def _setup(self):
super(SeparateGANTrainer, self)._setup()
self.build_train_tower()
with TowerContext('', is_training=True):
self.model.build_graph(self._input_source)
opt = self.model.get_optimizer()
self.d_min = opt.minimize(
......@@ -123,8 +126,9 @@ class MultiGPUGANTrainer(MultiGPUTrainerBase, FeedfreeTrainerBase):
super(MultiGPUGANTrainer, self)._setup()
devices = [LeastLoadedDeviceSetter(d, self._raw_devices) for d in self._raw_devices]
# NOTE trainer internal APIs subject to change in the future
def get_cost():
self.build_train_tower()
self.model.build_graph(self._input_source)
return [self.model.d_loss, self.model.g_loss]
cost_list = MultiGPUTrainerBase.build_on_multi_tower(
self.config.tower, get_cost, devices)
......
......@@ -90,8 +90,7 @@ class InferenceRunnerBase(Callback):
self._predict_tower_id = self.trainer.config.predict_tower[0]
def fn(_):
in_tensors = self._input_source.get_input_tensors()
self.trainer.model.build_graph(in_tensors)
self.trainer.model.build_graph(self._input_source)
with tf.variable_scope(self.trainer.vs_name_for_predictor, reuse=True):
PredictorTowerBuilder(fn, self._prefix).build(self._predict_tower_id)
......@@ -203,8 +202,7 @@ class DataParallelInferenceRunner(InferenceRunnerBase):
# build graph
def build_tower(k):
# inputs (placeholders) for this tower only
input_tensors = self._input_source.get_input_tensors()
model.build_graph(input_tensors)
model.build_graph(self._input_source)
builder = PredictorTowerBuilder(build_tower, prefix=self._prefix)
with tf.variable_scope(tf.get_variable_scope(), reuse=True):
......
......@@ -9,6 +9,8 @@ import tensorflow as tf
import six
from ..utils.argtools import memoized
# TODO sort out import issues
# from ..train.input_source import InputSource
from .regularize import regularize_cost_from_collection
__all__ = ['InputDesc', 'ModelDesc']
......@@ -130,15 +132,21 @@ class ModelDesc(object):
:returns: a list of InputDesc
"""
def build_graph(self, model_inputs):
# TODO only use InputSource in the future? Now mainly used in predict/
def build_graph(self, inputs):
"""
Build the whole symbolic graph.
Args:
model_inputs (list[tf.Tensor]): a list of inputs, corresponding to
InputDesc of this model.
"""
self._build_graph(model_inputs)
inputs (list[tf.Tensor] or InputSource): a list of tensors, or an :class:`InputSource`,
that match the list of :class:`InputDesc` defined by ``_get_inputs``.
"""
if not isinstance(inputs, (list, tuple)):
inputs = inputs.get_input_tensors()
assert len(inputs) == len(self.get_inputs_desc()), \
"Number of inputs passed to the graph != number of inputs defined " \
"in ModelDesc! ({} != {})".format(len(inputs), len(self.get_inputs_desc()))
self._build_graph(inputs)
@abstractmethod
def _build_graph(self, inputs):
......
......@@ -108,6 +108,15 @@ class Trainer(object):
""" Abstract method: run one iteration. Subclass should define what is "iteration".
"""
def _setup_input_source(self, input_source):
"""
Setup InputSource on this trainer.
"""
input_source.setup(self.model.get_inputs_desc())
cbs = input_source.get_callbacks()
for cb in cbs:
self.register_callback(cb)
def setup(self):
"""
Setup the trainer and be ready for the main loop.
......
......@@ -20,27 +20,10 @@ class FeedfreeTrainerBase(Trainer):
""" A base trainer which runs iteration without feed_dict (therefore faster)
Expect ``self.data`` to be a :class:`FeedfreeInput`.
"""
def build_train_tower(self):
"""
Get input tensors from `self.input_source` and build the forward graph.
"""
def f():
self._input_tensors = self._input_source.get_input_tensors()
self.model.build_graph(self._input_tensors)
ctx = get_current_tower_context()
if ctx is None: # call without a context, use a default one
with TowerContext('', is_training=True):
f()
else:
assert ctx.is_training, ctx
f()
def _setup(self):
assert isinstance(self._input_source, FeedfreeInput), type(self._input_source)
self._input_source.setup(self.model.get_inputs_desc())
input_callbacks = self._input_source.get_callbacks()
for cb in input_callbacks:
self.register_callback(cb)
self._setup_input_source(self._input_source)
def run_step(self):
""" Simply run ``self.train_op``."""
......@@ -51,10 +34,14 @@ class SingleCostFeedfreeTrainer(FeedfreeTrainerBase):
""" A feedfree Trainer which assumes a single cost. """
def _get_cost_and_grad(self):
""" get the cost and gradient"""
self.build_train_tower()
ctx = get_current_tower_context()
assert ctx.is_training, ctx
self.model.build_graph(self._input_source)
cost = self.model.get_cost() # assume single cost
# produce gradients
varlist = tf.trainable_variables()
ctx = get_current_tower_context()
if ctx is not None and ctx.has_own_variables and ctx.vs_name:
# only optimize w.r.t vars in this tower
# TODO use ctx.vars?
......@@ -93,8 +80,6 @@ class SimpleFeedfreeTrainer(SingleCostFeedfreeTrainer):
cost, grads = self._get_cost_and_grad()
opt = self.model.get_optimizer()
self.train_op = opt.apply_gradients(grads, name='min_op')
# skip training
# self.train_op = tf.group(*self._input_tensors)
def QueueInputTrainer(config, input_queue=None):
......
......@@ -34,15 +34,11 @@ class SimpleTrainer(Trainer):
self.hooked_sess.run(self.train_op)
def _setup(self):
model = self.model
self._input_source.setup(model.get_inputs_desc())
cbs = self._input_source.get_callbacks()
for cb in cbs:
self.register_callback(cb)
self.inputs = self._input_source.get_input_tensors()
self._setup_input_source(self._input_source)
with TowerContext('', is_training=True):
model.build_graph(self.inputs)
cost_var = model.get_cost()
self.model.build_graph(self._input_source)
cost_var = self.model.get_cost()
opt = self.model.get_optimizer()
self.train_op = opt.minimize(cost_var, name='min_op')
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment