Commit db475954 authored by ppwwyyxx's avatar ppwwyyxx

validation extension

parent f1c514a8
......@@ -8,7 +8,6 @@ import numpy as np
__all__ = ['BatchData']
class BatchData(object):
def __init__(self, ds, batch_size):
self.ds = ds
self.batch_size = batch_size
......@@ -30,4 +29,3 @@ class BatchData(object):
np.array([x[k] for x in data_holder],
dtype=data_holder[0][k].dtype))
return tuple(result)
......@@ -5,6 +5,9 @@
import tensorflow as tf
import numpy as np
from itertools import count
from layers import *
from utils import *
from dataflow.dataset import Mnist
......@@ -31,7 +34,7 @@ def get_model(input, label):
fc1 = FullyConnected('fc1', fc0, out_dim=200)
fc1 = tf.nn.relu(fc1)
fc2 = FullyConnected('lr', fc1, out_dim=10)
prob = tf.nn.softmax(fc2)
prob = tf.nn.softmax(fc2, name='output')
logprob = tf.log(prob)
y = one_hot(label, NUM_CLASS)
......@@ -41,63 +44,59 @@ def get_model(input, label):
tf.scalar_summary(cost.op.name, cost)
return prob, cost
def get_eval(prob, labels):
"""
Args:
prob: bx10
labels: b
Returns:
scalar float: accuracy
"""
correct = tf.nn.in_top_k(prob, labels, 1)
#def get_eval(prob, labels):
#"""
#Args:
#prob: bx10
#labels: b
#Returns:
#scalar float: accuracy
#"""
#correct = tf.nn.in_top_k(prob, labels, 1)
nr_correct = tf.reduce_sum(tf.cast(correct, tf.int32))
return tf.cast(nr_correct, tf.float32) / tf.cast(tf.size(labels), tf.float32)
#nr_correct = tf.reduce_sum(tf.cast(correct, tf.int32))
#return tf.cast(nr_correct, tf.float32) / tf.cast(tf.size(labels), tf.float32)
def main():
dataset_train = BatchData(Mnist('train'), batch_size)
dataset_test = BatchData(Mnist('test'), batch_size)
dataset_train = Mnist('train')
dataset_test = Mnist('test')
with tf.Graph().as_default():
input_var = tf.placeholder(tf.float32, shape=(batch_size, PIXELS))
label_var = tf.placeholder(tf.int32, shape=(batch_size,))
input_var = tf.placeholder(tf.float32, shape=(None, PIXELS), name='input')
label_var = tf.placeholder(tf.int32, shape=(None,), name='label')
prob, cost = get_model(input_var, label_var)
optimizer = tf.train.AdagradOptimizer(0.01)
train_op = optimizer.minimize(cost)
eval_op = get_eval(prob, label_var)
validation_ext = OnehotClassificationValidation(
BatchData(dataset_test, batch_size), 'test')
validation_ext.init()
summary_op = tf.merge_all_summaries()
saver = tf.train.Saver()
sess = tf.Session()
init = tf.initialize_all_variables()
sess.run(init)
sess.run(tf.initialize_all_variables())
summary_writer = tf.train.SummaryWriter(LOG_DIR,
graph_def=sess.graph_def)
epoch = 0
while True:
epoch += 1
for (img, label) in dataset_train.get_data():
feed = {input_var: img,
label_var: label}
_, cost_value = sess.run([train_op, cost], feed_dict=feed)
with sess.as_default():
for epoch in count(1):
for (img, label) in BatchData(dataset_train, batch_size).get_data():
feed = {input_var: img,
label_var: label}
print('Epoch %d: cost = %.2f' % (epoch, cost_value))
_, cost_value = sess.run([train_op, cost], feed_dict=feed)
summary_str = sess.run(summary_op, feed_dict=feed)
summary_writer.add_summary(summary_str, epoch)
print('Epoch %d: last batch cost = %.2f' % (epoch, cost_value))
if epoch % 2 == 0:
saver.save(sess, LOG_DIR, global_step=epoch)
summary_str = sess.run(summary_op, feed_dict=feed)
summary_writer.add_summary(summary_str, epoch)
scores = []
for (img, label) in dataset_test.get_data():
feed = {input_var: img, label_var: label}
scores.append(sess.run(eval_op, feed_dict=feed))
print "Test Scores: {}".format(np.array(scores).mean())
if epoch % 2 == 0:
saver.save(sess, LOG_DIR, global_step=epoch)
validation_ext.trigger()
......
......@@ -3,16 +3,17 @@
# File: __init__.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
import tensorflow as tf
from pkgutil import walk_packages
import os
import os.path
def global_import(name):
p = __import__(name, globals(), locals())
lst = p.__all__ if '__all__' in dir(p) else dir(p)
for k in lst:
globals()[k] = p.__dict__[k]
def one_hot(y, num_labels):
batch_size = y.get_shape().as_list()[0]
assert type(batch_size) == int, type(batch_size)
y = tf.expand_dims(y, 1)
indices = tf.expand_dims(tf.range(0, batch_size), 1)
concated = tf.concat(1, [indices, y])
onehot_labels = tf.sparse_to_dense(
concated, tf.pack([batch_size, num_labels]), 1.0, 0.0)
onehot_labels.set_shape([batch_size, num_labels])
return tf.cast(onehot_labels, tf.float32)
for _, module_name, _ in walk_packages(
[os.path.dirname(__file__)]):
if not module_name.startswith('_'):
global_import(module_name)
#!/usr/bin/env python2
# -*- coding: UTF-8 -*-
# File: extension.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
import tensorflow as tf
import numpy as np
class OnehotClassificationValidation(object):
"""
use with output: bxn probability
and label: (b,) vector
"""
def __init__(self, ds, prefix,
input_op_name='input',
label_op_name='label',
output_op_name='output'):
self.ds = ds
self.input_op_name = input_op_name
self.output_op_name = output_op_name
self.label_op_name = label_op_name
def init(self):
self.graph = tf.get_default_graph()
with tf.name_scope('validation'):
self.input_var = self.graph.get_operation_by_name(self.input_op_name).outputs[0]
self.label_var = self.graph.get_operation_by_name(self.label_op_name).outputs[0]
self.output_var = self.graph.get_operation_by_name(self.output_op_name).outputs[0]
correct = tf.equal(tf.cast(tf.argmax(self.output_var, 1), tf.int32),
self.label_var)
# TODO: add cost
self.accuracy_var = tf.reduce_mean(tf.cast(correct, tf.float32))
def trigger(self):
scores = []
for (img, label) in self.ds.get_data():
feed = {self.input_var: img, self.label_var: label}
scores.append(
self.accuracy_var.eval(feed_dict=feed))
acc = np.array(scores, dtype='float32').mean()
# TODO write to summary?
print "Accuracy: ", acc
#!/usr/bin/env python2
# -*- coding: UTF-8 -*-
# File: utils.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
import tensorflow as tf
__all__ = ['one_hot']
def one_hot(y, num_labels):
batch_size = tf.size(y)
y = tf.expand_dims(y, 1)
indices = tf.expand_dims(tf.range(0, batch_size), 1)
concated = tf.concat(1, [indices, y])
onehot_labels = tf.sparse_to_dense(
concated, tf.pack([batch_size, num_labels]), 1.0, 0.0)
onehot_labels.set_shape([None, num_labels])
return tf.cast(onehot_labels, tf.float32)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment