Commit a55d81ca authored by Yuxin Wu's avatar Yuxin Wu

use sess.make_callable for predictors

parent f363d2e8
......@@ -12,7 +12,6 @@ import multiprocessing
from tensorpack import *
from tensorpack.utils import logger
from tensorpack.utils.gpu import get_nr_gpu
from tensorpack.utils.viz import *
from tensorpack.utils.argtools import shape2d, shape4d
from tensorpack.dataflow import dataset
......
......@@ -4,9 +4,8 @@
import os
import argparse
from tensorpack import *
from tensorpack.utils.gpu import get_nr_gpu
import tensorflow as tf
from tensorpack import *
"""
This is a boiler-plate template.
......
......@@ -7,10 +7,11 @@ from abc import abstractmethod, ABCMeta
import tensorflow as tf
import six
from ..tfutils.common import get_tensors_by_names
from ..tfutils.common import get_tensors_by_names, get_tf_version_number
from ..tfutils.tower import TowerContext
from ..input_source import PlaceholderInput
from ..utils.develop import log_deprecated
from ..utils.argtools import log_once
__all__ = ['PredictorBase', 'AsyncPredictorBase',
'OnlinePredictor', 'OfflinePredictor',
......@@ -106,15 +107,40 @@ class OnlinePredictor(PredictorBase):
self.input_tensors = input_tensors
self.output_tensors = output_tensors
self.sess = sess
self._use_callable = get_tf_version_number() >= 1.2
if self._use_callable:
if sess is not None:
self._callable = sess.make_callable(
fetches=output_tensors,
feed_list=input_tensors)
else:
log_once(
"TF>=1.2 is recommended for better performance of predictor!", 'warn')
self._callable = None
def _do_call_old(self, dp):
feed = dict(zip(self.input_tensors, dp))
output = self.sess.run(self.output_tensors, feed_dict=feed)
return output
def _do_call_new(self, dp):
if self._callable is None:
self._callable = self.sess.make_callable(
fetches=self.output_tensors,
feed_list=self.input_tensors)
return self._callable(*dp)
def _do_call(self, dp):
assert len(dp) == len(self.input_tensors), \
"{} != {}".format(len(dp), len(self.input_tensors))
feed = dict(zip(self.input_tensors, dp))
if self.sess is None:
self.sess = tf.get_default_session()
output = self.sess.run(self.output_tensors, feed_dict=feed)
return output
if self._use_callable:
return self._do_call_new(dp)
else:
return self._do_call_old(dp)
class OfflinePredictor(OnlinePredictor):
......
......@@ -278,7 +278,13 @@ class TowerTensorHandle(object):
"""
return self._output
# def make_callable(self, input_names, output_names):
# should move to somewhere else.
# def get_predictor(self, input_names, output_names):
# """
# Get a predictor with tensors inside this tower.
# """
# input_tensors = self.get_tensors(input_names)
# output_tensors = self.get_tensors(output_names)
# pass
# # TODO sort out the import order
# from ..predict.base import OnlinePredictor # noqa
# return OnlinePredictor(input_tensors, output_tensors)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment