Commit 9a777e98 authored by Yuxin Wu's avatar Yuxin Wu

Add NoOpTrainer

parent 7968aabe
...@@ -36,7 +36,7 @@ download the annotation files `instances_minival2014.json`, ...@@ -36,7 +36,7 @@ download the annotation files `instances_minival2014.json`,
[here](https://github.com/rbgirshick/py-faster-rcnn/blob/master/data/README.md) [here](https://github.com/rbgirshick/py-faster-rcnn/blob/master/data/README.md)
to `annotations/` as well. to `annotations/` as well.
<sub><sup>Note that train2017==trainval35k==train2014+val2014-minival2014, and val2017==minival2014</sup></sub> <sup>Note that train2017==trainval35k==train2014+val2014-minival2014, and val2017==minival2014</sup>
## Usage ## Usage
......
...@@ -19,7 +19,16 @@ __all__ = ['ScalarStats', 'Inferencer', ...@@ -19,7 +19,16 @@ __all__ = ['ScalarStats', 'Inferencer',
@six.add_metaclass(ABCMeta) @six.add_metaclass(ABCMeta)
class Inferencer(Callback): class Inferencer(Callback):
""" Base class of Inferencer. """ Base class of Inferencer.
Inferencer is a special kind of callback that should be called by :class:`InferenceRunner`. """ Inferencer is a special kind of callback that should be called by :class:`InferenceRunner`.
It has the methods `_get_fetches` and `_on_fetches` which are like
:class:`SessionRunHooks`, except that they will be used only by :class:`InferenceRunner`.
.. document private functions
.. automethod:: _before_inference
.. automethod:: _after_inference
.. automethod:: _get_fetches
.. automethod:: _on_fetches
"""
def _before_epoch(self): def _before_epoch(self):
self._before_inference() self._before_inference()
...@@ -58,6 +67,9 @@ class Inferencer(Callback): ...@@ -58,6 +67,9 @@ class Inferencer(Callback):
return [get_op_tensor_name(n)[1] for n in ret] return [get_op_tensor_name(n)[1] for n in ret]
def _get_fetches(self): def _get_fetches(self):
"""
To be implemented by subclasses
"""
raise NotImplementedError() raise NotImplementedError()
def on_fetches(self, results): def on_fetches(self, results):
...@@ -71,6 +83,9 @@ class Inferencer(Callback): ...@@ -71,6 +83,9 @@ class Inferencer(Callback):
self._on_fetches(results) self._on_fetches(results)
def _on_fetches(self, results): def _on_fetches(self, results):
"""
To be implemented by subclasses
"""
raise NotImplementedError() raise NotImplementedError()
......
...@@ -200,8 +200,11 @@ class Monitors(Callback): ...@@ -200,8 +200,11 @@ class Monitors(Callback):
If you run multiprocess training, keep in mind that If you run multiprocess training, keep in mind that
the data is perhaps only available on chief process. the data is perhaps only available on chief process.
Returns:
scalar
""" """
return self._scalar_history.get_latest(name) return self._scalar_history.get_latest(name)[1]
def get_history(self, name): def get_history(self, name):
""" """
......
...@@ -318,6 +318,8 @@ class MapDataComponent(MapData): ...@@ -318,6 +318,8 @@ class MapDataComponent(MapData):
if r is None: if r is None:
return None return None
dp = copy(dp) # shallow copy to avoid modifying the datapoint dp = copy(dp) # shallow copy to avoid modifying the datapoint
if isinstance(dp, tuple):
dp = list(dp) # to be able to modify it in the next line
dp[self._index] = r dp[self._index] = r
return dp return dp
......
...@@ -17,8 +17,8 @@ __all__ = ['PredictConfig'] ...@@ -17,8 +17,8 @@ __all__ = ['PredictConfig']
class PredictConfig(object): class PredictConfig(object):
def __init__(self, def __init__(self,
model=None, model=None,
inputs_desc=None,
tower_func=None, tower_func=None,
inputs_desc=None,
input_names=None, input_names=None,
output_names=None, output_names=None,
...@@ -35,8 +35,10 @@ class PredictConfig(object): ...@@ -35,8 +35,10 @@ class PredictConfig(object):
Args: Args:
model (ModelDescBase): to be used to obtain inputs_desc and tower_func. model (ModelDescBase): to be used to obtain inputs_desc and tower_func.
inputs_desc ([InputDesc]):
tower_func: a callable which takes input tensors (by positional args) and construct a tower. tower_func: a callable which takes input tensors (by positional args) and construct a tower.
or a :class:`tfutils.TowerFuncWrapper` instance, which packs both `inputs_desc` and function together.
inputs_desc ([InputDesc]): if tower_func is a plain function (instead of a TowerFuncWrapper), this describes
the list of inputs it takes.
input_names (list): a list of input tensor names. Defaults to match inputs_desc. input_names (list): a list of input tensor names. Defaults to match inputs_desc.
output_names (list): a list of names of the output tensors to predict, the output_names (list): a list of names of the output tensors to predict, the
...@@ -59,6 +61,8 @@ class PredictConfig(object): ...@@ -59,6 +61,8 @@ class PredictConfig(object):
self.inputs_desc = model.get_inputs_desc() self.inputs_desc = model.get_inputs_desc()
self.tower_func = TowerFuncWrapper(model.build_graph, self.inputs_desc) self.tower_func = TowerFuncWrapper(model.build_graph, self.inputs_desc)
else: else:
if isinstance(tower_func, TowerFuncWrapper):
inputs_desc = tower_func.inputs_desc
assert inputs_desc is not None and tower_func is not None assert inputs_desc is not None and tower_func is not None
self.inputs_desc = inputs_desc self.inputs_desc = inputs_desc
self.tower_func = TowerFuncWrapper(tower_func, inputs_desc) self.tower_func = TowerFuncWrapper(tower_func, inputs_desc)
......
...@@ -25,7 +25,7 @@ from ..graph_builder.utils import override_to_local_variable ...@@ -25,7 +25,7 @@ from ..graph_builder.utils import override_to_local_variable
from .tower import SingleCostTrainer from .tower import SingleCostTrainer
__all__ = ['SimpleTrainer', __all__ = ['NoOpTrainer', 'SimpleTrainer',
'QueueInputTrainer', 'QueueInputTrainer',
'SyncMultiGPUTrainer', 'SyncMultiGPUTrainer',
'SyncMultiGPUTrainerReplicated', 'SyncMultiGPUTrainerReplicated',
...@@ -56,6 +56,18 @@ class SimpleTrainer(SingleCostTrainer): ...@@ -56,6 +56,18 @@ class SimpleTrainer(SingleCostTrainer):
return [] return []
class NoOpTrainer(SimpleTrainer):
"""
A special trainer that builds the graph (if given a tower function)
and does nothing in each step.
It is used to only run the callbacks.
Note that `steps_per_epoch` and `max_epochs` are still valid options.
"""
def run_step(self):
pass
# Only exists for type check & back-compatibility # Only exists for type check & back-compatibility
class QueueInputTrainer(SimpleTrainer): class QueueInputTrainer(SimpleTrainer):
def _setup_graph(self, input, get_cost_fn, get_opt_fn): def _setup_graph(self, input, get_cost_fn, get_opt_fn):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment