Commit db475954 authored by ppwwyyxx's avatar ppwwyyxx

validation extension

parent f1c514a8
...@@ -8,7 +8,6 @@ import numpy as np ...@@ -8,7 +8,6 @@ import numpy as np
__all__ = ['BatchData'] __all__ = ['BatchData']
class BatchData(object): class BatchData(object):
def __init__(self, ds, batch_size): def __init__(self, ds, batch_size):
self.ds = ds self.ds = ds
self.batch_size = batch_size self.batch_size = batch_size
...@@ -30,4 +29,3 @@ class BatchData(object): ...@@ -30,4 +29,3 @@ class BatchData(object):
np.array([x[k] for x in data_holder], np.array([x[k] for x in data_holder],
dtype=data_holder[0][k].dtype)) dtype=data_holder[0][k].dtype))
return tuple(result) return tuple(result)
...@@ -5,6 +5,9 @@ ...@@ -5,6 +5,9 @@
import tensorflow as tf import tensorflow as tf
import numpy as np import numpy as np
from itertools import count
from layers import * from layers import *
from utils import * from utils import *
from dataflow.dataset import Mnist from dataflow.dataset import Mnist
...@@ -31,7 +34,7 @@ def get_model(input, label): ...@@ -31,7 +34,7 @@ def get_model(input, label):
fc1 = FullyConnected('fc1', fc0, out_dim=200) fc1 = FullyConnected('fc1', fc0, out_dim=200)
fc1 = tf.nn.relu(fc1) fc1 = tf.nn.relu(fc1)
fc2 = FullyConnected('lr', fc1, out_dim=10) fc2 = FullyConnected('lr', fc1, out_dim=10)
prob = tf.nn.softmax(fc2) prob = tf.nn.softmax(fc2, name='output')
logprob = tf.log(prob) logprob = tf.log(prob)
y = one_hot(label, NUM_CLASS) y = one_hot(label, NUM_CLASS)
...@@ -41,63 +44,59 @@ def get_model(input, label): ...@@ -41,63 +44,59 @@ def get_model(input, label):
tf.scalar_summary(cost.op.name, cost) tf.scalar_summary(cost.op.name, cost)
return prob, cost return prob, cost
def get_eval(prob, labels): #def get_eval(prob, labels):
""" #"""
Args: #Args:
prob: bx10 #prob: bx10
labels: b #labels: b
Returns: #Returns:
scalar float: accuracy #scalar float: accuracy
""" #"""
correct = tf.nn.in_top_k(prob, labels, 1) #correct = tf.nn.in_top_k(prob, labels, 1)
nr_correct = tf.reduce_sum(tf.cast(correct, tf.int32)) #nr_correct = tf.reduce_sum(tf.cast(correct, tf.int32))
return tf.cast(nr_correct, tf.float32) / tf.cast(tf.size(labels), tf.float32) #return tf.cast(nr_correct, tf.float32) / tf.cast(tf.size(labels), tf.float32)
def main(): def main():
dataset_train = BatchData(Mnist('train'), batch_size) dataset_train = Mnist('train')
dataset_test = BatchData(Mnist('test'), batch_size) dataset_test = Mnist('test')
with tf.Graph().as_default(): with tf.Graph().as_default():
input_var = tf.placeholder(tf.float32, shape=(batch_size, PIXELS)) input_var = tf.placeholder(tf.float32, shape=(None, PIXELS), name='input')
label_var = tf.placeholder(tf.int32, shape=(batch_size,)) label_var = tf.placeholder(tf.int32, shape=(None,), name='label')
prob, cost = get_model(input_var, label_var) prob, cost = get_model(input_var, label_var)
optimizer = tf.train.AdagradOptimizer(0.01) optimizer = tf.train.AdagradOptimizer(0.01)
train_op = optimizer.minimize(cost) train_op = optimizer.minimize(cost)
eval_op = get_eval(prob, label_var) validation_ext = OnehotClassificationValidation(
BatchData(dataset_test, batch_size), 'test')
validation_ext.init()
summary_op = tf.merge_all_summaries() summary_op = tf.merge_all_summaries()
saver = tf.train.Saver() saver = tf.train.Saver()
sess = tf.Session() sess = tf.Session()
init = tf.initialize_all_variables() sess.run(tf.initialize_all_variables())
sess.run(init)
summary_writer = tf.train.SummaryWriter(LOG_DIR, summary_writer = tf.train.SummaryWriter(LOG_DIR,
graph_def=sess.graph_def) graph_def=sess.graph_def)
epoch = 0 with sess.as_default():
while True: for epoch in count(1):
epoch += 1 for (img, label) in BatchData(dataset_train, batch_size).get_data():
for (img, label) in dataset_train.get_data(): feed = {input_var: img,
feed = {input_var: img, label_var: label}
label_var: label}
_, cost_value = sess.run([train_op, cost], feed_dict=feed)
print('Epoch %d: cost = %.2f' % (epoch, cost_value)) _, cost_value = sess.run([train_op, cost], feed_dict=feed)
summary_str = sess.run(summary_op, feed_dict=feed) print('Epoch %d: last batch cost = %.2f' % (epoch, cost_value))
summary_writer.add_summary(summary_str, epoch)
if epoch % 2 == 0: summary_str = sess.run(summary_op, feed_dict=feed)
saver.save(sess, LOG_DIR, global_step=epoch) summary_writer.add_summary(summary_str, epoch)
scores = [] if epoch % 2 == 0:
for (img, label) in dataset_test.get_data(): saver.save(sess, LOG_DIR, global_step=epoch)
feed = {input_var: img, label_var: label} validation_ext.trigger()
scores.append(sess.run(eval_op, feed_dict=feed))
print "Test Scores: {}".format(np.array(scores).mean())
......
...@@ -3,16 +3,17 @@ ...@@ -3,16 +3,17 @@
# File: __init__.py # File: __init__.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com> # Author: Yuxin Wu <ppwwyyxx@gmail.com>
import tensorflow as tf from pkgutil import walk_packages
import os
import os.path
def global_import(name):
p = __import__(name, globals(), locals())
lst = p.__all__ if '__all__' in dir(p) else dir(p)
for k in lst:
globals()[k] = p.__dict__[k]
def one_hot(y, num_labels): for _, module_name, _ in walk_packages(
batch_size = y.get_shape().as_list()[0] [os.path.dirname(__file__)]):
assert type(batch_size) == int, type(batch_size) if not module_name.startswith('_'):
y = tf.expand_dims(y, 1) global_import(module_name)
indices = tf.expand_dims(tf.range(0, batch_size), 1)
concated = tf.concat(1, [indices, y])
onehot_labels = tf.sparse_to_dense(
concated, tf.pack([batch_size, num_labels]), 1.0, 0.0)
onehot_labels.set_shape([batch_size, num_labels])
return tf.cast(onehot_labels, tf.float32)
#!/usr/bin/env python2
# -*- coding: UTF-8 -*-
# File: extension.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
import tensorflow as tf
import numpy as np
class OnehotClassificationValidation(object):
"""
use with output: bxn probability
and label: (b,) vector
"""
def __init__(self, ds, prefix,
input_op_name='input',
label_op_name='label',
output_op_name='output'):
self.ds = ds
self.input_op_name = input_op_name
self.output_op_name = output_op_name
self.label_op_name = label_op_name
def init(self):
self.graph = tf.get_default_graph()
with tf.name_scope('validation'):
self.input_var = self.graph.get_operation_by_name(self.input_op_name).outputs[0]
self.label_var = self.graph.get_operation_by_name(self.label_op_name).outputs[0]
self.output_var = self.graph.get_operation_by_name(self.output_op_name).outputs[0]
correct = tf.equal(tf.cast(tf.argmax(self.output_var, 1), tf.int32),
self.label_var)
# TODO: add cost
self.accuracy_var = tf.reduce_mean(tf.cast(correct, tf.float32))
def trigger(self):
scores = []
for (img, label) in self.ds.get_data():
feed = {self.input_var: img, self.label_var: label}
scores.append(
self.accuracy_var.eval(feed_dict=feed))
acc = np.array(scores, dtype='float32').mean()
# TODO write to summary?
print "Accuracy: ", acc
#!/usr/bin/env python2
# -*- coding: UTF-8 -*-
# File: utils.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
import tensorflow as tf
__all__ = ['one_hot']
def one_hot(y, num_labels):
batch_size = tf.size(y)
y = tf.expand_dims(y, 1)
indices = tf.expand_dims(tf.range(0, batch_size), 1)
concated = tf.concat(1, [indices, y])
onehot_labels = tf.sparse_to_dense(
concated, tf.pack([batch_size, num_labels]), 1.0, 0.0)
onehot_labels.set_shape([None, num_labels])
return tf.cast(onehot_labels, tf.float32)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment