Commit 61c113b8 authored by Yuxin Wu's avatar Yuxin Wu

Add a ModelDescBase without single-cost assumptions

parent 7240f877
......@@ -12,7 +12,7 @@ from ..utils.argtools import memoized
from .input_source_base import InputSource
from ..models.regularize import regularize_cost_from_collection
__all__ = ['InputDesc', 'ModelDesc']
__all__ = ['InputDesc', 'ModelDesc', 'ModelDescBase']
class InputDesc(
......@@ -87,12 +87,10 @@ class InputDesc(
@six.add_metaclass(ABCMeta)
class ModelDesc(object):
""" Base class for a model description.
"""
class ModelDescBase(object):
""" Base class for a model description. """
# inputs:
# TODO remove this method?
# TODO remove this method? Now mainly used in predict/
@memoized
def get_reused_placehdrs(self):
"""
......@@ -151,11 +149,18 @@ class ModelDesc(object):
def _build_graph(self, inputs):
pass
class ModelDesc(ModelDescBase):
"""
A ModelDesc with single cost and single optimizers.
"""
def get_cost(self):
"""
Return the cost tensor in the graph.
Used by some of the tensorpack :class:`Trainer` which assumes single-cost models.
You can ignore this method if you use your own trainer with more than one cost.
You can ignore this method (or just use :class:`ModelDescBase`)
if you use your own trainer with more than one cost.
It calls :meth:`ModelDesc._get_cost()` which by default returns
``self.cost``. You can override :meth:`_get_cost()` if needed.
......@@ -175,7 +180,7 @@ class ModelDesc(object):
"""
Return the optimizer used in the task.
Used by some of the tensorpack :class:`Trainer` which assume single optimizer.
You can (and should) ignore this method if you use a custom trainer with more than one optimizers.
You should use :class:`ModelDescBase` if you use a custom trainer with more than one optimizers.
Users of :class:`ModelDesc` will need to implement `_get_optimizer()`,
which will only be called once per each model.
......@@ -187,9 +192,3 @@ class ModelDesc(object):
def _get_optimizer(self):
raise NotImplementedError()
def get_gradient_processor(self):
return self._get_gradient_processor()
def _get_gradient_processor(self):
return []
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment