Commit 8f8fe80d authored by Yuxin Wu's avatar Yuxin Wu

FeedfreePredictor and example on ImageNet eval (fix #772)

parent 7f505225
...@@ -4,15 +4,17 @@ ...@@ -4,15 +4,17 @@
import cv2 import cv2
import numpy as np import numpy as np
import tqdm
import multiprocessing import multiprocessing
import tensorflow as tf import tensorflow as tf
from abc import abstractmethod from abc import abstractmethod
from tensorpack import imgaug, dataset, ModelDesc from tensorpack import ModelDesc
from tensorpack.input_source import QueueInput, StagingInput
from tensorpack.dataflow import ( from tensorpack.dataflow import (
AugmentImageComponent, PrefetchDataZMQ, imgaug, dataset, AugmentImageComponent, PrefetchDataZMQ,
BatchData, MultiThreadMapData) BatchData, MultiThreadMapData)
from tensorpack.predict import PredictConfig, SimpleDatasetPredictor from tensorpack.predict import PredictConfig, FeedfreePredictor
from tensorpack.utils.stats import RatioCounter from tensorpack.utils.stats import RatioCounter
from tensorpack.models import regularize_cost from tensorpack.models import regularize_cost
from tensorpack.tfutils.summary import add_moving_summary from tensorpack.tfutils.summary import add_moving_summary
...@@ -126,12 +128,17 @@ def eval_on_ILSVRC12(model, sessinit, dataflow): ...@@ -126,12 +128,17 @@ def eval_on_ILSVRC12(model, sessinit, dataflow):
input_names=['input', 'label'], input_names=['input', 'label'],
output_names=['wrong-top1', 'wrong-top5'] output_names=['wrong-top1', 'wrong-top5']
) )
pred = SimpleDatasetPredictor(pred_config, dataflow)
acc1, acc5 = RatioCounter(), RatioCounter() acc1, acc5 = RatioCounter(), RatioCounter()
for top1, top5 in pred.get_result():
# This does not have a visible improvement over naive predictor,
# but will have an improvement if image_dtype is set to float32.
pred = FeedfreePredictor(pred_config, StagingInput(QueueInput(dataflow), device='/gpu:0'))
for _ in tqdm.trange(dataflow.size()):
top1, top5 = pred()
batch_size = top1.shape[0] batch_size = top1.shape[0]
acc1.feed(top1.sum(), batch_size) acc1.feed(top1.sum(), batch_size)
acc5.feed(top5.sum(), batch_size) acc5.feed(top5.sum(), batch_size)
print("Top1 Error: {}".format(acc1.ratio)) print("Top1 Error: {}".format(acc1.ratio))
print("Top5 Error: {}".format(acc5.ratio)) print("Top5 Error: {}".format(acc5.ratio))
......
...@@ -547,10 +547,10 @@ class StagingInput(FeedfreeInput): ...@@ -547,10 +547,10 @@ class StagingInput(FeedfreeInput):
self.fetches = tf.train.SessionRunArgs( self.fetches = tf.train.SessionRunArgs(
fetches=[self.stage_op, unstage_op]) fetches=[self.stage_op, unstage_op])
def _prefill(self): def _prefill(self, sess):
logger.info("Pre-filling StagingArea ...") logger.info("Pre-filling StagingArea ...")
for k in range(self.nr_stage): for k in range(self.nr_stage):
self.stage_op.run() self.stage_op.run(session=sess)
logger.info("{} element{} put into StagingArea on each tower.".format( logger.info("{} element{} put into StagingArea on each tower.".format(
self.nr_stage, "s were" if self.nr_stage > 1 else " was")) self.nr_stage, "s were" if self.nr_stage > 1 else " was"))
...@@ -559,7 +559,7 @@ class StagingInput(FeedfreeInput): ...@@ -559,7 +559,7 @@ class StagingInput(FeedfreeInput):
# doing it in `before_train` may not work because QueueInput happens in before_train. # doing it in `before_train` may not work because QueueInput happens in before_train.
if not self._initialized: if not self._initialized:
self._initialized = True self._initialized = True
self._prefill() self._prefill(ctx.session)
# Only step the stagingarea when the input is evaluated in this sess.run # Only step the stagingarea when the input is evaluated in this sess.run
fetches = ctx.original_args.fetches fetches = ctx.original_args.fetches
if dependency_of_fetches(fetches, self._check_dependency_op): if dependency_of_fetches(fetches, self._check_dependency_op):
......
...@@ -118,6 +118,13 @@ class InputSource(object): ...@@ -118,6 +118,13 @@ class InputSource(object):
All callbacks will be automatically marked as `chief_only=False`, All callbacks will be automatically marked as `chief_only=False`,
so they will run on all nodes. so they will run on all nodes.
Callbacks returned by :class:`InputSource` only supports a subset of callback's functionalities:
1. It cannot access the trainer, because an :class:`InputSource` can be used in pure inference.
2. It cannot use the following methods: `trigger_{step,epoch}, {before,after}_epoch`.
In other words, these callbacks should only have the basic functionality of `tf.train.SessionRunHooks`.
Returns: Returns:
list[Callback]: extra callbacks needed by this InputSource. list[Callback]: extra callbacks needed by this InputSource.
""" """
......
#!/usr/bin/env python
from tensorflow.python.training.monitored_session \
import _HookedSession as HookedSession
from .base import PredictorBase
from ..tfutils.tower import PredictTowerContext
from ..callbacks import Callbacks
__all__ = ['FeedfreePredictor']
class FeedfreePredictor(PredictorBase):
"""
Create a predictor that takes inputs from an :class:`InputSource`, instead of from feeds.
An instance `pred` of :class:`FeedfreePredictor` can be called only by `pred()`, which returns
a list of output values as defined in config.output_names.
"""
def __init__(self, config, input_source):
"""
Args:
config (PredictConfig): the config to use.
input_source (InputSource): the feedfree InputSource to use.
Must match the inputs_desc in config.
"""
self._config = config
self._input_source = input_source
assert config.return_input is False, \
"return_input is not supported in FeedfreePredictor! " \
"If you need to fetch inputs, add the names to the output_names!"
self._hooks = []
self.graph = config._maybe_create_graph()
with self.graph.as_default():
self._input_callbacks = Callbacks(
self._input_source.setup(config.inputs_desc))
with PredictTowerContext(''):
self._input_tensors = self._input_source.get_input_tensors()
config.tower_func(*self._input_tensors)
self._tower_handle = config.tower_func.towers[-1]
self._output_tensors = self._tower_handle.get_tensors(config.output_names)
self._input_callbacks.setup_graph(None)
for h in self._input_callbacks.get_hooks():
self._register_hook(h)
self._initialize_session()
def _register_hook(self, hook):
"""
Args:
hook (tf.train.SessionRunHook):
"""
self._hooks.append(hook)
def _initialize_session(self):
# init the session
self._config.session_init._setup_graph()
self._sess = self._config.session_creator.create_session()
self._config.session_init._run_init(self._sess)
with self._sess.as_default():
self._input_callbacks.before_train()
self._hooked_sess = HookedSession(self._sess, self._hooks)
def __call__(self):
return self._hooked_sess.run(self._output_tensors)
def _do_call(self):
raise NotImplementedError("You're calling the wrong function!")
...@@ -51,7 +51,8 @@ def dependency_of_fetches(fetches, op): ...@@ -51,7 +51,8 @@ def dependency_of_fetches(fetches, op):
""" """
try: try:
from tensorflow.python.client.session import _FetchHandler as FetchHandler from tensorflow.python.client.session import _FetchHandler as FetchHandler
handler = FetchHandler(tf.get_default_graph(), fetches, {}) # use the graph of the op, so that this function can be called without being under a default graph
handler = FetchHandler(op.graph, fetches, {})
targets = tuple(handler.fetches() + handler.targets()) targets = tuple(handler.fetches() + handler.targets())
except ImportError: except ImportError:
if isinstance(fetches, list): if isinstance(fetches, list):
......
...@@ -22,7 +22,7 @@ class NewSessionCreator(tf.train.ChiefSessionCreator): ...@@ -22,7 +22,7 @@ class NewSessionCreator(tf.train.ChiefSessionCreator):
""" """
Args: Args:
target, graph, config: same as :meth:`Session.__init__()`. target, graph, config: same as :meth:`Session.__init__()`.
config: defaults to :func:`tfutils.get_default_sess_config()` config: a :class:`tf.ConfigProto` instance, defaults to :func:`tfutils.get_default_sess_config()`
""" """
assert graph is None assert graph is None
......
...@@ -377,6 +377,9 @@ class TowerTensorHandle(object): ...@@ -377,6 +377,9 @@ class TowerTensorHandle(object):
1. The name of the tensor without any tower prefix. 1. The name of the tensor without any tower prefix.
2. The name of an :class:`InputDesc`, if it is used when building the tower. 2. The name of an :class:`InputDesc`, if it is used when building the tower.
In the second case, this method will return the tensor that's used as the corresponding
input to the tower. Note that this tensor may have a different name (e.g. may be an output of a queue).
""" """
name = get_op_tensor_name(name)[1] name = get_op_tensor_name(name)[1]
if len(self.ns_name): if len(self.ns_name):
...@@ -392,10 +395,12 @@ class TowerTensorHandle(object): ...@@ -392,10 +395,12 @@ class TowerTensorHandle(object):
raise raise
else: else:
if name in self._extra_tensor_names: if name in self._extra_tensor_names:
logger.warn( mapped_tensor = self._extra_tensor_names[name]
"'{}' may refer to both the tensor '{}' or the input '{}'.".format( logger.info(
name, ret.name, self._extra_tensor_names[name].name) + "'{}' may refer to both the Tensor/Placeholder '{}' or the input to the tower '{}'.".format(
"Assuming it is the tensor '{}'.".format(ret.name)) name, ret.name, mapped_tensor.name) +
" Assuming it is the input '{}'.".format(mapped_tensor.name))
return mapped_tensor
return ret return ret
def get_tensors(self, names): def get_tensors(self, names):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment