Commit 61c113b8 authored by Yuxin Wu's avatar Yuxin Wu

Add a ModelDescBase without single-cost assumptions

parent 7240f877
...@@ -12,7 +12,7 @@ from ..utils.argtools import memoized ...@@ -12,7 +12,7 @@ from ..utils.argtools import memoized
from .input_source_base import InputSource from .input_source_base import InputSource
from ..models.regularize import regularize_cost_from_collection from ..models.regularize import regularize_cost_from_collection
__all__ = ['InputDesc', 'ModelDesc'] __all__ = ['InputDesc', 'ModelDesc', 'ModelDescBase']
class InputDesc( class InputDesc(
...@@ -87,12 +87,10 @@ class InputDesc( ...@@ -87,12 +87,10 @@ class InputDesc(
@six.add_metaclass(ABCMeta) @six.add_metaclass(ABCMeta)
class ModelDesc(object): class ModelDescBase(object):
""" Base class for a model description. """ Base class for a model description. """
"""
# inputs: # TODO remove this method? Now mainly used in predict/
# TODO remove this method?
@memoized @memoized
def get_reused_placehdrs(self): def get_reused_placehdrs(self):
""" """
...@@ -151,11 +149,18 @@ class ModelDesc(object): ...@@ -151,11 +149,18 @@ class ModelDesc(object):
def _build_graph(self, inputs): def _build_graph(self, inputs):
pass pass
class ModelDesc(ModelDescBase):
"""
A ModelDesc with single cost and single optimizers.
"""
def get_cost(self): def get_cost(self):
""" """
Return the cost tensor in the graph. Return the cost tensor in the graph.
Used by some of the tensorpack :class:`Trainer` which assumes single-cost models. Used by some of the tensorpack :class:`Trainer` which assumes single-cost models.
You can ignore this method if you use your own trainer with more than one cost. You can ignore this method (or just use :class:`ModelDescBase`)
if you use your own trainer with more than one cost.
It calls :meth:`ModelDesc._get_cost()` which by default returns It calls :meth:`ModelDesc._get_cost()` which by default returns
``self.cost``. You can override :meth:`_get_cost()` if needed. ``self.cost``. You can override :meth:`_get_cost()` if needed.
...@@ -175,7 +180,7 @@ class ModelDesc(object): ...@@ -175,7 +180,7 @@ class ModelDesc(object):
""" """
Return the optimizer used in the task. Return the optimizer used in the task.
Used by some of the tensorpack :class:`Trainer` which assume single optimizer. Used by some of the tensorpack :class:`Trainer` which assume single optimizer.
You can (and should) ignore this method if you use a custom trainer with more than one optimizers. You should use :class:`ModelDescBase` if you use a custom trainer with more than one optimizers.
Users of :class:`ModelDesc` will need to implement `_get_optimizer()`, Users of :class:`ModelDesc` will need to implement `_get_optimizer()`,
which will only be called once per each model. which will only be called once per each model.
...@@ -187,9 +192,3 @@ class ModelDesc(object): ...@@ -187,9 +192,3 @@ class ModelDesc(object):
def _get_optimizer(self): def _get_optimizer(self):
raise NotImplementedError() raise NotImplementedError()
def get_gradient_processor(self):
return self._get_gradient_processor()
def _get_gradient_processor(self):
return []
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment