Commit 86ec2d15 authored by Yuxin Wu's avatar Yuxin Wu

offline predictor

parent 4af48399
...@@ -39,8 +39,7 @@ class ModelDesc(object): ...@@ -39,8 +39,7 @@ class ModelDesc(object):
def reuse_input_vars(self): def reuse_input_vars(self):
""" Find and return already-defined input_vars in default graph""" """ Find and return already-defined input_vars in default graph"""
input_var_names = [k.name for k in self._get_input_vars()] input_var_names = [k.name for k in self._get_input_vars()]
g = tf.get_default_graph() return get_vars_by_names(input_var_names)
return [g.get_tensor_by_name(name + ":0") for name in input_var_names]
def get_input_vars_desc(self): def get_input_vars_desc(self):
""" return a list of `InputVar` instance""" """ return a list of `InputVar` instance"""
......
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# File: base.py
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
from abc import abstractmethod, ABCMeta, abstractproperty
import tensorflow as tf
from ..tfutils import get_vars_by_names
class PredictorBase(object):
__metaclass__ = ABCMeta
@abstractproperty
def session(self):
""" return the session the predictor is running on"""
pass
def __call__(self, dp):
assert len(dp) == len(self.input_var_names), \
"{} != {}".format(len(dp), len(self.input_var_names))
output = self._do_call(dp)
if self.return_input:
return (dp, output)
else:
return output
@abstractmethod
def _do_call(self, dp):
"""
:param dp: input datapoint. must have the same length as input_var_names
:return: output as defined by the config
"""
pass
class OfflinePredictor(PredictorBase):
""" Build a predictor from a given config, in an independent graph"""
def __init__(self, config):
self.graph = tf.Graph()
with self.graph.as_default():
input_vars = config.model.get_input_vars()
config.model._build_graph(input_vars, False)
self.input_var_names = config.input_var_names
self.output_var_names = config.output_var_names
self.return_input = config.return_input
self.input_vars = get_vars_by_names(self.input_var_names)
self.output_vars = get_vars_by_names(self.output_var_names)
sess = tf.Session(config=config.session_config)
config.session_init.init(sess)
self._session = sess
@property
def session(self):
return self._session
def _do_call(self, dp):
feed = dict(zip(self.input_vars, dp))
output = self.session.run(self.output_vars, feed_dict=feed)
return output
...@@ -4,11 +4,13 @@ ...@@ -4,11 +4,13 @@
import tensorflow as tf import tensorflow as tf
from collections import namedtuple from collections import namedtuple
import six
from six.moves import zip from six.moves import zip
from tensorpack.models import ModelDesc from tensorpack.models import ModelDesc
from ..utils import logger from ..utils import logger
from ..tfutils import * from ..tfutils import *
from .base import OfflinePredictor
import multiprocessing import multiprocessing
...@@ -29,20 +31,20 @@ class PredictConfig(object): ...@@ -29,20 +31,20 @@ class PredictConfig(object):
:param output_var_names: a list of names of the output tensors to predict, the :param output_var_names: a list of names of the output tensors to predict, the
variables can be any computable tensor in the graph. variables can be any computable tensor in the graph.
Predict specific output might not require all input variables. Predict specific output might not require all input variables.
:param return_input: whether to produce (input, output) pair or just output. default to False. :param return_input: whether to return (input, output) pair or just output. default to False.
It's only effective for `DatasetPredictorBase`.
""" """
def assert_type(v, tp): def assert_type(v, tp):
assert isinstance(v, tp), v.__class__ assert isinstance(v, tp), v.__class__
# XXX does it work? start with minimal memory, but allow growth. # XXX does it work? start with minimal memory, but allow growth.
# allow_growth doesn't seem to work very well in TF. # allow_growth doesn't seem to work very well in TF.
self.session_config = kwargs.pop('session_config', get_default_sess_config(0.3)) self.session_config = kwargs.pop('session_config', get_default_sess_config(0.4))
self.session_init = kwargs.pop('session_init', JustCurrentSession()) self.session_init = kwargs.pop('session_init', JustCurrentSession())
assert_type(self.session_init, SessionInit) assert_type(self.session_init, SessionInit)
self.model = kwargs.pop('model') self.model = kwargs.pop('model')
assert_type(self.model, ModelDesc) assert_type(self.model, ModelDesc)
self.input_var_names = kwargs.pop('input_var_names', None)
# inputs & outputs
self.input_var_names = kwargs.pop('input_var_names', None)
input_mapping = kwargs.pop('input_data_mapping', None) input_mapping = kwargs.pop('input_data_mapping', None)
if input_mapping: if input_mapping:
raw_vars = self.model.get_input_vars_desc() raw_vars = self.model.get_input_vars_desc()
...@@ -55,32 +57,19 @@ Use \'input_var_names=[{}]\' instead'.format(', '.join(self.input_var_names))) ...@@ -55,32 +57,19 @@ Use \'input_var_names=[{}]\' instead'.format(', '.join(self.input_var_names)))
self.input_var_names = [k.name for k in raw_vars] self.input_var_names = [k.name for k in raw_vars]
self.output_var_names = kwargs.pop('output_var_names') self.output_var_names = kwargs.pop('output_var_names')
assert len(self.input_var_names), self.input_var_names assert len(self.input_var_names), self.input_var_names
for v in self.input_var_names: assert_type(v, six.string_types)
assert len(self.output_var_names), self.output_var_names assert len(self.output_var_names), self.output_var_names
self.return_input = kwargs.pop('return_input', False) self.return_input = kwargs.pop('return_input', False)
assert len(kwargs) == 0, 'Unknown arguments: {}'.format(str(kwargs.keys())) assert len(kwargs) == 0, 'Unknown arguments: {}'.format(str(kwargs.keys()))
def get_predict_func(config): def get_predict_func(config):
""" """
Produce a simple predictor function run inside a new session. Produce a offline predictor run inside a new session.
:param config: a `PredictConfig` instance. :param config: a `PredictConfig` instance.
:returns: A prediction function that takes a list of input values, and return :returns: A callable predictor that takes a list of input values, and return
a list of output values defined in ``config.output_var_names``. a list of output values defined in ``config.output_var_names``.
""" """
# build graph return OfflinePredictor(config)
input_vars = config.model.get_input_vars()
config.model._build_graph(input_vars, False)
input_vars = get_vars_by_names(config.input_var_names)
output_vars = get_vars_by_names(config.output_var_names)
sess = tf.Session(config=config.session_config)
config.session_init.init(sess)
def run_input(dp):
assert len(input_vars) == len(dp), "{} != {}".format(len(input_vars), len(dp))
feed = dict(zip(input_vars, dp))
return sess.run(output_vars, feed_dict=feed)
# XXX hack. so the caller can get access to the session.
run_input.session = sess
return run_input
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment