Commit 8f8fe80d authored by Yuxin Wu's avatar Yuxin Wu

FeedfreePredictor and example on ImageNet eval (fix #772)

parent 7f505225
......@@ -4,15 +4,17 @@
import cv2
import numpy as np
import tqdm
import multiprocessing
import tensorflow as tf
from abc import abstractmethod
from tensorpack import imgaug, dataset, ModelDesc
from tensorpack import ModelDesc
from tensorpack.input_source import QueueInput, StagingInput
from tensorpack.dataflow import (
AugmentImageComponent, PrefetchDataZMQ,
imgaug, dataset, AugmentImageComponent, PrefetchDataZMQ,
BatchData, MultiThreadMapData)
from tensorpack.predict import PredictConfig, SimpleDatasetPredictor
from tensorpack.predict import PredictConfig, FeedfreePredictor
from tensorpack.utils.stats import RatioCounter
from tensorpack.models import regularize_cost
from tensorpack.tfutils.summary import add_moving_summary
......@@ -126,12 +128,17 @@ def eval_on_ILSVRC12(model, sessinit, dataflow):
input_names=['input', 'label'],
output_names=['wrong-top1', 'wrong-top5']
)
pred = SimpleDatasetPredictor(pred_config, dataflow)
acc1, acc5 = RatioCounter(), RatioCounter()
for top1, top5 in pred.get_result():
# This does not have a visible improvement over naive predictor,
# but will have an improvement if image_dtype is set to float32.
pred = FeedfreePredictor(pred_config, StagingInput(QueueInput(dataflow), device='/gpu:0'))
for _ in tqdm.trange(dataflow.size()):
top1, top5 = pred()
batch_size = top1.shape[0]
acc1.feed(top1.sum(), batch_size)
acc5.feed(top5.sum(), batch_size)
print("Top1 Error: {}".format(acc1.ratio))
print("Top5 Error: {}".format(acc5.ratio))
......
......@@ -547,10 +547,10 @@ class StagingInput(FeedfreeInput):
self.fetches = tf.train.SessionRunArgs(
fetches=[self.stage_op, unstage_op])
def _prefill(self):
def _prefill(self, sess):
logger.info("Pre-filling StagingArea ...")
for k in range(self.nr_stage):
self.stage_op.run()
self.stage_op.run(session=sess)
logger.info("{} element{} put into StagingArea on each tower.".format(
self.nr_stage, "s were" if self.nr_stage > 1 else " was"))
......@@ -559,7 +559,7 @@ class StagingInput(FeedfreeInput):
# doing it in `before_train` may not work because QueueInput happens in before_train.
if not self._initialized:
self._initialized = True
self._prefill()
self._prefill(ctx.session)
# Only step the stagingarea when the input is evaluated in this sess.run
fetches = ctx.original_args.fetches
if dependency_of_fetches(fetches, self._check_dependency_op):
......
......@@ -118,6 +118,13 @@ class InputSource(object):
All callbacks will be automatically marked as `chief_only=False`,
so they will run on all nodes.
Callbacks returned by :class:`InputSource` only supports a subset of callback's functionalities:
1. It cannot access the trainer, because an :class:`InputSource` can be used in pure inference.
2. It cannot use the following methods: `trigger_{step,epoch}, {before,after}_epoch`.
In other words, these callbacks should only have the basic functionality of `tf.train.SessionRunHooks`.
Returns:
list[Callback]: extra callbacks needed by this InputSource.
"""
......
#!/usr/bin/env python
from tensorflow.python.training.monitored_session \
import _HookedSession as HookedSession
from .base import PredictorBase
from ..tfutils.tower import PredictTowerContext
from ..callbacks import Callbacks
__all__ = ['FeedfreePredictor']
class FeedfreePredictor(PredictorBase):
"""
Create a predictor that takes inputs from an :class:`InputSource`, instead of from feeds.
An instance `pred` of :class:`FeedfreePredictor` can be called only by `pred()`, which returns
a list of output values as defined in config.output_names.
"""
def __init__(self, config, input_source):
"""
Args:
config (PredictConfig): the config to use.
input_source (InputSource): the feedfree InputSource to use.
Must match the inputs_desc in config.
"""
self._config = config
self._input_source = input_source
assert config.return_input is False, \
"return_input is not supported in FeedfreePredictor! " \
"If you need to fetch inputs, add the names to the output_names!"
self._hooks = []
self.graph = config._maybe_create_graph()
with self.graph.as_default():
self._input_callbacks = Callbacks(
self._input_source.setup(config.inputs_desc))
with PredictTowerContext(''):
self._input_tensors = self._input_source.get_input_tensors()
config.tower_func(*self._input_tensors)
self._tower_handle = config.tower_func.towers[-1]
self._output_tensors = self._tower_handle.get_tensors(config.output_names)
self._input_callbacks.setup_graph(None)
for h in self._input_callbacks.get_hooks():
self._register_hook(h)
self._initialize_session()
def _register_hook(self, hook):
"""
Args:
hook (tf.train.SessionRunHook):
"""
self._hooks.append(hook)
def _initialize_session(self):
# init the session
self._config.session_init._setup_graph()
self._sess = self._config.session_creator.create_session()
self._config.session_init._run_init(self._sess)
with self._sess.as_default():
self._input_callbacks.before_train()
self._hooked_sess = HookedSession(self._sess, self._hooks)
def __call__(self):
return self._hooked_sess.run(self._output_tensors)
def _do_call(self):
raise NotImplementedError("You're calling the wrong function!")
......@@ -51,7 +51,8 @@ def dependency_of_fetches(fetches, op):
"""
try:
from tensorflow.python.client.session import _FetchHandler as FetchHandler
handler = FetchHandler(tf.get_default_graph(), fetches, {})
# use the graph of the op, so that this function can be called without being under a default graph
handler = FetchHandler(op.graph, fetches, {})
targets = tuple(handler.fetches() + handler.targets())
except ImportError:
if isinstance(fetches, list):
......
......@@ -22,7 +22,7 @@ class NewSessionCreator(tf.train.ChiefSessionCreator):
"""
Args:
target, graph, config: same as :meth:`Session.__init__()`.
config: defaults to :func:`tfutils.get_default_sess_config()`
config: a :class:`tf.ConfigProto` instance, defaults to :func:`tfutils.get_default_sess_config()`
"""
assert graph is None
......
......@@ -377,6 +377,9 @@ class TowerTensorHandle(object):
1. The name of the tensor without any tower prefix.
2. The name of an :class:`InputDesc`, if it is used when building the tower.
In the second case, this method will return the tensor that's used as the corresponding
input to the tower. Note that this tensor may have a different name (e.g. may be an output of a queue).
"""
name = get_op_tensor_name(name)[1]
if len(self.ns_name):
......@@ -392,10 +395,12 @@ class TowerTensorHandle(object):
raise
else:
if name in self._extra_tensor_names:
logger.warn(
"'{}' may refer to both the tensor '{}' or the input '{}'.".format(
name, ret.name, self._extra_tensor_names[name].name) +
"Assuming it is the tensor '{}'.".format(ret.name))
mapped_tensor = self._extra_tensor_names[name]
logger.info(
"'{}' may refer to both the Tensor/Placeholder '{}' or the input to the tower '{}'.".format(
name, ret.name, mapped_tensor.name) +
" Assuming it is the input '{}'.".format(mapped_tensor.name))
return mapped_tensor
return ret
def get_tensors(self, names):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment