Commit 9908baa2 authored by Yuxin Wu's avatar Yuxin Wu

new model

parent 13e2a3e3
...@@ -19,55 +19,70 @@ class Model(ModelDesc): ...@@ -19,55 +19,70 @@ class Model(ModelDesc):
def _build_graph(self, input_vars, _): def _build_graph(self, input_vars, _):
x, label = input_vars x, label = input_vars
x = x / 255.0 x = x / 256.0
def tanh_round_bit(x, name=None): def quantize(x, name=None):
x = tf.tanh(x) * 0.5 # quantize to 2 bit
return (((x + 0.5) * 3.0 + 0.5) // 1) / 3.0 - 0.5 return ((x * 3.0 + 0.5) // 1) / 3.0
x = Conv2D('conv1_1', x, 96, 12, nl=tanh_round_bit, stride=4, padding='VALID') bn = lambda x, name: BatchNorm('bn', x, False, epsilon=1e-4)
bnc = lambda x, name: tf.clip_by_value(bn(x, None), 0.0, 1.0, name=name)
bnl = lambda x, name: BatchNorm('bn', x, False, epsilon=1e-4) def conv_split(name, x, channel, shape):
with argscope([Conv2D, FullyConnected], nl=bnl): inputs = tf.split(3, 2, x)
x = Conv2D('conv2_1', x, 256, 5, padding='SAME') x0 = Conv2D(name + 'a', inputs[0], channel/2, shape)
x = tf.pad(x, [[0,0], [1,1], [1,1], [0,0]], "SYMMETRIC") x1 = Conv2D(name + 'b', inputs[1], channel/2, shape)
x = MaxPooling('pool1', x, 3, stride=2, padding='VALID') return tf.concat(3, [x0, x1])
x = tanh_round_bit(x)
with argscope([Conv2D, FullyConnected], nl=bnc):
x = Conv2D('conv1_1', x, 96, 12, stride=4, padding='VALID')
x = quantize(x)
x = conv_split('conv2_1', x, 256, 5)
x = tf.pad(x, [[0,0], [1,1], [1,1], [0,0]])
x = MaxPooling('pool1', x, 3, 2)
x = quantize(x)
x = Conv2D('conv3_1', x, 384, 3) x = Conv2D('conv3_1', x, 384, 3)
x = tf.pad(x, [[0,0], [1,1], [1,1], [0,0]], "SYMMETRIC") x = tf.pad(x, [[0,0], [1,1], [1,1], [0,0]])
x = MaxPooling('pool2', x, 3, stride=2, padding='VALID') x = MaxPooling('pool2', x, 3, 2)
x = tanh_round_bit(x) x = quantize(x)
x = Conv2D('conv4_1', x, 384, 3) x = conv_split('conv4_1', x, 384, 3)
x = tanh_round_bit(x) x = quantize(x)
x = Conv2D('conv5_1', x, 256, 3) x = conv_split('conv5_1', x, 256, 3)
x = MaxPooling('pool3', x, 3, stride=2, padding='VALID') x = MaxPooling('pool3', x, 3, 2)
x = tanh_round_bit(x) x = quantize(x)
x = tf.transpose(x, perm=[0,3,1,2]) x = tf.transpose(x, perm=[0,3,1,2])
x = tf.nn.dropout(x, keep_prob=1.)
x = FullyConnected('fc0', x, out_dim=4096) x = FullyConnected('fc0', x, out_dim=4096)
x = tanh_round_bit(x) x = quantize(x)
x = tf.nn.dropout(x, keep_prob=1.)
x = FullyConnected('fc1', x, out_dim=4096) x = FullyConnected('fc1', x, out_dim=4096)
x = tf.tanh(x) * 0.5 logits = FullyConnected('fct', x, out_dim=1000, nl=bn)
logits = FullyConnected('fct', x, out_dim=1000)
prob = tf.nn.softmax(logits, name='prob') prob = tf.nn.softmax(logits, name='prob')
nr_wrong = tf.reduce_sum(prediction_incorrect(logits, label), name='wrong-top1') nr_wrong = tf.reduce_sum(prediction_incorrect(logits, label), name='wrong-top1')
nr_wrong = tf.reduce_sum(prediction_incorrect(logits, label, 5), name='wrong-top5') nr_wrong = tf.reduce_sum(prediction_incorrect(logits, label, 5), name='wrong-top5')
def eval_on_ILSVRC12(model, sess_init, data_dir): def eval_on_ILSVRC12(model, sess_init, data_dir):
ds = dataset.ILSVRC12(data_dir, 'val', shuffle=False) ds = dataset.ILSVRC12(data_dir, 'val', shuffle=False)
def resize_func(im):
h, w = im.shape[:2]
scale = 256.0 / min(h, w)
desSize = map(int, (max(224, min(w, scale * w)),\
max(224, min(h, scale * h))))
im = cv2.resize(im, tuple(desSize), interpolation=cv2.INTER_CUBIC)
return im
transformers = [ transformers = [
imgaug.Resize((256, 256)), imgaug.AugmentWithFunc(resize_func),
imgaug.CenterCrop((224, 224)), imgaug.CenterCrop((224, 224)),
] ]
ds = AugmentImageComponent(ds, transformers) ds = AugmentImageComponent(ds, transformers)
ds = BatchData(ds, 128, remainder=True) ds = BatchData(ds, 128, remainder=True)
ds = PrefetchData(ds, 10, nr_proc=1) ds = PrefetchData(ds, 10, 1)
cfg = PredictConfig( cfg = PredictConfig(
model=model, model=model,
...@@ -83,8 +98,6 @@ def eval_on_ILSVRC12(model, sess_init, data_dir): ...@@ -83,8 +98,6 @@ def eval_on_ILSVRC12(model, sess_init, data_dir):
batch_size = output.shape[0] batch_size = output.shape[0]
acc1.feed(w1, batch_size) acc1.feed(w1, batch_size)
acc5.feed(w5, batch_size) acc5.feed(w5, batch_size)
if idx == 10:
print("Top1 Error: {} after {} images".format(acc1.ratio, acc1.count))
print("Top1 Error: {}".format(acc1.ratio)) print("Top1 Error: {}".format(acc1.ratio))
print("Top5 Error: {}".format(acc5.ratio)) print("Top5 Error: {}".format(acc5.ratio))
...@@ -111,12 +124,6 @@ def run_test(model, sess_init, inputs): ...@@ -111,12 +124,6 @@ def run_test(model, sess_init, inputs):
print(f + ":") print(f + ":")
print(list(zip(names, prob[ret]))) print(list(zip(names, prob[ret])))
# save the metagraph
#saver = tf.train.Saver()
#saver.export_meta_graph('graph.meta', collection_list=
#[INPUT_VARS_KEY, tf.GraphKeys.VARIABLES, tf.GraphKeys.TRAINABLE_VARIABLES], as_text=True)
#saver.save(predict_func.session, 'alexnet.tfmodel', write_meta_graph=False)
if __name__ == '__main__': if __name__ == '__main__':
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('--load', help='path to the saved model parameters', required=True) parser.add_argument('--load', help='path to the saved model parameters', required=True)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment