Commit 12f2866e authored by ppwwyyxx's avatar ppwwyyxx

extension

parent db475954
...@@ -8,9 +8,15 @@ import numpy as np ...@@ -8,9 +8,15 @@ import numpy as np
__all__ = ['BatchData'] __all__ = ['BatchData']
class BatchData(object): class BatchData(object):
def __init__(self, ds, batch_size): def __init__(self, ds, batch_size, remainder=False):
"""
Args:
ds: a dataflow
remainder: whether to return the remaining data smaller than a batch_size
"""
self.ds = ds self.ds = ds
self.batch_size = batch_size self.batch_size = batch_size
self.remainder = remainder
def get_data(self): def get_data(self):
holder = [] holder = []
...@@ -19,6 +25,8 @@ class BatchData(object): ...@@ -19,6 +25,8 @@ class BatchData(object):
if len(holder) == self.batch_size: if len(holder) == self.batch_size:
yield BatchData.aggregate_batch(holder) yield BatchData.aggregate_batch(holder)
holder = [] holder = []
if self.remainder and len(holder) > 0:
yield BatchData.aggregate_batch(holder)
@staticmethod @staticmethod
def aggregate_batch(data_holder): def aggregate_batch(data_holder):
......
...@@ -44,22 +44,16 @@ def get_model(input, label): ...@@ -44,22 +44,16 @@ def get_model(input, label):
tf.scalar_summary(cost.op.name, cost) tf.scalar_summary(cost.op.name, cost)
return prob, cost return prob, cost
#def get_eval(prob, labels):
#"""
#Args:
#prob: bx10
#labels: b
#Returns:
#scalar float: accuracy
#"""
#correct = tf.nn.in_top_k(prob, labels, 1)
#nr_correct = tf.reduce_sum(tf.cast(correct, tf.int32))
#return tf.cast(nr_correct, tf.float32) / tf.cast(tf.size(labels), tf.float32)
def main(): def main():
dataset_train = Mnist('train') dataset_train = Mnist('train')
dataset_test = Mnist('test') dataset_test = Mnist('test')
extensions = [
OnehotClassificationValidation(
BatchData(dataset_test, batch_size, remainder=True),
prefix='test', period=2),
PeriodicSaver(LOG_DIR, period=2)
]
with tf.Graph().as_default(): with tf.Graph().as_default():
input_var = tf.placeholder(tf.float32, shape=(None, PIXELS), name='input') input_var = tf.placeholder(tf.float32, shape=(None, PIXELS), name='input')
label_var = tf.placeholder(tf.int32, shape=(None,), name='label') label_var = tf.placeholder(tf.int32, shape=(None,), name='label')
...@@ -69,17 +63,14 @@ def main(): ...@@ -69,17 +63,14 @@ def main():
optimizer = tf.train.AdagradOptimizer(0.01) optimizer = tf.train.AdagradOptimizer(0.01)
train_op = optimizer.minimize(cost) train_op = optimizer.minimize(cost)
validation_ext = OnehotClassificationValidation( for ext in extensions:
BatchData(dataset_test, batch_size), 'test') ext.init()
validation_ext.init()
summary_op = tf.merge_all_summaries() summary_op = tf.merge_all_summaries()
saver = tf.train.Saver()
sess = tf.Session() sess = tf.Session()
sess.run(tf.initialize_all_variables()) sess.run(tf.initialize_all_variables())
summary_writer = tf.train.SummaryWriter(LOG_DIR, summary_writer = tf.train.SummaryWriter(LOG_DIR, graph_def=sess.graph_def)
graph_def=sess.graph_def)
with sess.as_default(): with sess.as_default():
for epoch in count(1): for epoch in count(1):
...@@ -90,13 +81,11 @@ def main(): ...@@ -90,13 +81,11 @@ def main():
_, cost_value = sess.run([train_op, cost], feed_dict=feed) _, cost_value = sess.run([train_op, cost], feed_dict=feed)
print('Epoch %d: last batch cost = %.2f' % (epoch, cost_value)) print('Epoch %d: last batch cost = %.2f' % (epoch, cost_value))
summary_str = summary_op.eval(feed_dict=feed)
summary_str = sess.run(summary_op, feed_dict=feed)
summary_writer.add_summary(summary_str, epoch) summary_writer.add_summary(summary_str, epoch)
if epoch % 2 == 0: for ext in extensions:
saver.save(sess, LOG_DIR, global_step=epoch) ext.trigger()
validation_ext.trigger()
......
...@@ -4,17 +4,47 @@ ...@@ -4,17 +4,47 @@
# Author: Yuxin Wu <ppwwyyxx@gmail.com> # Author: Yuxin Wu <ppwwyyxx@gmail.com>
import tensorflow as tf import tensorflow as tf
import sys
import numpy as np import numpy as np
import os
from abc import abstractmethod
class OnehotClassificationValidation(object): class Extension(object):
def init(self):
pass
@abstractmethod
def trigger(self):
pass
class PeriodicExtension(Extension):
def __init__(self, period):
self.__period = period
self.epoch_num = 0
def init(self):
pass
def trigger(self):
self.epoch_num += 1
if self.epoch_num % self.__period == 0:
self._trigger()
@abstractmethod
def _trigger(self):
pass
class OnehotClassificationValidation(PeriodicExtension):
""" """
use with output: bxn probability use with output: bxn probability
and label: (b,) vector and label: (b,) vector
""" """
def __init__(self, ds, prefix, def __init__(self, ds, prefix,
period=1,
input_op_name='input', input_op_name='input',
label_op_name='label', label_op_name='label',
output_op_name='output'): output_op_name='output'):
super(OnehotClassificationValidation, self).__init__(period)
self.ds = ds self.ds = ds
self.input_op_name = input_op_name self.input_op_name = input_op_name
self.output_op_name = output_op_name self.output_op_name = output_op_name
...@@ -30,15 +60,29 @@ class OnehotClassificationValidation(object): ...@@ -30,15 +60,29 @@ class OnehotClassificationValidation(object):
correct = tf.equal(tf.cast(tf.argmax(self.output_var, 1), tf.int32), correct = tf.equal(tf.cast(tf.argmax(self.output_var, 1), tf.int32),
self.label_var) self.label_var)
# TODO: add cost # TODO: add cost
self.accuracy_var = tf.reduce_mean(tf.cast(correct, tf.float32)) self.nr_correct_var = tf.reduce_sum(tf.cast(correct, tf.int32))
def trigger(self): def _trigger(self):
scores = [] cnt = 0
cnt_correct = 0
for (img, label) in self.ds.get_data(): for (img, label) in self.ds.get_data():
# TODO dropout?
feed = {self.input_var: img, self.label_var: label} feed = {self.input_var: img, self.label_var: label}
scores.append( cnt += img.shape[0]
self.accuracy_var.eval(feed_dict=feed)) cnt_correct += self.nr_correct_var.eval(feed_dict=feed)
acc = np.array(scores, dtype='float32').mean()
# TODO write to summary? # TODO write to summary?
print "Accuracy: ", acc print "Accuracy at epoch {}: {}".format(
self.epoch_num, cnt_correct * 1.0 / cnt)
class PeriodicSaver(PeriodicExtension):
def __init__(self, log_dir, period=1):
super(PeriodicSaver, self).__init__(period)
self.path = os.path.join(log_dir, 'model')
def init(self):
self.saver = tf.train.Saver(max_to_keep=99999)
def _trigger(self):
self.saver.save(tf.get_default_session(), self.path,
global_step=self.epoch_num, latest_filename='latest')
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment