Commit a55d81ca authored by Yuxin Wu's avatar Yuxin Wu

use sess.make_callable for predictors

parent f363d2e8
...@@ -12,7 +12,6 @@ import multiprocessing ...@@ -12,7 +12,6 @@ import multiprocessing
from tensorpack import * from tensorpack import *
from tensorpack.utils import logger from tensorpack.utils import logger
from tensorpack.utils.gpu import get_nr_gpu
from tensorpack.utils.viz import * from tensorpack.utils.viz import *
from tensorpack.utils.argtools import shape2d, shape4d from tensorpack.utils.argtools import shape2d, shape4d
from tensorpack.dataflow import dataset from tensorpack.dataflow import dataset
......
...@@ -4,9 +4,8 @@ ...@@ -4,9 +4,8 @@
import os import os
import argparse import argparse
from tensorpack import *
from tensorpack.utils.gpu import get_nr_gpu
import tensorflow as tf import tensorflow as tf
from tensorpack import *
""" """
This is a boiler-plate template. This is a boiler-plate template.
......
...@@ -7,10 +7,11 @@ from abc import abstractmethod, ABCMeta ...@@ -7,10 +7,11 @@ from abc import abstractmethod, ABCMeta
import tensorflow as tf import tensorflow as tf
import six import six
from ..tfutils.common import get_tensors_by_names from ..tfutils.common import get_tensors_by_names, get_tf_version_number
from ..tfutils.tower import TowerContext from ..tfutils.tower import TowerContext
from ..input_source import PlaceholderInput from ..input_source import PlaceholderInput
from ..utils.develop import log_deprecated from ..utils.develop import log_deprecated
from ..utils.argtools import log_once
__all__ = ['PredictorBase', 'AsyncPredictorBase', __all__ = ['PredictorBase', 'AsyncPredictorBase',
'OnlinePredictor', 'OfflinePredictor', 'OnlinePredictor', 'OfflinePredictor',
...@@ -106,15 +107,40 @@ class OnlinePredictor(PredictorBase): ...@@ -106,15 +107,40 @@ class OnlinePredictor(PredictorBase):
self.input_tensors = input_tensors self.input_tensors = input_tensors
self.output_tensors = output_tensors self.output_tensors = output_tensors
self.sess = sess self.sess = sess
self._use_callable = get_tf_version_number() >= 1.2
if self._use_callable:
if sess is not None:
self._callable = sess.make_callable(
fetches=output_tensors,
feed_list=input_tensors)
else:
log_once(
"TF>=1.2 is recommended for better performance of predictor!", 'warn')
self._callable = None
def _do_call_old(self, dp):
feed = dict(zip(self.input_tensors, dp))
output = self.sess.run(self.output_tensors, feed_dict=feed)
return output
def _do_call_new(self, dp):
if self._callable is None:
self._callable = self.sess.make_callable(
fetches=self.output_tensors,
feed_list=self.input_tensors)
return self._callable(*dp)
def _do_call(self, dp): def _do_call(self, dp):
assert len(dp) == len(self.input_tensors), \ assert len(dp) == len(self.input_tensors), \
"{} != {}".format(len(dp), len(self.input_tensors)) "{} != {}".format(len(dp), len(self.input_tensors))
feed = dict(zip(self.input_tensors, dp))
if self.sess is None: if self.sess is None:
self.sess = tf.get_default_session() self.sess = tf.get_default_session()
output = self.sess.run(self.output_tensors, feed_dict=feed)
return output if self._use_callable:
return self._do_call_new(dp)
else:
return self._do_call_old(dp)
class OfflinePredictor(OnlinePredictor): class OfflinePredictor(OnlinePredictor):
......
...@@ -278,7 +278,13 @@ class TowerTensorHandle(object): ...@@ -278,7 +278,13 @@ class TowerTensorHandle(object):
""" """
return self._output return self._output
# def make_callable(self, input_names, output_names): # should move to somewhere else.
# def get_predictor(self, input_names, output_names):
# """
# Get a predictor with tensors inside this tower.
# """
# input_tensors = self.get_tensors(input_names) # input_tensors = self.get_tensors(input_names)
# output_tensors = self.get_tensors(output_names) # output_tensors = self.get_tensors(output_names)
# pass # # TODO sort out the import order
# from ..predict.base import OnlinePredictor # noqa
# return OnlinePredictor(input_tensors, output_tensors)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment