Commit 86ec2d15 authored by Yuxin Wu's avatar Yuxin Wu

offline predictor

parent 4af48399
......@@ -39,8 +39,7 @@ class ModelDesc(object):
def reuse_input_vars(self):
""" Find and return already-defined input_vars in default graph"""
input_var_names = [k.name for k in self._get_input_vars()]
g = tf.get_default_graph()
return [g.get_tensor_by_name(name + ":0") for name in input_var_names]
return get_vars_by_names(input_var_names)
def get_input_vars_desc(self):
""" return a list of `InputVar` instance"""
......
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# File: base.py
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
from abc import abstractmethod, ABCMeta, abstractproperty
import tensorflow as tf
from ..tfutils import get_vars_by_names
class PredictorBase(object):
__metaclass__ = ABCMeta
@abstractproperty
def session(self):
""" return the session the predictor is running on"""
pass
def __call__(self, dp):
assert len(dp) == len(self.input_var_names), \
"{} != {}".format(len(dp), len(self.input_var_names))
output = self._do_call(dp)
if self.return_input:
return (dp, output)
else:
return output
@abstractmethod
def _do_call(self, dp):
"""
:param dp: input datapoint. must have the same length as input_var_names
:return: output as defined by the config
"""
pass
class OfflinePredictor(PredictorBase):
""" Build a predictor from a given config, in an independent graph"""
def __init__(self, config):
self.graph = tf.Graph()
with self.graph.as_default():
input_vars = config.model.get_input_vars()
config.model._build_graph(input_vars, False)
self.input_var_names = config.input_var_names
self.output_var_names = config.output_var_names
self.return_input = config.return_input
self.input_vars = get_vars_by_names(self.input_var_names)
self.output_vars = get_vars_by_names(self.output_var_names)
sess = tf.Session(config=config.session_config)
config.session_init.init(sess)
self._session = sess
@property
def session(self):
return self._session
def _do_call(self, dp):
feed = dict(zip(self.input_vars, dp))
output = self.session.run(self.output_vars, feed_dict=feed)
return output
......@@ -4,11 +4,13 @@
import tensorflow as tf
from collections import namedtuple
import six
from six.moves import zip
from tensorpack.models import ModelDesc
from ..utils import logger
from ..tfutils import *
from .base import OfflinePredictor
import multiprocessing
......@@ -29,20 +31,20 @@ class PredictConfig(object):
:param output_var_names: a list of names of the output tensors to predict, the
variables can be any computable tensor in the graph.
Predict specific output might not require all input variables.
:param return_input: whether to produce (input, output) pair or just output. default to False.
It's only effective for `DatasetPredictorBase`.
:param return_input: whether to return (input, output) pair or just output. default to False.
"""
def assert_type(v, tp):
assert isinstance(v, tp), v.__class__
# XXX does it work? start with minimal memory, but allow growth.
# allow_growth doesn't seem to work very well in TF.
self.session_config = kwargs.pop('session_config', get_default_sess_config(0.3))
self.session_config = kwargs.pop('session_config', get_default_sess_config(0.4))
self.session_init = kwargs.pop('session_init', JustCurrentSession())
assert_type(self.session_init, SessionInit)
self.model = kwargs.pop('model')
assert_type(self.model, ModelDesc)
self.input_var_names = kwargs.pop('input_var_names', None)
# inputs & outputs
self.input_var_names = kwargs.pop('input_var_names', None)
input_mapping = kwargs.pop('input_data_mapping', None)
if input_mapping:
raw_vars = self.model.get_input_vars_desc()
......@@ -55,32 +57,19 @@ Use \'input_var_names=[{}]\' instead'.format(', '.join(self.input_var_names)))
self.input_var_names = [k.name for k in raw_vars]
self.output_var_names = kwargs.pop('output_var_names')
assert len(self.input_var_names), self.input_var_names
for v in self.input_var_names: assert_type(v, six.string_types)
assert len(self.output_var_names), self.output_var_names
self.return_input = kwargs.pop('return_input', False)
assert len(kwargs) == 0, 'Unknown arguments: {}'.format(str(kwargs.keys()))
def get_predict_func(config):
"""
Produce a simple predictor function run inside a new session.
Produce a offline predictor run inside a new session.
:param config: a `PredictConfig` instance.
:returns: A prediction function that takes a list of input values, and return
:returns: A callable predictor that takes a list of input values, and return
a list of output values defined in ``config.output_var_names``.
"""
# build graph
input_vars = config.model.get_input_vars()
config.model._build_graph(input_vars, False)
input_vars = get_vars_by_names(config.input_var_names)
output_vars = get_vars_by_names(config.output_var_names)
sess = tf.Session(config=config.session_config)
config.session_init.init(sess)
return OfflinePredictor(config)
def run_input(dp):
assert len(input_vars) == len(dp), "{} != {}".format(len(input_vars), len(dp))
feed = dict(zip(input_vars, dp))
return sess.run(output_vars, feed_dict=feed)
# XXX hack. so the caller can get access to the session.
run_input.session = sess
return run_input
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment