Commit 9908baa2 authored by Yuxin Wu's avatar Yuxin Wu

new model

parent 13e2a3e3
......@@ -19,55 +19,70 @@ class Model(ModelDesc):
def _build_graph(self, input_vars, _):
x, label = input_vars
x = x / 255.0
x = x / 256.0
def tanh_round_bit(x, name=None):
x = tf.tanh(x) * 0.5
return (((x + 0.5) * 3.0 + 0.5) // 1) / 3.0 - 0.5
def quantize(x, name=None):
# quantize to 2 bit
return ((x * 3.0 + 0.5) // 1) / 3.0
x = Conv2D('conv1_1', x, 96, 12, nl=tanh_round_bit, stride=4, padding='VALID')
bn = lambda x, name: BatchNorm('bn', x, False, epsilon=1e-4)
bnc = lambda x, name: tf.clip_by_value(bn(x, None), 0.0, 1.0, name=name)
bnl = lambda x, name: BatchNorm('bn', x, False, epsilon=1e-4)
with argscope([Conv2D, FullyConnected], nl=bnl):
x = Conv2D('conv2_1', x, 256, 5, padding='SAME')
x = tf.pad(x, [[0,0], [1,1], [1,1], [0,0]], "SYMMETRIC")
x = MaxPooling('pool1', x, 3, stride=2, padding='VALID')
x = tanh_round_bit(x)
def conv_split(name, x, channel, shape):
inputs = tf.split(3, 2, x)
x0 = Conv2D(name + 'a', inputs[0], channel/2, shape)
x1 = Conv2D(name + 'b', inputs[1], channel/2, shape)
return tf.concat(3, [x0, x1])
with argscope([Conv2D, FullyConnected], nl=bnc):
x = Conv2D('conv1_1', x, 96, 12, stride=4, padding='VALID')
x = quantize(x)
x = conv_split('conv2_1', x, 256, 5)
x = tf.pad(x, [[0,0], [1,1], [1,1], [0,0]])
x = MaxPooling('pool1', x, 3, 2)
x = quantize(x)
x = Conv2D('conv3_1', x, 384, 3)
x = tf.pad(x, [[0,0], [1,1], [1,1], [0,0]], "SYMMETRIC")
x = MaxPooling('pool2', x, 3, stride=2, padding='VALID')
x = tanh_round_bit(x)
x = tf.pad(x, [[0,0], [1,1], [1,1], [0,0]])
x = MaxPooling('pool2', x, 3, 2)
x = quantize(x)
x = Conv2D('conv4_1', x, 384, 3)
x = tanh_round_bit(x)
x = conv_split('conv4_1', x, 384, 3)
x = quantize(x)
x = Conv2D('conv5_1', x, 256, 3)
x = MaxPooling('pool3', x, 3, stride=2, padding='VALID')
x = tanh_round_bit(x)
x = conv_split('conv5_1', x, 256, 3)
x = MaxPooling('pool3', x, 3, 2)
x = quantize(x)
x = tf.transpose(x, perm=[0,3,1,2])
x = tf.nn.dropout(x, keep_prob=1.)
x = FullyConnected('fc0', x, out_dim=4096)
x = tanh_round_bit(x)
x = quantize(x)
x = tf.nn.dropout(x, keep_prob=1.)
x = FullyConnected('fc1', x, out_dim=4096)
x = tf.tanh(x) * 0.5
logits = FullyConnected('fct', x, out_dim=1000)
logits = FullyConnected('fct', x, out_dim=1000, nl=bn)
prob = tf.nn.softmax(logits, name='prob')
nr_wrong = tf.reduce_sum(prediction_incorrect(logits, label), name='wrong-top1')
nr_wrong = tf.reduce_sum(prediction_incorrect(logits, label, 5), name='wrong-top5')
def eval_on_ILSVRC12(model, sess_init, data_dir):
ds = dataset.ILSVRC12(data_dir, 'val', shuffle=False)
def resize_func(im):
h, w = im.shape[:2]
scale = 256.0 / min(h, w)
desSize = map(int, (max(224, min(w, scale * w)),\
max(224, min(h, scale * h))))
im = cv2.resize(im, tuple(desSize), interpolation=cv2.INTER_CUBIC)
return im
transformers = [
imgaug.Resize((256, 256)),
imgaug.AugmentWithFunc(resize_func),
imgaug.CenterCrop((224, 224)),
]
ds = AugmentImageComponent(ds, transformers)
ds = BatchData(ds, 128, remainder=True)
ds = PrefetchData(ds, 10, nr_proc=1)
ds = PrefetchData(ds, 10, 1)
cfg = PredictConfig(
model=model,
......@@ -83,8 +98,6 @@ def eval_on_ILSVRC12(model, sess_init, data_dir):
batch_size = output.shape[0]
acc1.feed(w1, batch_size)
acc5.feed(w5, batch_size)
if idx == 10:
print("Top1 Error: {} after {} images".format(acc1.ratio, acc1.count))
print("Top1 Error: {}".format(acc1.ratio))
print("Top5 Error: {}".format(acc5.ratio))
......@@ -111,12 +124,6 @@ def run_test(model, sess_init, inputs):
print(f + ":")
print(list(zip(names, prob[ret])))
# save the metagraph
#saver = tf.train.Saver()
#saver.export_meta_graph('graph.meta', collection_list=
#[INPUT_VARS_KEY, tf.GraphKeys.VARIABLES, tf.GraphKeys.TRAINABLE_VARIABLES], as_text=True)
#saver.save(predict_func.session, 'alexnet.tfmodel', write_meta_graph=False)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--load', help='path to the saved model parameters', required=True)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment