Commit 5013ea51 authored by Yuxin Wu's avatar Yuxin Wu

create_predict_config for ImageNetModels

parent decf8310
......@@ -345,6 +345,7 @@ class ImageNetModel(ModelDesc):
image = tf.transpose(image, [0, 3, 1, 2])
logits = self.get_logits(image)
tf.nn.softmax(logits, name='prob')
loss = ImageNetModel.compute_loss_and_error(
logits, label, label_smoothing=self.label_smoothing)
......@@ -423,6 +424,20 @@ class ImageNetModel(ModelDesc):
return loss
def create_predict_config(self, session_init):
"""
Returns:
a :class:`PredictConfig` to be used for inference.
The predictor will take inputs and return probabilities.
Examples:
pred = OfflinePredictor(model.create_predict_config(get_model_loader(args.load)))
prob = pred(NCHW_image)[0] # Nx1000 probabilities
"""
return PredictConfig(model=self, input_names=['input'], output_names=['prob'], session_init=session_init)
if __name__ == '__main__':
import argparse
from tensorpack.dataflow import TestDataSpeed
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment