Commit fdab3db2 authored by Yuxin Wu's avatar Yuxin Wu

deprecation in inferencer

parent f0273bee
......@@ -51,7 +51,6 @@ The components are designed to be independent. You can use only Model or DataFlo
pip install --user -r requirements.txt
pip install --user -r opt-requirements.txt (some optional dependencies, you can install later if needed)
```
+ [tcmalloc](http://goog-perftools.sourceforge.net/doc/tcmalloc.html) usually helps.
+ Enable `import tensorpack`:
```
export PYTHONPATH=$PYTHONPATH:`readlink -f path/to/tensorpack`
......
......@@ -6,6 +6,7 @@ import tensorflow as tf
import numpy as np
from abc import ABCMeta, abstractmethod
from collections import namedtuple
import sys
import six
from six.moves import zip, map
......@@ -31,20 +32,21 @@ class Inferencer(object):
def _before_inference(self):
pass
def datapoint(self, _, output):
def datapoint(self, output):
"""
Called after complete running every data point
"""
self._datapoint(_, output)
self._datapoint(output)
@abstractmethod
def _datapoint(self, _, output):
def _datapoint(self, output):
pass
def after_inference(self):
"""
Called after a round of inference ends.
Returns a dict of statistics.
Returns a dict of statistics which will be logged by the InferenceRunner.
The inferencer needs to handle other kind of logging by their own.
"""
return self._after_inference()
......@@ -129,9 +131,11 @@ class InferenceRunner(Callback):
for inf, tensormap in zip(self.infs, self.inf_to_tensors):
inf_output = [(outputs if k.isOutput else dp)[k.index]
for k in tensormap]
inf.datapoint(dp, inf_output)
inf.datapoint(inf_output)
pbar.update()
self._write_summary_after_inference()
def _write_summary_after_inference(self):
for inf in self.infs:
ret = inf.after_inference()
for k, v in six.iteritems(ret):
......@@ -165,7 +169,7 @@ class ScalarStats(Inferencer):
def _before_inference(self):
self.stats = []
def _datapoint(self, _, output):
def _datapoint(self, output):
self.stats.append(output)
def _after_inference(self):
......@@ -206,13 +210,11 @@ class ClassificationError(Inferencer):
def _before_inference(self):
self.err_stat = RatioCounter()
def _datapoint(self, _, outputs):
def _datapoint(self, outputs):
vec = outputs[0]
if vec.ndim == 0:
if execute_only_once():
logger.warn("[DEPRECATED] use a 'wrong vector' for ClassificationError instead of nr_wrong")
batch_size = _[0].shape[0] # assume batched input
wrong = int(vec)
logger.error("[DEPRECATED] use a 'wrong vector' for ClassificationError instead of nr_wrong")
sys.exit(1)
else:
# TODO put shape assertion into inferencerrunner
assert vec.ndim == 1, "{} is not a vector!".format(self.wrong_var_name)
......@@ -243,7 +245,7 @@ class BinaryClassificationStats(Inferencer):
def _before_inference(self):
self.stat = BinaryStatistics()
def _datapoint(self, _, outputs):
def _datapoint(self, outputs):
pred, label = outputs
self.stat.feed(pred, label)
......
......@@ -9,10 +9,15 @@ from ..utils import get_rng
__all__ = ['DataFlow', 'ProxyDataFlow', 'RNGDataFlow']
class DataFlow(object):
""" Base class for all DataFlow """
__metaclass__ = ABCMeta
class Infinity:
pass
@abstractmethod
def get_data(self):
"""
......
......@@ -165,16 +165,18 @@ class RepeatedData(ProxyDataFlow):
:param nr: number of times to repeat ds.
If nr == -1, repeat ds infinitely many times.
"""
if nr == -1:
nr = DataFlow.Infinity
self.nr = nr
super(RepeatedData, self).__init__(ds)
def size(self):
if self.nr == -1:
if self.nr == DataFlow.Infinity:
raise RuntimeError("size() is unavailable for infinite dataflow")
return self.ds.size() * self.nr
def get_data(self):
if self.nr == -1:
if self.nr == DataFlow.Infinity:
while True:
for dp in self.ds.get_data():
yield dp
......
......@@ -50,7 +50,6 @@ class Trainer(object):
assert isinstance(config, TrainConfig), type(config)
self.config = config
self.model = config.model
self.model.get_input_vars() # ensure they are present
self.sess = tf.Session(config=self.config.session_config)
self.coord = tf.train.Coordinator()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment