Commit 20ee19bc authored by Yuxin Wu's avatar Yuxin Wu

move get_cost_and_grad to ModelDesc

parent 696c7830
......@@ -126,7 +126,6 @@ class MultiGPUGANTrainer(MultiGPUTrainerBase, FeedfreeTrainerBase):
super(MultiGPUGANTrainer, self)._setup()
devices = [LeastLoadedDeviceSetter(d, self._raw_devices) for d in self._raw_devices]
# NOTE trainer internal APIs subject to change in the future
def get_cost():
self.model.build_graph(self._input_source)
return [self.model.d_loss, self.model.g_loss]
......
......@@ -9,10 +9,13 @@ import tensorflow as tf
import six
from ..utils.argtools import memoized
from ..tfutils.tower import get_current_tower_context
from ..tfutils.gradproc import FilterNoneGrad
from .input_source_base import InputSource
from ..models.regularize import regularize_cost_from_collection
__all__ = ['InputDesc', 'ModelDesc', 'ModelDescBase']
__all__ = ['InputDesc', 'ModelDesc']
# don't expose ModelDescBase for use right now. API wasn't final.
class InputDesc(
......@@ -88,7 +91,8 @@ class InputDesc(
@six.add_metaclass(ABCMeta)
class ModelDescBase(object):
""" Base class for a model description. """
""" Base class for a model description.
"""
# TODO remove this method? Now mainly used in predict/
@memoized
......@@ -152,7 +156,7 @@ class ModelDescBase(object):
class ModelDesc(ModelDescBase):
"""
A ModelDesc with single cost and single optimizers.
A ModelDesc with single cost and single optimizer.
"""
def get_cost(self):
......@@ -192,3 +196,28 @@ class ModelDesc(ModelDescBase):
def _get_optimizer(self):
raise NotImplementedError()
def get_cost_and_grad(self):
"""
Compute gradients with ``self.get_optimizer()`` on ``self.get_cost()``.
Returns:
cost (tf.Tensor): the cost tensor returned by ``self.get_cost()``.
grads (list[tuple]): list of (grad, variable) tuple.
"""
return self._get_cost_and_grad()
def _get_cost_and_grad(self):
ctx = get_current_tower_context()
assert ctx is not None and ctx.is_training, ctx
cost = self.get_cost() # assume single cost
# produce gradients
varlist = ctx.filter_vars_by_vs_name(tf.trainable_variables())
opt = self.get_optimizer()
grads = opt.compute_gradients(
cost, var_list=varlist,
gate_gradients=False, colocate_gradients_with_ops=True)
grads = FilterNoneGrad().process(grads)
return cost, grads
......@@ -3,12 +3,8 @@
# File: feedfree.py
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
import tensorflow as tf
from six.moves import zip
from ..utils import logger
from ..tfutils.gradproc import FilterNoneGrad
from ..tfutils.tower import TowerContext, get_current_tower_context
from ..tfutils.tower import TowerContext
from ..graph_builder.input_source import QueueInput, FeedfreeInput
from .base import Trainer
......@@ -42,22 +38,8 @@ class SingleCostFeedfreeTrainer(FeedfreeTrainerBase):
""" A feedfree Trainer which assumes a single cost. """
def _get_cost_and_grad(self):
""" get the cost and gradient"""
ctx = get_current_tower_context()
assert ctx.is_training, ctx
self.model.build_graph(self._input_source)
cost = self.model.get_cost() # assume single cost
# produce gradients
varlist = ctx.filter_vars_by_vs_name(tf.trainable_variables())
grads = tf.gradients(
cost,
varlist,
gate_gradients=False,
colocate_gradients_with_ops=True)
grads = list(zip(grads, varlist))
grads = FilterNoneGrad().process(grads)
return cost, grads
return self.model.get_cost_and_grad()
class SimpleFeedfreeTrainer(SingleCostFeedfreeTrainer):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment