Commit 14a28c01 authored by Yuxin Wu's avatar Yuxin Wu

online predictor in trainer

parent 86ec2d15
...@@ -7,18 +7,18 @@ from abc import abstractmethod, ABCMeta, abstractproperty ...@@ -7,18 +7,18 @@ from abc import abstractmethod, ABCMeta, abstractproperty
import tensorflow as tf import tensorflow as tf
from ..tfutils import get_vars_by_names from ..tfutils import get_vars_by_names
__all__ = ['OnlinePredictor', 'OfflinePredictor']
class PredictorBase(object): class PredictorBase(object):
__metaclass__ = ABCMeta __metaclass__ = ABCMeta
"""
@abstractproperty Property:
def session(self): session
""" return the session the predictor is running on""" return_input
pass """
def __call__(self, dp): def __call__(self, dp):
assert len(dp) == len(self.input_var_names), \
"{} != {}".format(len(dp), len(self.input_var_names))
output = self._do_call(dp) output = self._do_call(dp)
if self.return_input: if self.return_input:
return (dp, output) return (dp, output)
...@@ -33,8 +33,23 @@ class PredictorBase(object): ...@@ -33,8 +33,23 @@ class PredictorBase(object):
""" """
pass pass
class OnlinePredictor(PredictorBase):
def __init__(self, sess, input_vars, output_vars, return_input=False):
self.session = sess
self.return_input = return_input
class OfflinePredictor(PredictorBase): self.input_vars = input_vars
self.output_vars = output_vars
def _do_call(self, dp):
assert len(dp) == len(self.input_vars), \
"{} != {}".format(len(dp), len(self.input_vars))
feed = dict(zip(self.input_vars, dp))
output = self.session.run(self.output_vars, feed_dict=feed)
return output
class OfflinePredictor(OnlinePredictor):
""" Build a predictor from a given config, in an independent graph""" """ Build a predictor from a given config, in an independent graph"""
def __init__(self, config): def __init__(self, config):
self.graph = tf.Graph() self.graph = tf.Graph()
...@@ -42,22 +57,10 @@ class OfflinePredictor(PredictorBase): ...@@ -42,22 +57,10 @@ class OfflinePredictor(PredictorBase):
input_vars = config.model.get_input_vars() input_vars = config.model.get_input_vars()
config.model._build_graph(input_vars, False) config.model._build_graph(input_vars, False)
self.input_var_names = config.input_var_names input_vars = get_vars_by_names(config.input_var_names)
self.output_var_names = config.output_var_names output_vars = get_vars_by_names(config.output_var_names)
self.return_input = config.return_input
self.input_vars = get_vars_by_names(self.input_var_names)
self.output_vars = get_vars_by_names(self.output_var_names)
sess = tf.Session(config=config.session_config) sess = tf.Session(config=config.session_config)
config.session_init.init(sess) config.session_init.init(sess)
self._session = sess super(OfflinePredictor, self).__init__(
sess, input_vars, output_vars, config.return_input)
@property
def session(self):
return self._session
def _do_call(self, dp):
feed = dict(zip(self.input_vars, dp))
output = self.session.run(self.output_vars, feed_dict=feed)
return output
...@@ -8,11 +8,14 @@ import time ...@@ -8,11 +8,14 @@ import time
from six.moves import zip from six.moves import zip
from .base import Trainer from .base import Trainer
from ..dataflow.common import RepeatedData from ..dataflow.common import RepeatedData
from ..utils import *
from ..tfutils.summary import summary_moving_average from ..tfutils.summary import summary_moving_average
from ..tfutils.modelutils import describe_model from ..tfutils.modelutils import describe_model
from ..utils import *
from ..tfutils import * from ..tfutils import *
from ..predict import OnlinePredictor
__all__ = ['SimpleTrainer', 'QueueInputTrainer'] __all__ = ['SimpleTrainer', 'QueueInputTrainer']
...@@ -56,11 +59,7 @@ class SimpleTrainer(Trainer): ...@@ -56,11 +59,7 @@ class SimpleTrainer(Trainer):
for v in input_vars: for v in input_vars:
assert v in self.input_vars assert v in self.input_vars
output_vars = get_vars_by_names(output_names) output_vars = get_vars_by_names(output_names)
def func(inputs): return OnlinePredictor(self.sess, input_vars, output_vars)
assert len(inputs) == len(input_vars)
feed = dict(zip(input_vars, inputs))
return self.sess.run(output_vars, feed_dict=feed)
return func
class EnqueueThread(threading.Thread): class EnqueueThread(threading.Thread):
def __init__(self, trainer): def __init__(self, trainer):
...@@ -218,11 +217,7 @@ class QueueInputTrainer(Trainer): ...@@ -218,11 +217,7 @@ class QueueInputTrainer(Trainer):
raw_input_vars = get_vars_by_names(input_names) raw_input_vars = get_vars_by_names(input_names)
output_names = ['towerp{}/'.format(tower) + n for n in output_names] output_names = ['towerp{}/'.format(tower) + n for n in output_names]
output_vars = get_vars_by_names(output_names) output_vars = get_vars_by_names(output_names)
def func(inputs): return OnlinePredictor(self.sess, raw_input_vars, output_vars)
assert len(inputs) == len(raw_input_vars)
feed = dict(zip(raw_input_vars, inputs))
return self.sess.run(output_vars, feed_dict=feed)
return func
def get_predict_funcs(self, input_names, output_names, n): def get_predict_funcs(self, input_names, output_names, n):
""" return n predicts functions evenly on each predict_tower""" """ return n predicts functions evenly on each predict_tower"""
......
...@@ -102,5 +102,5 @@ if __name__ == '__main__': ...@@ -102,5 +102,5 @@ if __name__ == '__main__':
x = Rect(2, 1, 3, 3, allow_neg=True) x = Rect(2, 1, 3, 3, allow_neg=True)
img = np.random.rand(3,3) img = np.random.rand(3,3)
print img print(img)
print x.roi_zeropad(img) print(x.roi_zeropad(img))
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment