Commit a988fc18 authored by Yuxin Wu's avatar Yuxin Wu

Update the use of exportmodel

parent 2e2bbcac
......@@ -23,8 +23,15 @@ It saves models to standard checkpoint format, plus a metagraph protobuf file.
They are sufficient to use with whatever deployment methods TensorFlow supports.
But you'll need to read TF docs and do it on your own.
Please note that, the metagraph saved during training is the training graph.
But you may need a different one for inference.
For example, you may need a different data layout for CPU inference,
or you may need placeholders in the inference graph, or the training graph contains multi-GPU replication
which you want to remove.
In this case, you can always create a new graph by yourself with TF symbolic functions.
The only thing tensorpack has is `OfflinePredictor`,
a simple function to build the graph and a callable for you.
It is mainly for quick demo purpose.
It only runs inference on Python data, therefore may not be the most efficient way.
Check out some examples for the usage.
Check out some examples for its usage.
......@@ -66,7 +66,8 @@ with TowerContext('', is_training=True):
Some layers (in particular ``BatchNorm``) has different train/test time behavior which is controlled
by ``TowerContext``. If you need to use the tensorpack version of them in test time, you'll need to create the ops for them under another context.
```python
with tf.variable_scope(tf.get_variable_scope(), reuse=True), TowerContext('predict', is_training=False):
# Open a `reuse=True` variable scope here if you're sharing variables, then:
with TowerContext('some_name_or_empty_string', is_training=False):
# build the graph again
```
......
......@@ -58,25 +58,22 @@ class ModelExport(object):
assert isinstance(input_names, list)
assert isinstance(output_names, list)
assert isinstance(model, ModelDescBase)
logger.info('[export] prepare new model export')
super(ModelExport, self).__init__()
self.model = model
self.input = PlaceholderInput()
self.input.setup(self.model.get_inputs_desc())
self.output_names = output_names
self.input_names = input_names
def export(self, checkpoint, export_path, version=1, tags=[tf.saved_model.tag_constants.SERVING],
def export(self, checkpoint, export_path,
tags=[tf.saved_model.tag_constants.SERVING],
signature_name='prediction_pipeline'):
"""Use SavedModelBuilder to export a trained model without TensorPack depency.
"""
Use SavedModelBuilder to export a trained model without tensorpack depency.
Remarks:
This produces
variables/ # output from the vanilla Saver
variables.data-?????-of-?????
variables.index
saved_model.pb # saved model in protcol buffer format
saved_model.pb # a `SavedModel` protobuf
Currently, we only support a single signature, which is the general PredictSignatureDef:
https://github.com/tensorflow/serving/blob/master/tensorflow_serving/g3doc/signature_defs.md
......@@ -89,47 +86,49 @@ class ModelExport(object):
"""
logger.info('[export] build model for %s' % checkpoint)
with TowerContext('', is_training=False):
self.model.build_graph(*self.input.get_input_tensors())
self.sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True))
# load values from latest checkpoint
init = sessinit.SaverRestore(checkpoint)
self.sess.run(tf.global_variables_initializer())
init.init(self.sess)
self.inputs = []
for n in self.input_names:
tensor = tf.get_default_graph().get_tensor_by_name('%s:0' % n)
logger.info('[export] add input-tensor "%s"' % tensor.name)
self.inputs.append(tensor)
self.outputs = []
for n in self.output_names:
tensor = tf.get_default_graph().get_tensor_by_name('%s:0' % n)
logger.info('[export] add output-tensor "%s"' % tensor.name)
self.outputs.append(tensor)
logger.info('[export] exporting trained model to %s' % export_path)
builder = tf.saved_model.builder.SavedModelBuilder(export_path)
logger.info('[export] build signatures')
# build inputs
inputs_signature = dict()
for n, v in zip(self.input_names, self.inputs):
logger.info('[export] add input signature: %s' % v)
inputs_signature[n] = tf.saved_model.utils.build_tensor_info(v)
outputs_signature = dict()
for n, v in zip(self.output_names, self.outputs):
logger.info('[export] add output signature: %s' % v)
outputs_signature[n] = tf.saved_model.utils.build_tensor_info(v)
prediction_signature = tf.saved_model.signature_def_utils.build_signature_def(
inputs=inputs_signature,
outputs=outputs_signature,
method_name=tf.saved_model.signature_constants.PREDICT_METHOD_NAME)
builder.add_meta_graph_and_variables(
self.sess, tags,
signature_def_map={signature_name: prediction_signature})
builder.save()
input = PlaceholderInput()
input.setup(self.model.get_inputs_desc())
self.model.build_graph(*input.get_input_tensors())
self.sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True))
# load values from latest checkpoint
init = sessinit.SaverRestore(checkpoint)
self.sess.run(tf.global_variables_initializer())
init.init(self.sess)
self.inputs = []
for n in self.input_names:
tensor = tf.get_default_graph().get_tensor_by_name('%s:0' % n)
logger.info('[export] add input-tensor "%s"' % tensor.name)
self.inputs.append(tensor)
self.outputs = []
for n in self.output_names:
tensor = tf.get_default_graph().get_tensor_by_name('%s:0' % n)
logger.info('[export] add output-tensor "%s"' % tensor.name)
self.outputs.append(tensor)
logger.info('[export] exporting trained model to %s' % export_path)
builder = tf.saved_model.builder.SavedModelBuilder(export_path)
logger.info('[export] build signatures')
# build inputs
inputs_signature = dict()
for n, v in zip(self.input_names, self.inputs):
logger.info('[export] add input signature: %s' % v)
inputs_signature[n] = tf.saved_model.utils.build_tensor_info(v)
outputs_signature = dict()
for n, v in zip(self.output_names, self.outputs):
logger.info('[export] add output signature: %s' % v)
outputs_signature[n] = tf.saved_model.utils.build_tensor_info(v)
prediction_signature = tf.saved_model.signature_def_utils.build_signature_def(
inputs=inputs_signature,
outputs=outputs_signature,
method_name=tf.saved_model.signature_constants.PREDICT_METHOD_NAME)
builder.add_meta_graph_and_variables(
self.sess, tags,
signature_def_map={signature_name: prediction_signature})
builder.save()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment