Commit 9a777e98 authored by Yuxin Wu's avatar Yuxin Wu

Add NoOpTrainer

parent 7968aabe
......@@ -36,7 +36,7 @@ download the annotation files `instances_minival2014.json`,
[here](https://github.com/rbgirshick/py-faster-rcnn/blob/master/data/README.md)
to `annotations/` as well.
<sub><sup>Note that train2017==trainval35k==train2014+val2014-minival2014, and val2017==minival2014</sup></sub>
<sup>Note that train2017==trainval35k==train2014+val2014-minival2014, and val2017==minival2014</sup>
## Usage
......
......@@ -19,7 +19,16 @@ __all__ = ['ScalarStats', 'Inferencer',
@six.add_metaclass(ABCMeta)
class Inferencer(Callback):
""" Base class of Inferencer.
Inferencer is a special kind of callback that should be called by :class:`InferenceRunner`. """
Inferencer is a special kind of callback that should be called by :class:`InferenceRunner`.
It has the methods `_get_fetches` and `_on_fetches` which are like
:class:`SessionRunHooks`, except that they will be used only by :class:`InferenceRunner`.
.. document private functions
.. automethod:: _before_inference
.. automethod:: _after_inference
.. automethod:: _get_fetches
.. automethod:: _on_fetches
"""
def _before_epoch(self):
self._before_inference()
......@@ -58,6 +67,9 @@ class Inferencer(Callback):
return [get_op_tensor_name(n)[1] for n in ret]
def _get_fetches(self):
"""
To be implemented by subclasses
"""
raise NotImplementedError()
def on_fetches(self, results):
......@@ -71,6 +83,9 @@ class Inferencer(Callback):
self._on_fetches(results)
def _on_fetches(self, results):
"""
To be implemented by subclasses
"""
raise NotImplementedError()
......
......@@ -200,8 +200,11 @@ class Monitors(Callback):
If you run multiprocess training, keep in mind that
the data is perhaps only available on chief process.
Returns:
scalar
"""
return self._scalar_history.get_latest(name)
return self._scalar_history.get_latest(name)[1]
def get_history(self, name):
"""
......
......@@ -318,6 +318,8 @@ class MapDataComponent(MapData):
if r is None:
return None
dp = copy(dp) # shallow copy to avoid modifying the datapoint
if isinstance(dp, tuple):
dp = list(dp) # to be able to modify it in the next line
dp[self._index] = r
return dp
......
......@@ -17,8 +17,8 @@ __all__ = ['PredictConfig']
class PredictConfig(object):
def __init__(self,
model=None,
inputs_desc=None,
tower_func=None,
inputs_desc=None,
input_names=None,
output_names=None,
......@@ -35,8 +35,10 @@ class PredictConfig(object):
Args:
model (ModelDescBase): to be used to obtain inputs_desc and tower_func.
inputs_desc ([InputDesc]):
tower_func: a callable which takes input tensors (by positional args) and construct a tower.
or a :class:`tfutils.TowerFuncWrapper` instance, which packs both `inputs_desc` and function together.
inputs_desc ([InputDesc]): if tower_func is a plain function (instead of a TowerFuncWrapper), this describes
the list of inputs it takes.
input_names (list): a list of input tensor names. Defaults to match inputs_desc.
output_names (list): a list of names of the output tensors to predict, the
......@@ -59,6 +61,8 @@ class PredictConfig(object):
self.inputs_desc = model.get_inputs_desc()
self.tower_func = TowerFuncWrapper(model.build_graph, self.inputs_desc)
else:
if isinstance(tower_func, TowerFuncWrapper):
inputs_desc = tower_func.inputs_desc
assert inputs_desc is not None and tower_func is not None
self.inputs_desc = inputs_desc
self.tower_func = TowerFuncWrapper(tower_func, inputs_desc)
......
......@@ -25,7 +25,7 @@ from ..graph_builder.utils import override_to_local_variable
from .tower import SingleCostTrainer
__all__ = ['SimpleTrainer',
__all__ = ['NoOpTrainer', 'SimpleTrainer',
'QueueInputTrainer',
'SyncMultiGPUTrainer',
'SyncMultiGPUTrainerReplicated',
......@@ -56,6 +56,18 @@ class SimpleTrainer(SingleCostTrainer):
return []
class NoOpTrainer(SimpleTrainer):
"""
A special trainer that builds the graph (if given a tower function)
and does nothing in each step.
It is used to only run the callbacks.
Note that `steps_per_epoch` and `max_epochs` are still valid options.
"""
def run_step(self):
pass
# Only exists for type check & back-compatibility
class QueueInputTrainer(SimpleTrainer):
def _setup_graph(self, input, get_cost_fn, get_opt_fn):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment