Commit 12f2866e authored by ppwwyyxx's avatar ppwwyyxx

extension

parent db475954
......@@ -8,9 +8,15 @@ import numpy as np
__all__ = ['BatchData']
class BatchData(object):
def __init__(self, ds, batch_size):
def __init__(self, ds, batch_size, remainder=False):
"""
Args:
ds: a dataflow
remainder: whether to return the remaining data smaller than a batch_size
"""
self.ds = ds
self.batch_size = batch_size
self.remainder = remainder
def get_data(self):
holder = []
......@@ -19,6 +25,8 @@ class BatchData(object):
if len(holder) == self.batch_size:
yield BatchData.aggregate_batch(holder)
holder = []
if self.remainder and len(holder) > 0:
yield BatchData.aggregate_batch(holder)
@staticmethod
def aggregate_batch(data_holder):
......
......@@ -44,22 +44,16 @@ def get_model(input, label):
tf.scalar_summary(cost.op.name, cost)
return prob, cost
#def get_eval(prob, labels):
#"""
#Args:
#prob: bx10
#labels: b
#Returns:
#scalar float: accuracy
#"""
#correct = tf.nn.in_top_k(prob, labels, 1)
#nr_correct = tf.reduce_sum(tf.cast(correct, tf.int32))
#return tf.cast(nr_correct, tf.float32) / tf.cast(tf.size(labels), tf.float32)
def main():
dataset_train = Mnist('train')
dataset_test = Mnist('test')
extensions = [
OnehotClassificationValidation(
BatchData(dataset_test, batch_size, remainder=True),
prefix='test', period=2),
PeriodicSaver(LOG_DIR, period=2)
]
with tf.Graph().as_default():
input_var = tf.placeholder(tf.float32, shape=(None, PIXELS), name='input')
label_var = tf.placeholder(tf.int32, shape=(None,), name='label')
......@@ -69,17 +63,14 @@ def main():
optimizer = tf.train.AdagradOptimizer(0.01)
train_op = optimizer.minimize(cost)
validation_ext = OnehotClassificationValidation(
BatchData(dataset_test, batch_size), 'test')
validation_ext.init()
for ext in extensions:
ext.init()
summary_op = tf.merge_all_summaries()
saver = tf.train.Saver()
sess = tf.Session()
sess.run(tf.initialize_all_variables())
summary_writer = tf.train.SummaryWriter(LOG_DIR,
graph_def=sess.graph_def)
summary_writer = tf.train.SummaryWriter(LOG_DIR, graph_def=sess.graph_def)
with sess.as_default():
for epoch in count(1):
......@@ -90,13 +81,11 @@ def main():
_, cost_value = sess.run([train_op, cost], feed_dict=feed)
print('Epoch %d: last batch cost = %.2f' % (epoch, cost_value))
summary_str = sess.run(summary_op, feed_dict=feed)
summary_str = summary_op.eval(feed_dict=feed)
summary_writer.add_summary(summary_str, epoch)
if epoch % 2 == 0:
saver.save(sess, LOG_DIR, global_step=epoch)
validation_ext.trigger()
for ext in extensions:
ext.trigger()
......
......@@ -4,17 +4,47 @@
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
import tensorflow as tf
import sys
import numpy as np
import os
from abc import abstractmethod
class OnehotClassificationValidation(object):
class Extension(object):
def init(self):
pass
@abstractmethod
def trigger(self):
pass
class PeriodicExtension(Extension):
def __init__(self, period):
self.__period = period
self.epoch_num = 0
def init(self):
pass
def trigger(self):
self.epoch_num += 1
if self.epoch_num % self.__period == 0:
self._trigger()
@abstractmethod
def _trigger(self):
pass
class OnehotClassificationValidation(PeriodicExtension):
"""
use with output: bxn probability
and label: (b,) vector
"""
def __init__(self, ds, prefix,
period=1,
input_op_name='input',
label_op_name='label',
output_op_name='output'):
super(OnehotClassificationValidation, self).__init__(period)
self.ds = ds
self.input_op_name = input_op_name
self.output_op_name = output_op_name
......@@ -30,15 +60,29 @@ class OnehotClassificationValidation(object):
correct = tf.equal(tf.cast(tf.argmax(self.output_var, 1), tf.int32),
self.label_var)
# TODO: add cost
self.accuracy_var = tf.reduce_mean(tf.cast(correct, tf.float32))
self.nr_correct_var = tf.reduce_sum(tf.cast(correct, tf.int32))
def trigger(self):
scores = []
def _trigger(self):
cnt = 0
cnt_correct = 0
for (img, label) in self.ds.get_data():
# TODO dropout?
feed = {self.input_var: img, self.label_var: label}
scores.append(
self.accuracy_var.eval(feed_dict=feed))
acc = np.array(scores, dtype='float32').mean()
cnt += img.shape[0]
cnt_correct += self.nr_correct_var.eval(feed_dict=feed)
# TODO write to summary?
print "Accuracy: ", acc
print "Accuracy at epoch {}: {}".format(
self.epoch_num, cnt_correct * 1.0 / cnt)
class PeriodicSaver(PeriodicExtension):
def __init__(self, log_dir, period=1):
super(PeriodicSaver, self).__init__(period)
self.path = os.path.join(log_dir, 'model')
def init(self):
self.saver = tf.train.Saver(max_to_keep=99999)
def _trigger(self):
self.saver.save(tf.get_default_session(), self.path,
global_step=self.epoch_num, latest_filename='latest')
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment