Commit fdab3db2 authored by Yuxin Wu's avatar Yuxin Wu

deprecation in inferencer

parent f0273bee
...@@ -51,7 +51,6 @@ The components are designed to be independent. You can use only Model or DataFlo ...@@ -51,7 +51,6 @@ The components are designed to be independent. You can use only Model or DataFlo
pip install --user -r requirements.txt pip install --user -r requirements.txt
pip install --user -r opt-requirements.txt (some optional dependencies, you can install later if needed) pip install --user -r opt-requirements.txt (some optional dependencies, you can install later if needed)
``` ```
+ [tcmalloc](http://goog-perftools.sourceforge.net/doc/tcmalloc.html) usually helps.
+ Enable `import tensorpack`: + Enable `import tensorpack`:
``` ```
export PYTHONPATH=$PYTHONPATH:`readlink -f path/to/tensorpack` export PYTHONPATH=$PYTHONPATH:`readlink -f path/to/tensorpack`
......
...@@ -6,6 +6,7 @@ import tensorflow as tf ...@@ -6,6 +6,7 @@ import tensorflow as tf
import numpy as np import numpy as np
from abc import ABCMeta, abstractmethod from abc import ABCMeta, abstractmethod
from collections import namedtuple from collections import namedtuple
import sys
import six import six
from six.moves import zip, map from six.moves import zip, map
...@@ -31,20 +32,21 @@ class Inferencer(object): ...@@ -31,20 +32,21 @@ class Inferencer(object):
def _before_inference(self): def _before_inference(self):
pass pass
def datapoint(self, _, output): def datapoint(self, output):
""" """
Called after complete running every data point Called after complete running every data point
""" """
self._datapoint(_, output) self._datapoint(output)
@abstractmethod @abstractmethod
def _datapoint(self, _, output): def _datapoint(self, output):
pass pass
def after_inference(self): def after_inference(self):
""" """
Called after a round of inference ends. Called after a round of inference ends.
Returns a dict of statistics. Returns a dict of statistics which will be logged by the InferenceRunner.
The inferencer needs to handle other kind of logging by their own.
""" """
return self._after_inference() return self._after_inference()
...@@ -129,9 +131,11 @@ class InferenceRunner(Callback): ...@@ -129,9 +131,11 @@ class InferenceRunner(Callback):
for inf, tensormap in zip(self.infs, self.inf_to_tensors): for inf, tensormap in zip(self.infs, self.inf_to_tensors):
inf_output = [(outputs if k.isOutput else dp)[k.index] inf_output = [(outputs if k.isOutput else dp)[k.index]
for k in tensormap] for k in tensormap]
inf.datapoint(dp, inf_output) inf.datapoint(inf_output)
pbar.update() pbar.update()
self._write_summary_after_inference()
def _write_summary_after_inference(self):
for inf in self.infs: for inf in self.infs:
ret = inf.after_inference() ret = inf.after_inference()
for k, v in six.iteritems(ret): for k, v in six.iteritems(ret):
...@@ -165,7 +169,7 @@ class ScalarStats(Inferencer): ...@@ -165,7 +169,7 @@ class ScalarStats(Inferencer):
def _before_inference(self): def _before_inference(self):
self.stats = [] self.stats = []
def _datapoint(self, _, output): def _datapoint(self, output):
self.stats.append(output) self.stats.append(output)
def _after_inference(self): def _after_inference(self):
...@@ -206,13 +210,11 @@ class ClassificationError(Inferencer): ...@@ -206,13 +210,11 @@ class ClassificationError(Inferencer):
def _before_inference(self): def _before_inference(self):
self.err_stat = RatioCounter() self.err_stat = RatioCounter()
def _datapoint(self, _, outputs): def _datapoint(self, outputs):
vec = outputs[0] vec = outputs[0]
if vec.ndim == 0: if vec.ndim == 0:
if execute_only_once(): logger.error("[DEPRECATED] use a 'wrong vector' for ClassificationError instead of nr_wrong")
logger.warn("[DEPRECATED] use a 'wrong vector' for ClassificationError instead of nr_wrong") sys.exit(1)
batch_size = _[0].shape[0] # assume batched input
wrong = int(vec)
else: else:
# TODO put shape assertion into inferencerrunner # TODO put shape assertion into inferencerrunner
assert vec.ndim == 1, "{} is not a vector!".format(self.wrong_var_name) assert vec.ndim == 1, "{} is not a vector!".format(self.wrong_var_name)
...@@ -243,7 +245,7 @@ class BinaryClassificationStats(Inferencer): ...@@ -243,7 +245,7 @@ class BinaryClassificationStats(Inferencer):
def _before_inference(self): def _before_inference(self):
self.stat = BinaryStatistics() self.stat = BinaryStatistics()
def _datapoint(self, _, outputs): def _datapoint(self, outputs):
pred, label = outputs pred, label = outputs
self.stat.feed(pred, label) self.stat.feed(pred, label)
......
...@@ -9,10 +9,15 @@ from ..utils import get_rng ...@@ -9,10 +9,15 @@ from ..utils import get_rng
__all__ = ['DataFlow', 'ProxyDataFlow', 'RNGDataFlow'] __all__ = ['DataFlow', 'ProxyDataFlow', 'RNGDataFlow']
class DataFlow(object): class DataFlow(object):
""" Base class for all DataFlow """ """ Base class for all DataFlow """
__metaclass__ = ABCMeta __metaclass__ = ABCMeta
class Infinity:
pass
@abstractmethod @abstractmethod
def get_data(self): def get_data(self):
""" """
......
...@@ -165,16 +165,18 @@ class RepeatedData(ProxyDataFlow): ...@@ -165,16 +165,18 @@ class RepeatedData(ProxyDataFlow):
:param nr: number of times to repeat ds. :param nr: number of times to repeat ds.
If nr == -1, repeat ds infinitely many times. If nr == -1, repeat ds infinitely many times.
""" """
if nr == -1:
nr = DataFlow.Infinity
self.nr = nr self.nr = nr
super(RepeatedData, self).__init__(ds) super(RepeatedData, self).__init__(ds)
def size(self): def size(self):
if self.nr == -1: if self.nr == DataFlow.Infinity:
raise RuntimeError("size() is unavailable for infinite dataflow") raise RuntimeError("size() is unavailable for infinite dataflow")
return self.ds.size() * self.nr return self.ds.size() * self.nr
def get_data(self): def get_data(self):
if self.nr == -1: if self.nr == DataFlow.Infinity:
while True: while True:
for dp in self.ds.get_data(): for dp in self.ds.get_data():
yield dp yield dp
......
...@@ -50,7 +50,6 @@ class Trainer(object): ...@@ -50,7 +50,6 @@ class Trainer(object):
assert isinstance(config, TrainConfig), type(config) assert isinstance(config, TrainConfig), type(config)
self.config = config self.config = config
self.model = config.model self.model = config.model
self.model.get_input_vars() # ensure they are present
self.sess = tf.Session(config=self.config.session_config) self.sess = tf.Session(config=self.config.session_config)
self.coord = tf.train.Coordinator() self.coord = tf.train.Coordinator()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment