Commit 4267ea4d authored by ppwwyyxx's avatar ppwwyyxx

add dropout

parent 5fab58f3
...@@ -34,26 +34,30 @@ def get_model(input, label): ...@@ -34,26 +34,30 @@ def get_model(input, label):
output: variable output: variable
cost: scalar variable cost: scalar variable
""" """
keep_prob = tf.placeholder(tf.float32, name='dropout_prob')
input = tf.reshape(input, [-1, IMAGE_SIZE, IMAGE_SIZE, 1]) input = tf.reshape(input, [-1, IMAGE_SIZE, IMAGE_SIZE, 1])
conv0 = Conv2D('conv0', input, out_channel=20, kernel_shape=5, conv0 = Conv2D('conv0', input, out_channel=32, kernel_shape=5,
padding='valid') padding='valid')
conv0 = tf.nn.relu(conv0)
pool0 = tf.nn.max_pool(conv0, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], pool0 = tf.nn.max_pool(conv0, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1],
padding='SAME') padding='SAME')
conv1 = Conv2D('conv1', pool0, out_channel=40, kernel_shape=3, conv1 = Conv2D('conv1', pool0, out_channel=40, kernel_shape=3,
padding='valid') padding='valid')
conv1 = tf.nn.relu(conv1)
pool1 = tf.nn.max_pool(conv1, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], pool1 = tf.nn.max_pool(conv1, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1],
padding='SAME') padding='SAME')
conv2 = Conv2D('conv2', pool0, out_channel=40, kernel_shape=3,
padding='valid')
feature = batch_flatten(conv2) feature = batch_flatten(pool1)
fc0 = FullyConnected('fc0', feature, 512) fc0 = FullyConnected('fc0', feature, 1024)
fc0 = tf.nn.relu(fc0) fc0 = tf.nn.relu(fc0)
fc0 = tf.nn.dropout(fc0, keep_prob)
fc1 = FullyConnected('lr', fc0, out_dim=10) fc1 = FullyConnected('lr', fc0, out_dim=10)
prob = tf.nn.softmax(fc1, name='output') prob = tf.nn.softmax(fc1, name='output')
logprob = tf.log(prob) logprob = logSoftmax(fc1)
y = one_hot(label, NUM_CLASS) y = one_hot(label, NUM_CLASS)
cost = tf.reduce_sum(-y * logprob, 1) cost = tf.reduce_sum(-y * logprob, 1)
cost = tf.reduce_mean(cost, name='cost') cost = tf.reduce_mean(cost, name='cost')
...@@ -77,7 +81,7 @@ def main(): ...@@ -77,7 +81,7 @@ def main():
prob, cost = get_model(input_var, label_var) prob, cost = get_model(input_var, label_var)
optimizer = tf.train.AdagradOptimizer(0.01) optimizer = tf.train.AdamOptimizer(1e-4)
train_op = optimizer.minimize(cost) train_op = optimizer.minimize(cost)
for ext in extensions: for ext in extensions:
...@@ -90,11 +94,14 @@ def main(): ...@@ -90,11 +94,14 @@ def main():
sess.run(tf.initialize_all_variables()) sess.run(tf.initialize_all_variables())
summary_writer = tf.train.SummaryWriter(LOG_DIR, graph_def=sess.graph_def) summary_writer = tf.train.SummaryWriter(LOG_DIR, graph_def=sess.graph_def)
g = tf.get_default_graph()
keep_prob = g.get_tensor_by_name('dropout_prob:0')
with sess.as_default(): with sess.as_default():
for epoch in count(1): for epoch in count(1):
for (img, label) in BatchData(dataset_train, batch_size).get_data(): for (img, label) in BatchData(dataset_train, batch_size).get_data():
feed = {input_var: img, feed = {input_var: img,
label_var: label} label_var: label,
keep_prob: 0.5}
_, cost_value = sess.run([train_op, cost], feed_dict=feed) _, cost_value = sess.run([train_op, cost], feed_dict=feed)
......
...@@ -41,21 +41,22 @@ class OnehotClassificationValidation(PeriodicExtension): ...@@ -41,21 +41,22 @@ class OnehotClassificationValidation(PeriodicExtension):
""" """
def __init__(self, ds, prefix, def __init__(self, ds, prefix,
period=1, period=1,
input_op_name='input', input_var_name='input:0',
label_op_name='label', label_var_name='label:0',
output_op_name='output'): output_var_name='output:0'):
super(OnehotClassificationValidation, self).__init__(period) super(OnehotClassificationValidation, self).__init__(period)
self.ds = ds self.ds = ds
self.input_op_name = input_op_name self.input_var_name = input_var_name
self.output_op_name = output_op_name self.output_var_name = output_var_name
self.label_op_name = label_op_name self.label_var_name = label_var_name
def init(self): def init(self):
self.graph = tf.get_default_graph() self.graph = tf.get_default_graph()
with tf.name_scope('validation'): with tf.name_scope('validation'):
self.input_var = self.graph.get_operation_by_name(self.input_op_name).outputs[0] self.input_var = self.graph.get_tensor_by_name(self.input_var_name)
self.label_var = self.graph.get_operation_by_name(self.label_op_name).outputs[0] self.label_var = self.graph.get_tensor_by_name(self.label_var_name)
self.output_var = self.graph.get_operation_by_name(self.output_op_name).outputs[0] self.output_var = self.graph.get_tensor_by_name(self.output_var_name)
self.dropout_var = self.graph.get_tensor_by_name('dropout_prob:0')
correct = tf.equal(tf.cast(tf.argmax(self.output_var, 1), tf.int32), correct = tf.equal(tf.cast(tf.argmax(self.output_var, 1), tf.int32),
self.label_var) self.label_var)
...@@ -66,8 +67,9 @@ class OnehotClassificationValidation(PeriodicExtension): ...@@ -66,8 +67,9 @@ class OnehotClassificationValidation(PeriodicExtension):
cnt = 0 cnt = 0
cnt_correct = 0 cnt_correct = 0
for (img, label) in self.ds.get_data(): for (img, label) in self.ds.get_data():
# TODO dropout? feed = {self.input_var: img,
feed = {self.input_var: img, self.label_var: label} self.label_var: label,
self.dropout_var: 1.0}
cnt += img.shape[0] cnt += img.shape[0]
cnt_correct += self.nr_correct_var.eval(feed_dict=feed) cnt_correct += self.nr_correct_var.eval(feed_dict=feed)
# TODO write to summary? # TODO write to summary?
......
...@@ -5,7 +5,7 @@ ...@@ -5,7 +5,7 @@
import tensorflow as tf import tensorflow as tf
import numpy as np import numpy as np
__all__ = ['one_hot', 'batch_flatten'] __all__ = ['one_hot', 'batch_flatten', 'logSoftmax']
def one_hot(y, num_labels): def one_hot(y, num_labels):
batch_size = tf.size(y) batch_size = tf.size(y)
...@@ -20,3 +20,10 @@ def one_hot(y, num_labels): ...@@ -20,3 +20,10 @@ def one_hot(y, num_labels):
def batch_flatten(x): def batch_flatten(x):
total_dim = np.prod(x.get_shape()[1:].as_list()) total_dim = np.prod(x.get_shape()[1:].as_list())
return tf.reshape(x, [-1, total_dim]) return tf.reshape(x, [-1, total_dim])
def logSoftmax(x):
z = x - tf.reduce_max(x, 1, keep_dims=True)
logprob = z - tf.log(tf.reduce_sum(tf.exp(z), 1, keep_dims=True))
return logprob
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment