Commit bd0ca738 authored by Yuxin Wu's avatar Yuxin Wu

a better load_vgg

parent 5476b488
...@@ -20,6 +20,12 @@ from tensorpack.callbacks import * ...@@ -20,6 +20,12 @@ from tensorpack.callbacks import *
from tensorpack.dataflow import * from tensorpack.dataflow import *
from tensorpack.dataflow.dataset import ILSVRCMeta from tensorpack.dataflow.dataset import ILSVRCMeta
"""
Usage:
python2 -m tensorpack.utils.loadcaffe PATH/TO/models/VGG/{VGG_ILSVRC_16_layers_deploy.prototxt,VGG_ILSVRC_16_layers.caffemodel} vgg16.npy
./load_vgg16.py --load vgg16.npy --input cat.png
"""
class Model(ModelDesc): class Model(ModelDesc):
def _get_input_vars(self): def _get_input_vars(self):
return [InputVar(tf.float32, (None, 224, 224, 3), 'input'), return [InputVar(tf.float32, (None, 224, 224, 3), 'input'),
...@@ -31,32 +37,33 @@ class Model(ModelDesc): ...@@ -31,32 +37,33 @@ class Model(ModelDesc):
image, label = inputs image, label = inputs
with argscope(Conv2D, kernel_shape=3):
# 224 # 224
l = Conv2D('conv1_1', image, out_channel=64, kernel_shape=3) l = Conv2D('conv1_1', image, 64)
l = Conv2D('conv1_2', l, out_channel=64, kernel_shape=3) l = Conv2D('conv1_2', l, 64)
l = MaxPooling('pool1', l, 2, stride=2, padding='VALID') l = MaxPooling('pool1', l, 2, stride=2, padding='VALID')
# 112 # 112
l = Conv2D('conv2_1', l, out_channel=128, kernel_shape=3) l = Conv2D('conv2_1', l, 128)
l = Conv2D('conv2_2', l, out_channel=128, kernel_shape=3) l = Conv2D('conv2_2', l, 128)
l = MaxPooling('pool2', l, 2, stride=2, padding='VALID') l = MaxPooling('pool2', l, 2, stride=2, padding='VALID')
# 56 # 56
l = Conv2D('conv3_1', l, out_channel=256, kernel_shape=3) l = Conv2D('conv3_1', l, 256)
l = Conv2D('conv3_2', l, out_channel=256, kernel_shape=3) l = Conv2D('conv3_2', l, 256)
l = Conv2D('conv3_3', l, out_channel=256, kernel_shape=3) l = Conv2D('conv3_3', l, 256)
l = MaxPooling('pool3', l, 2, stride=2, padding='VALID') l = MaxPooling('pool3', l, 2, stride=2, padding='VALID')
# 28 # 28
l = Conv2D('conv4_1', l, out_channel=512, kernel_shape=3) l = Conv2D('conv4_1', l, 512)
l = Conv2D('conv4_2', l, out_channel=512, kernel_shape=3) l = Conv2D('conv4_2', l, 512)
l = Conv2D('conv4_3', l, out_channel=512, kernel_shape=3) l = Conv2D('conv4_3', l, 512)
l = MaxPooling('pool4', l, 2, stride=2, padding='VALID') l = MaxPooling('pool4', l, 2, stride=2, padding='VALID')
# 14 # 14
l = Conv2D('conv5_1', l, out_channel=512, kernel_shape=3) l = Conv2D('conv5_1', l, 512)
l = Conv2D('conv5_2', l, out_channel=512, kernel_shape=3) l = Conv2D('conv5_2', l, 512)
l = Conv2D('conv5_3', l, out_channel=512, kernel_shape=3) l = Conv2D('conv5_3', l, 512)
l = MaxPooling('pool5', l, 2, stride=2, padding='VALID') l = MaxPooling('pool5', l, 2, stride=2, padding='VALID')
# 7 # 7
...@@ -67,26 +74,9 @@ class Model(ModelDesc): ...@@ -67,26 +74,9 @@ class Model(ModelDesc):
logits = FullyConnected('fc8', l, out_dim=1000, nl=tf.identity) logits = FullyConnected('fc8', l, out_dim=1000, nl=tf.identity)
prob = tf.nn.softmax(logits, name='output') prob = tf.nn.softmax(logits, name='output')
y = one_hot(label, 1000) cost = tf.nn.sparse_softmax_cross_entropy_with_logits(logits, label)
cost = tf.nn.softmax_cross_entropy_with_logits(logits, y) cost = tf.reduce_mean(cost, name='cost')
cost = tf.reduce_mean(cost, name='cross_entropy_loss') return cost
tf.add_to_collection(MOVING_SUMMARY_VARS_KEY, cost)
# compute the number of failed samples, for ValidationError to use at test time
wrong = tf.not_equal(
tf.cast(tf.argmax(prob, 1), tf.int32), label)
wrong = tf.cast(wrong, tf.float32)
nr_wrong = tf.reduce_sum(wrong, name='wrong')
# monitor training error
tf.add_to_collection(
MOVING_SUMMARY_VARS_KEY, tf.reduce_mean(wrong, name='train_error'))
# weight decay on all W of fc layers
wd_cost = tf.mul(1e-4,
regularize_cost('fc.*/W', tf.nn.l2_loss),
name='regularize_loss')
tf.add_to_collection(MOVING_SUMMARY_VARS_KEY, wd_cost)
return tf.add_n([wd_cost, cost], name='cost')
def run_test(path, input): def run_test(path, input):
param_dict = np.load(path).item() param_dict = np.load(path).item()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment