Commit 341a5d43 authored by ppwwyyxx's avatar ppwwyyxx

inputs, outputs, and better ext

parent dea224ac
......@@ -24,18 +24,22 @@ NUM_CLASS = 10
batch_size = 128
LOG_DIR = 'train_log'
def get_model(input, label):
def get_model(inputs):
"""
Args:
input: bx28x28
label: bx1 integer
inputs: a list of input variable,
e.g.: [input, label] with:
input: bx28x28
label: bx1 integer
Returns:
(output, cost)
output: variable
(outputs, cost)
outputs: a list of output variable
cost: scalar variable
"""
# use this dropout variable! it will be set to 1 at test time
keep_prob = tf.placeholder(tf.float32, shape=tuple(), name='dropout_prob')
keep_prob = tf.placeholder(tf.float32, shape=tuple(), name=DROPOUT_PROB_OP_NAME)
input, label = inputs
input = tf.reshape(input, [-1, IMAGE_SIZE, IMAGE_SIZE, 1])
conv0 = Conv2D('conv0', input, out_channel=32, kernel_shape=5,
......@@ -62,17 +66,23 @@ def get_model(input, label):
cost = tf.nn.softmax_cross_entropy_with_logits(fc1, y)
cost = tf.reduce_mean(cost, name='cost')
tf.scalar_summary(cost.op.name, cost)
return prob, cost
# number of correctly classified samples
correct = tf.equal(
tf.cast(tf.argmax(prob, 1), tf.int32), label)
correct = tf.reduce_sum(tf.cast(correct, tf.int32), name='correct')
return [prob, correct], cost
def main():
dataset_train = BatchData(Mnist('train'), batch_size)
dataset_test = BatchData(Mnist('test'), batch_size, remainder=True)
extensions = [
OnehotClassificationValidation(
callbacks = [
SummaryWriter(LOG_DIR),
AccuracyValidation(
dataset_test,
prefix='test', period=2),
PeriodicSaver(LOG_DIR, period=2)
prefix='test', period=1),
TrainingAccuracy(),
PeriodicSaver(LOG_DIR, period=1)
]
optimizer = tf.train.AdamOptimizer(1e-4)
sess_config = tf.ConfigProto()
......@@ -80,38 +90,42 @@ def main():
with tf.Graph().as_default():
G = tf.get_default_graph()
input_var = tf.placeholder(tf.float32, shape=(None, IMAGE_SIZE, IMAGE_SIZE), name='input')
image_var = tf.placeholder(tf.float32, shape=(None, IMAGE_SIZE, IMAGE_SIZE), name='input')
label_var = tf.placeholder(tf.int32, shape=(None,), name='label')
prob, cost = get_model(input_var, label_var)
input_vars = [image_var, label_var]
train_op = optimizer.minimize(cost)
for v in input_vars:
G.add_to_collection(INPUT_VARS_KEY, v)
output_vars, cost_var = get_model(input_vars)
for v in output_vars:
G.add_to_collection(OUTPUT_VARS_KEY, v)
for ext in extensions:
ext.init()
train_op = optimizer.minimize(cost_var)
summary_op = tf.merge_all_summaries()
sess = tf.Session(config=sess_config)
sess.run(tf.initialize_all_variables())
summary_writer = tf.train.SummaryWriter(LOG_DIR, graph_def=sess.graph_def)
keep_prob = G.get_tensor_by_name('dropout_prob:0')
with sess.as_default():
for epoch in count(1):
running_cost = StatCounter()
for (img, label) in dataset_train.get_data():
feed = {input_var: img,
label_var: label,
keep_prob: 0.5}
for ext in callbacks:
ext.before_train()
_, cost_value = sess.run([train_op, cost], feed_dict=feed)
running_cost.feed(cost_value)
print('Epoch %d: avg cost = %.2f' % (epoch, running_cost.average))
summary_str = summary_op.eval(feed_dict=feed)
summary_writer.add_summary(summary_str, epoch)
for ext in extensions:
ext.trigger()
keep_prob_var = G.get_tensor_by_name(DROPOUT_PROB_VAR_NAME)
for epoch in count(1):
for dp in dataset_train.get_data():
feed = {keep_prob_var: 0.5}
feed.update(dict(zip(input_vars, dp)))
results = sess.run(
[train_op, cost_var] + output_vars, feed_dict=feed)
cost = results[1]
outputs = results[2:]
assert len(outputs) == len(output_vars)
for cb in callbacks:
cb.trigger_step(dp, outputs, cost)
for cb in callbacks:
cb.trigger_epoch()
summary_writer.close()
......
#!/usr/bin/env python2
# -*- coding: UTF-8 -*-
# File: callback.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
import tensorflow as tf
import sys
import numpy as np
import os
from abc import abstractmethod
from .stat import *
from .utils import *
from .naming import *
class Callback(object):
def before_train(self):
self.graph = tf.get_default_graph()
self.sess = tf.get_default_session()
self._before_train()
def _before_train(self):
"""
Called before training
"""
# trigger after every step
def trigger_step(self, dp, outputs, cost):
"""
Args:
dp: the input dict fed into the graph
outputs: list of output values after running this dp
cost: the cost value after running this dp
"""
pass
# trigger after every epoch
def trigger_epoch(self):
pass
class PeriodicCallback(Callback):
def __init__(self, period):
self.__period = period
self.epoch_num = 0
def trigger_epoch(self):
self.epoch_num += 1
if self.epoch_num % self.__period == 0:
self._trigger()
@abstractmethod
def _trigger(self):
pass
class AccuracyValidation(PeriodicCallback):
"""
Validate the accuracy for the given correct and cost variable
Use under the following setup:
correct_var: integer, number of correct samples in this batch
ds: batched dataset
"""
def __init__(self, ds, prefix,
period=1,
correct_var_name='correct:0',
cost_var_name='cost:0'):
super(AccuracyValidation, self).__init__(period)
self.ds = ds
self.prefix = prefix
self.correct_var_name = correct_var_name
self.cost_var_name = cost_var_name
def get_tensor(self, name):
return self.graph.get_tensor_by_name(name)
def _before_train(self):
self.input_vars = self.graph.get_collection(INPUT_VARS_KEY)
self.dropout_var = self.get_tensor(DROPOUT_PROB_VAR_NAME)
self.correct_var = self.get_tensor(self.correct_var_name)
self.cost_var = self.get_tensor(self.cost_var_name)
try:
self.writer = tf.get_collection(SUMMARY_WRITER_COLLECTION_KEY)[0]
except Exception as e:
print "SummaryWriter should be the first extension!"
raise
def _trigger(self):
cnt = 0
correct_stat = Accuracy()
cost_sum = 0
for dp in self.ds.get_data():
feed = {self.dropout_var: 1.0}
feed.update(dict(zip(self.input_vars, dp)))
batch_size = dp[0].shape[0] # assume batched input
cnt += batch_size
correct, cost = self.sess.run(
[self.correct_var, self.cost_var], feed_dict=feed)
correct_stat.feed(correct, batch_size)
# each batch might not have the same size in validation
cost_sum += cost * batch_size
cost_avg = cost_sum / cnt
self.writer.add_summary(
create_summary('{} accuracy'.format(self.prefix),
correct_stat.accuracy),
self.epoch_num)
self.writer.add_summary(
create_summary('{} cost'.format(self.prefix),
cost_avg),
self.epoch_num)
print "{} validation after epoch {}: acc={}, cost={}".format(
self.prefix, self.epoch_num, correct_stat.accuracy, cost_avg)
class TrainingAccuracy(Callback):
def __init__(self, correct_var_name='correct:0'):
"""
correct_var: number of correct sample in this batch
"""
self.correct_var_name = correct_var_name
self.epoch_num = 0
def _before_train(self):
try:
self.writer = tf.get_collection(SUMMARY_WRITER_COLLECTION_KEY)[0]
except Exception as e:
print "SummaryWriter should be the first extension!"
raise
output_vars = self.graph.get_collection(OUTPUT_VARS_KEY)
for idx, var in enumerate(output_vars):
if var.name == self.correct_var_name:
self.correct_output_idx = idx
break
else:
raise RuntimeError(
"'correct' variable must be in the model outputs to use TrainingAccuracy")
self.running_cost = StatCounter()
self.running_acc = Accuracy()
def trigger_step(self, inputs, outputs, cost):
self.running_cost.feed(cost)
self.running_acc.feed(
outputs[self.correct_output_idx],
inputs[0].shape[0]) # assume batch input
def trigger_epoch(self):
self.epoch_num += 1
print('Training average in Epoch {}: cost={}, acc={}'.format
(self.epoch_num, self.running_cost.average,
self.running_acc.accuracy))
self.writer.add_summary(
create_summary('training average accuracy', self.running_acc.accuracy),
self.epoch_num)
self.writer.add_summary(
create_summary('training average cost', self.running_cost.average),
self.epoch_num)
self.running_cost.reset()
self.running_acc.reset()
class PeriodicSaver(PeriodicCallback):
def __init__(self, log_dir, period=1):
super(PeriodicSaver, self).__init__(period)
self.path = os.path.join(log_dir, 'model')
def _before_train(self):
self.saver = tf.train.Saver(max_to_keep=99999)
def _trigger(self):
self.saver.save(tf.get_default_session(), self.path,
global_step=self.epoch_num, latest_filename='latest')
class SummaryWriter(Callback):
def __init__(self, log_dir):
self.log_dir = log_dir
self.epoch_num = 0
def _before_train(self):
sess = tf.get_default_session()
graph = tf.get_default_graph()
self.writer = tf.train.SummaryWriter(
self.log_dir, graph_def=sess.graph_def)
graph.add_to_collection(SUMMARY_WRITER_COLLECTION_KEY, self.writer)
self.summary_op = tf.merge_all_summaries()
def trigger_step(self, dp, outputs, cost):
self.last_dp = dp
def trigger_epoch(self):
# check if there is any summary
if self.summary_op is None:
return
summary_str = self.summary_op.eval(self.last_dp)
self.epoch_num += 1
self.writer.add_summary(summary_str, self.epoch_num)
#!/usr/bin/env python2
# -*- coding: UTF-8 -*-
# File: extension.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
import tensorflow as tf
import sys
import numpy as np
import os
from abc import abstractmethod
class Extension(object):
def init(self):
pass
@abstractmethod
def trigger(self):
pass
class PeriodicExtension(Extension):
def __init__(self, period):
self.__period = period
self.epoch_num = 0
def init(self):
pass
def trigger(self):
self.epoch_num += 1
if self.epoch_num % self.__period == 0:
self._trigger()
@abstractmethod
def _trigger(self):
pass
class OnehotClassificationValidation(PeriodicExtension):
"""
use with output: bxn probability
and label: (b,) vector
"""
def __init__(self, ds, prefix,
period=1,
input_var_name='input:0',
label_var_name='label:0',
output_var_name='output:0'):
super(OnehotClassificationValidation, self).__init__(period)
self.ds = ds
self.input_var_name = input_var_name
self.output_var_name = output_var_name
self.label_var_name = label_var_name
def init(self):
self.graph = tf.get_default_graph()
with tf.name_scope('validation'):
self.input_var = self.graph.get_tensor_by_name(self.input_var_name)
self.label_var = self.graph.get_tensor_by_name(self.label_var_name)
self.output_var = self.graph.get_tensor_by_name(self.output_var_name)
self.dropout_var = self.graph.get_tensor_by_name('dropout_prob:0')
correct = tf.equal(tf.cast(tf.argmax(self.output_var, 1), tf.int32),
self.label_var)
self.nr_correct_var = tf.reduce_sum(tf.cast(correct, tf.int32))
self.cost_var = self.graph.get_tensor_by_name('cost:0')
def _trigger(self):
cnt = 0
correct_stat = Accuracy()
sess = tf.get_default_session()
cost_sum = 0
for (img, label) in self.ds.get_data():
feed = {self.input_var: img,
self.label_var: label,
self.dropout_var: 1.0}
cnt += img.shape[0]
correct, cost = sess.run([self.nr_correct_var, self.cost_var],
feed_dict=feed)
correct_stat.feed(correct, cnt)
cost_sum += cost * cnt
cost_sum /= cnt
# TODO write to summary?
print "After epoch {}: acc={}, cost={}".format(
self.epoch_num, correct_stat.accuracy, cost_sum)
class PeriodicSaver(PeriodicExtension):
def __init__(self, log_dir, period=1):
super(PeriodicSaver, self).__init__(period)
self.path = os.path.join(log_dir, 'model')
def init(self):
self.saver = tf.train.Saver(max_to_keep=99999)
def _trigger(self):
self.saver.save(tf.get_default_session(), self.path,
global_step=self.epoch_num, latest_filename='latest')
#!/usr/bin/env python2
# -*- coding: UTF-8 -*-
# File: naming.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
DROPOUT_PROB_OP_NAME = 'dropout_prob'
DROPOUT_PROB_VAR_NAME = 'dropout_prob:0'
SUMMARY_WRITER_COLLECTION_KEY = 'summary_writer'
MERGE_SUMMARY_OP_NAME = 'MergeSummary/MergeSummary:0'
INPUT_VARS_KEY = 'INPUT_VARIABLES'
OUTPUT_VARS_KEY = 'OUTPUT_VARIABLES'
# export all upper case variables
all_local_names = locals().keys()
__all__ = [x for x in all_local_names if x.upper() == x]
......@@ -4,14 +4,18 @@
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
import numpy as np
__all__ = ['StatCounter', 'Accuracy']
class StatCounter(object):
def __init__(self):
self.values = []
self.reset()
def feed(self, v):
self.values.append(v)
def reset(self):
self.values = []
@property
def average(self):
return np.mean(self.values)
......@@ -22,6 +26,9 @@ class StatCounter(object):
class Accuracy(object):
def __init__(self):
self.reset()
def reset(self):
self.tot = 0
self.corr = 0
......
......@@ -5,3 +5,17 @@
import tensorflow as tf
__all__ = ['create_summary']
def create_summary(name, v):
# TODO support image or histogram
"""
Args: v: a value
"""
assert isinstance(name, basestring), type(name)
v = float(v)
s = tf.Summary()
s.value.add(tag=name, simple_value=v)
return s
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment