Commit d979ab78 authored by Yuxin Wu's avatar Yuxin Wu

load vgg

parent 3e512ea6
#!/usr/bin/env python2
# -*- coding: UTF-8 -*-
# File: load_vgg16.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
import tensorflow as tf
import numpy as np
import os
import argparse
import cPickle as pkl
from tensorpack.train import TrainConfig, start_train
from tensorpack.predict import PredictConfig, get_predict_func
from tensorpack.models import *
from tensorpack.utils import *
from tensorpack.tfutils import *
from tensorpack.tfutils.symbolic_functions import *
from tensorpack.tfutils.summary import *
from tensorpack.callbacks import *
from tensorpack.dataflow import *
class Model(ModelDesc):
def _get_input_vars(self):
return [InputVar(tf.float32, (None, 224, 224, 3), 'input'),
InputVar(tf.int32, (None,), 'label') ]
def _get_cost(self, inputs, is_training):
is_training = bool(is_training)
keep_prob = tf.constant(0.5 if is_training else 1.0)
image, label = inputs
# 224
l = Conv2D('conv1_1', image, out_channel=64, kernel_shape=3)
l = Conv2D('conv1_2', l, out_channel=64, kernel_shape=3)
l = MaxPooling('pool1', l, 2, stride=2, padding='VALID')
# 112
l = Conv2D('conv2_1', l, out_channel=128, kernel_shape=3)
l = Conv2D('conv2_2', l, out_channel=128, kernel_shape=3)
l = MaxPooling('pool2', l, 2, stride=2, padding='VALID')
# 56
l = Conv2D('conv3_1', l, out_channel=256, kernel_shape=3)
l = Conv2D('conv3_2', l, out_channel=256, kernel_shape=3)
l = Conv2D('conv3_3', l, out_channel=256, kernel_shape=3)
l = MaxPooling('pool3', l, 2, stride=2, padding='VALID')
# 28
l = Conv2D('conv4_1', l, out_channel=512, kernel_shape=3)
l = Conv2D('conv4_2', l, out_channel=512, kernel_shape=3)
l = Conv2D('conv4_3', l, out_channel=512, kernel_shape=3)
l = MaxPooling('pool4', l, 2, stride=2, padding='VALID')
# 14
l = Conv2D('conv5_1', l, out_channel=512, kernel_shape=3)
l = Conv2D('conv5_2', l, out_channel=512, kernel_shape=3)
l = Conv2D('conv5_3', l, out_channel=512, kernel_shape=3)
l = MaxPooling('pool5', l, 2, stride=2, padding='VALID')
# 7
l = FullyConnected('fc6', l, 4096)
l = tf.nn.dropout(l, keep_prob)
l = FullyConnected('fc7', l, 4096)
l = tf.nn.dropout(l, keep_prob)
logits = FullyConnected('fc8', l, out_dim=1000, summary_activation=False, nl=tf.identity)
prob = tf.nn.softmax(logits, name='output')
y = one_hot(label, 1000)
cost = tf.nn.softmax_cross_entropy_with_logits(logits, y)
cost = tf.reduce_mean(cost, name='cross_entropy_loss')
tf.add_to_collection(MOVING_SUMMARY_VARS_KEY, cost)
# compute the number of failed samples, for ValidationError to use at test time
wrong = tf.not_equal(
tf.cast(tf.argmax(prob, 1), tf.int32), label)
wrong = tf.cast(wrong, tf.float32)
nr_wrong = tf.reduce_sum(wrong, name='wrong')
# monitor training error
tf.add_to_collection(
MOVING_SUMMARY_VARS_KEY, tf.reduce_mean(wrong, name='train_error'))
# weight decay on all W of fc layers
wd_cost = tf.mul(1e-4,
regularize_cost('fc.*/W', tf.nn.l2_loss),
name='regularize_loss')
tf.add_to_collection(MOVING_SUMMARY_VARS_KEY, wd_cost)
return tf.add_n([wd_cost, cost], name='cost')
def run_test(path, input):
param_dict = np.load(path).item()
pred_config = PredictConfig(
model=Model(),
input_data_mapping=[0],
session_init=ParamRestore(param_dict),
output_var_names=['output:0'] # output:0 is the probability distribution
)
predict_func = get_predict_func(pred_config)
import cv2
im = cv2.imread(input)
assert im is not None
im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB)
im = cv2.resize(im, (224, 224))
im = np.reshape(im, (1, 224, 224, 3)).astype('float32')
outputs = predict_func([im])[0]
prob = outputs[0]
print prob.shape
ret = prob.argsort()[-10:][::-1]
print ret
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--gpu', default='0',
help='comma separated list of GPU(s) to use.') # nargs='*' in multi mode
parser.add_argument('--load', required=True,
help='.npy model file generated by tensorpack.utils.loadcaffe')
parser.add_argument('--input', help='an input image', required=True)
args = parser.parse_args()
if args.gpu:
os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
run_test(args.load, args.input)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment