Commit 4c2bcb94 authored by Yuxin Wu's avatar Yuxin Wu

clean-up legacy placeholder methods in ModelDesc

parent f603636c
...@@ -10,6 +10,7 @@ import imp ...@@ -10,6 +10,7 @@ import imp
from tensorpack import TowerContext, logger from tensorpack import TowerContext, logger
from tensorpack.tfutils import sessinit, varmanip from tensorpack.tfutils import sessinit, varmanip
from tensorpack.graph_builder.input_source import PlaceholderInput
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('--config', help='config file') parser.add_argument('--config', help='config file')
...@@ -26,7 +27,9 @@ with tf.Graph().as_default() as G: ...@@ -26,7 +27,9 @@ with tf.Graph().as_default() as G:
MODEL = imp.load_source('config_script', args.config).Model MODEL = imp.load_source('config_script', args.config).Model
M = MODEL() M = MODEL()
with TowerContext('', is_training=False): with TowerContext('', is_training=False):
M.build_graph(M.get_reused_placehdrs()) input = PlaceholderInput()
input.setup(M.get_inputs_desc())
M.build_graph(input)
else: else:
tf.train.import_meta_graph(args.meta) tf.train.import_meta_graph(args.meta)
......
...@@ -93,32 +93,6 @@ class InputDesc( ...@@ -93,32 +93,6 @@ class InputDesc(
class ModelDescBase(object): class ModelDescBase(object):
""" Base class for a model description. """ Base class for a model description.
""" """
# TODO remove this method? Now mainly used in predict/
@memoized
def get_reused_placehdrs(self):
"""
Create or return (if already created) raw input TF placeholders in the graph.
Returns:
list[tf.Tensor]: the list of input placeholders in the graph.
"""
return [v.build_placeholder_reuse() for v in self.get_inputs_desc()]
def build_placeholders(self, prefix=''):
"""
For each InputDesc, create new placeholders with optional prefix and
return them. Useful when building new towers.
Returns:
list[tf.Tensor]: the list of built placeholders.
"""
inputs = self.get_inputs_desc()
ret = []
for v in inputs:
ret.append(v.build_placeholder(prefix))
return ret
@memoized @memoized
def get_inputs_desc(self): def get_inputs_desc(self):
""" """
......
...@@ -8,6 +8,7 @@ from ..tfutils.common import get_op_tensor_name, get_tensors_by_names ...@@ -8,6 +8,7 @@ from ..tfutils.common import get_op_tensor_name, get_tensors_by_names
from ..tfutils.tower import TowerContext from ..tfutils.tower import TowerContext
from ..tfutils.collection import freeze_collection from ..tfutils.collection import freeze_collection
from ..utils.naming import TOWER_FREEZE_KEYS from ..utils.naming import TOWER_FREEZE_KEYS
from .input_source import PlaceholderInput
__all__ = ['PredictorFactory'] __all__ = ['PredictorFactory']
...@@ -62,9 +63,9 @@ class PredictorFactory(object): ...@@ -62,9 +63,9 @@ class PredictorFactory(object):
TowerContext(tower_name, is_training=False), \ TowerContext(tower_name, is_training=False), \
freeze_collection(TOWER_FREEZE_KEYS): freeze_collection(TOWER_FREEZE_KEYS):
if input is None: if input is None:
input = self._model.get_reused_placehdrs() input = PlaceholderInput()
else: input.setup(self._model.get_inputs_desc())
input = input.get_input_tensors() input = input.get_input_tensors()
assert isinstance(input, (list, tuple)), input assert isinstance(input, (list, tuple)), input
self._model.build_graph(input) self._model.build_graph(input)
......
...@@ -9,6 +9,7 @@ import six ...@@ -9,6 +9,7 @@ import six
from ..tfutils.common import get_tensors_by_names from ..tfutils.common import get_tensors_by_names
from ..tfutils.tower import TowerContext from ..tfutils.tower import TowerContext
from ..graph_builder.input_source import PlaceholderInput
__all__ = ['PredictorBase', 'AsyncPredictorBase', __all__ = ['PredictorBase', 'AsyncPredictorBase',
'OnlinePredictor', 'OfflinePredictor', 'OnlinePredictor', 'OfflinePredictor',
...@@ -127,9 +128,10 @@ class OfflinePredictor(OnlinePredictor): ...@@ -127,9 +128,10 @@ class OfflinePredictor(OnlinePredictor):
""" """
self.graph = config._maybe_create_graph() self.graph = config._maybe_create_graph()
with self.graph.as_default(): with self.graph.as_default():
input_placehdrs = config.model.get_reused_placehdrs() input = PlaceholderInput()
input.setup(config.model.get_inputs_desc())
with TowerContext('', is_training=False): with TowerContext('', is_training=False):
config.model.build_graph(input_placehdrs) config.model.build_graph(input)
input_tensors = get_tensors_by_names(config.input_names) input_tensors = get_tensors_by_names(config.input_names)
output_tensors = get_tensors_by_names(config.output_names) output_tensors = get_tensors_by_names(config.output_names)
......
...@@ -10,6 +10,7 @@ This simplifies the process of exporting a model for TensorFlow serving. ...@@ -10,6 +10,7 @@ This simplifies the process of exporting a model for TensorFlow serving.
import tensorflow as tf import tensorflow as tf
from ..utils import logger from ..utils import logger
from ..graph_builder.model_desc import ModelDesc from ..graph_builder.model_desc import ModelDesc
from ..graph_builder.input_source import PlaceholderInput
from ..tfutils import TowerContext, sessinit from ..tfutils import TowerContext, sessinit
...@@ -61,7 +62,8 @@ class ModelExport(object): ...@@ -61,7 +62,8 @@ class ModelExport(object):
logger.info('[export] prepare new model export') logger.info('[export] prepare new model export')
super(ModelExport, self).__init__() super(ModelExport, self).__init__()
self.model = model self.model = model
self.placehdrs = self.model.get_reused_placehdrs() self.input = PlaceholderInput()
self.input.setup(self.model.get_inputs_desc())
self.output_names = output_names self.output_names = output_names
self.input_names = input_names self.input_names = input_names
...@@ -87,7 +89,7 @@ class ModelExport(object): ...@@ -87,7 +89,7 @@ class ModelExport(object):
""" """
logger.info('[export] build model for %s' % checkpoint) logger.info('[export] build model for %s' % checkpoint)
with TowerContext('', is_training=False): with TowerContext('', is_training=False):
self.model._build_graph(self.placehdrs) self.model._build_graph(self.input)
self.sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) self.sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True))
# load values from latest checkpoint # load values from latest checkpoint
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment