Commit 20ee19bc authored by Yuxin Wu's avatar Yuxin Wu

move get_cost_and_grad to ModelDesc

parent 696c7830
...@@ -126,7 +126,6 @@ class MultiGPUGANTrainer(MultiGPUTrainerBase, FeedfreeTrainerBase): ...@@ -126,7 +126,6 @@ class MultiGPUGANTrainer(MultiGPUTrainerBase, FeedfreeTrainerBase):
super(MultiGPUGANTrainer, self)._setup() super(MultiGPUGANTrainer, self)._setup()
devices = [LeastLoadedDeviceSetter(d, self._raw_devices) for d in self._raw_devices] devices = [LeastLoadedDeviceSetter(d, self._raw_devices) for d in self._raw_devices]
# NOTE trainer internal APIs subject to change in the future
def get_cost(): def get_cost():
self.model.build_graph(self._input_source) self.model.build_graph(self._input_source)
return [self.model.d_loss, self.model.g_loss] return [self.model.d_loss, self.model.g_loss]
......
...@@ -9,10 +9,13 @@ import tensorflow as tf ...@@ -9,10 +9,13 @@ import tensorflow as tf
import six import six
from ..utils.argtools import memoized from ..utils.argtools import memoized
from ..tfutils.tower import get_current_tower_context
from ..tfutils.gradproc import FilterNoneGrad
from .input_source_base import InputSource from .input_source_base import InputSource
from ..models.regularize import regularize_cost_from_collection from ..models.regularize import regularize_cost_from_collection
__all__ = ['InputDesc', 'ModelDesc', 'ModelDescBase'] __all__ = ['InputDesc', 'ModelDesc']
# don't expose ModelDescBase for use right now. API wasn't final.
class InputDesc( class InputDesc(
...@@ -88,7 +91,8 @@ class InputDesc( ...@@ -88,7 +91,8 @@ class InputDesc(
@six.add_metaclass(ABCMeta) @six.add_metaclass(ABCMeta)
class ModelDescBase(object): class ModelDescBase(object):
""" Base class for a model description. """ """ Base class for a model description.
"""
# TODO remove this method? Now mainly used in predict/ # TODO remove this method? Now mainly used in predict/
@memoized @memoized
...@@ -152,7 +156,7 @@ class ModelDescBase(object): ...@@ -152,7 +156,7 @@ class ModelDescBase(object):
class ModelDesc(ModelDescBase): class ModelDesc(ModelDescBase):
""" """
A ModelDesc with single cost and single optimizers. A ModelDesc with single cost and single optimizer.
""" """
def get_cost(self): def get_cost(self):
...@@ -192,3 +196,28 @@ class ModelDesc(ModelDescBase): ...@@ -192,3 +196,28 @@ class ModelDesc(ModelDescBase):
def _get_optimizer(self): def _get_optimizer(self):
raise NotImplementedError() raise NotImplementedError()
def get_cost_and_grad(self):
"""
Compute gradients with ``self.get_optimizer()`` on ``self.get_cost()``.
Returns:
cost (tf.Tensor): the cost tensor returned by ``self.get_cost()``.
grads (list[tuple]): list of (grad, variable) tuple.
"""
return self._get_cost_and_grad()
def _get_cost_and_grad(self):
ctx = get_current_tower_context()
assert ctx is not None and ctx.is_training, ctx
cost = self.get_cost() # assume single cost
# produce gradients
varlist = ctx.filter_vars_by_vs_name(tf.trainable_variables())
opt = self.get_optimizer()
grads = opt.compute_gradients(
cost, var_list=varlist,
gate_gradients=False, colocate_gradients_with_ops=True)
grads = FilterNoneGrad().process(grads)
return cost, grads
...@@ -3,12 +3,8 @@ ...@@ -3,12 +3,8 @@
# File: feedfree.py # File: feedfree.py
# Author: Yuxin Wu <ppwwyyxxc@gmail.com> # Author: Yuxin Wu <ppwwyyxxc@gmail.com>
import tensorflow as tf
from six.moves import zip
from ..utils import logger from ..utils import logger
from ..tfutils.gradproc import FilterNoneGrad from ..tfutils.tower import TowerContext
from ..tfutils.tower import TowerContext, get_current_tower_context
from ..graph_builder.input_source import QueueInput, FeedfreeInput from ..graph_builder.input_source import QueueInput, FeedfreeInput
from .base import Trainer from .base import Trainer
...@@ -42,22 +38,8 @@ class SingleCostFeedfreeTrainer(FeedfreeTrainerBase): ...@@ -42,22 +38,8 @@ class SingleCostFeedfreeTrainer(FeedfreeTrainerBase):
""" A feedfree Trainer which assumes a single cost. """ """ A feedfree Trainer which assumes a single cost. """
def _get_cost_and_grad(self): def _get_cost_and_grad(self):
""" get the cost and gradient""" """ get the cost and gradient"""
ctx = get_current_tower_context()
assert ctx.is_training, ctx
self.model.build_graph(self._input_source) self.model.build_graph(self._input_source)
cost = self.model.get_cost() # assume single cost return self.model.get_cost_and_grad()
# produce gradients
varlist = ctx.filter_vars_by_vs_name(tf.trainable_variables())
grads = tf.gradients(
cost,
varlist,
gate_gradients=False,
colocate_gradients_with_ops=True)
grads = list(zip(grads, varlist))
grads = FilterNoneGrad().process(grads)
return cost, grads
class SimpleFeedfreeTrainer(SingleCostFeedfreeTrainer): class SimpleFeedfreeTrainer(SingleCostFeedfreeTrainer):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment