Commit 341a5d43 authored by ppwwyyxx's avatar ppwwyyxx

inputs, outputs, and better ext

parent dea224ac
...@@ -24,18 +24,22 @@ NUM_CLASS = 10 ...@@ -24,18 +24,22 @@ NUM_CLASS = 10
batch_size = 128 batch_size = 128
LOG_DIR = 'train_log' LOG_DIR = 'train_log'
def get_model(input, label): def get_model(inputs):
""" """
Args: Args:
input: bx28x28 inputs: a list of input variable,
label: bx1 integer e.g.: [input, label] with:
input: bx28x28
label: bx1 integer
Returns: Returns:
(output, cost) (outputs, cost)
output: variable outputs: a list of output variable
cost: scalar variable cost: scalar variable
""" """
# use this dropout variable! it will be set to 1 at test time # use this dropout variable! it will be set to 1 at test time
keep_prob = tf.placeholder(tf.float32, shape=tuple(), name='dropout_prob') keep_prob = tf.placeholder(tf.float32, shape=tuple(), name=DROPOUT_PROB_OP_NAME)
input, label = inputs
input = tf.reshape(input, [-1, IMAGE_SIZE, IMAGE_SIZE, 1]) input = tf.reshape(input, [-1, IMAGE_SIZE, IMAGE_SIZE, 1])
conv0 = Conv2D('conv0', input, out_channel=32, kernel_shape=5, conv0 = Conv2D('conv0', input, out_channel=32, kernel_shape=5,
...@@ -62,17 +66,23 @@ def get_model(input, label): ...@@ -62,17 +66,23 @@ def get_model(input, label):
cost = tf.nn.softmax_cross_entropy_with_logits(fc1, y) cost = tf.nn.softmax_cross_entropy_with_logits(fc1, y)
cost = tf.reduce_mean(cost, name='cost') cost = tf.reduce_mean(cost, name='cost')
tf.scalar_summary(cost.op.name, cost) # number of correctly classified samples
return prob, cost correct = tf.equal(
tf.cast(tf.argmax(prob, 1), tf.int32), label)
correct = tf.reduce_sum(tf.cast(correct, tf.int32), name='correct')
return [prob, correct], cost
def main(): def main():
dataset_train = BatchData(Mnist('train'), batch_size) dataset_train = BatchData(Mnist('train'), batch_size)
dataset_test = BatchData(Mnist('test'), batch_size, remainder=True) dataset_test = BatchData(Mnist('test'), batch_size, remainder=True)
extensions = [ callbacks = [
OnehotClassificationValidation( SummaryWriter(LOG_DIR),
AccuracyValidation(
dataset_test, dataset_test,
prefix='test', period=2), prefix='test', period=1),
PeriodicSaver(LOG_DIR, period=2) TrainingAccuracy(),
PeriodicSaver(LOG_DIR, period=1)
] ]
optimizer = tf.train.AdamOptimizer(1e-4) optimizer = tf.train.AdamOptimizer(1e-4)
sess_config = tf.ConfigProto() sess_config = tf.ConfigProto()
...@@ -80,38 +90,42 @@ def main(): ...@@ -80,38 +90,42 @@ def main():
with tf.Graph().as_default(): with tf.Graph().as_default():
G = tf.get_default_graph() G = tf.get_default_graph()
input_var = tf.placeholder(tf.float32, shape=(None, IMAGE_SIZE, IMAGE_SIZE), name='input') image_var = tf.placeholder(tf.float32, shape=(None, IMAGE_SIZE, IMAGE_SIZE), name='input')
label_var = tf.placeholder(tf.int32, shape=(None,), name='label') label_var = tf.placeholder(tf.int32, shape=(None,), name='label')
prob, cost = get_model(input_var, label_var) input_vars = [image_var, label_var]
train_op = optimizer.minimize(cost) for v in input_vars:
G.add_to_collection(INPUT_VARS_KEY, v)
output_vars, cost_var = get_model(input_vars)
for v in output_vars:
G.add_to_collection(OUTPUT_VARS_KEY, v)
for ext in extensions: train_op = optimizer.minimize(cost_var)
ext.init()
summary_op = tf.merge_all_summaries()
sess = tf.Session(config=sess_config) sess = tf.Session(config=sess_config)
sess.run(tf.initialize_all_variables()) sess.run(tf.initialize_all_variables())
summary_writer = tf.train.SummaryWriter(LOG_DIR, graph_def=sess.graph_def)
keep_prob = G.get_tensor_by_name('dropout_prob:0')
with sess.as_default(): with sess.as_default():
for epoch in count(1): for ext in callbacks:
running_cost = StatCounter() ext.before_train()
for (img, label) in dataset_train.get_data():
feed = {input_var: img,
label_var: label,
keep_prob: 0.5}
_, cost_value = sess.run([train_op, cost], feed_dict=feed) keep_prob_var = G.get_tensor_by_name(DROPOUT_PROB_VAR_NAME)
running_cost.feed(cost_value) for epoch in count(1):
for dp in dataset_train.get_data():
print('Epoch %d: avg cost = %.2f' % (epoch, running_cost.average)) feed = {keep_prob_var: 0.5}
summary_str = summary_op.eval(feed_dict=feed) feed.update(dict(zip(input_vars, dp)))
summary_writer.add_summary(summary_str, epoch)
results = sess.run(
for ext in extensions: [train_op, cost_var] + output_vars, feed_dict=feed)
ext.trigger() cost = results[1]
outputs = results[2:]
assert len(outputs) == len(output_vars)
for cb in callbacks:
cb.trigger_step(dp, outputs, cost)
for cb in callbacks:
cb.trigger_epoch()
summary_writer.close()
......
#!/usr/bin/env python2
# -*- coding: UTF-8 -*-
# File: callback.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
import tensorflow as tf
import sys
import numpy as np
import os
from abc import abstractmethod
from .stat import *
from .utils import *
from .naming import *
class Callback(object):
def before_train(self):
self.graph = tf.get_default_graph()
self.sess = tf.get_default_session()
self._before_train()
def _before_train(self):
"""
Called before training
"""
# trigger after every step
def trigger_step(self, dp, outputs, cost):
"""
Args:
dp: the input dict fed into the graph
outputs: list of output values after running this dp
cost: the cost value after running this dp
"""
pass
# trigger after every epoch
def trigger_epoch(self):
pass
class PeriodicCallback(Callback):
def __init__(self, period):
self.__period = period
self.epoch_num = 0
def trigger_epoch(self):
self.epoch_num += 1
if self.epoch_num % self.__period == 0:
self._trigger()
@abstractmethod
def _trigger(self):
pass
class AccuracyValidation(PeriodicCallback):
"""
Validate the accuracy for the given correct and cost variable
Use under the following setup:
correct_var: integer, number of correct samples in this batch
ds: batched dataset
"""
def __init__(self, ds, prefix,
period=1,
correct_var_name='correct:0',
cost_var_name='cost:0'):
super(AccuracyValidation, self).__init__(period)
self.ds = ds
self.prefix = prefix
self.correct_var_name = correct_var_name
self.cost_var_name = cost_var_name
def get_tensor(self, name):
return self.graph.get_tensor_by_name(name)
def _before_train(self):
self.input_vars = self.graph.get_collection(INPUT_VARS_KEY)
self.dropout_var = self.get_tensor(DROPOUT_PROB_VAR_NAME)
self.correct_var = self.get_tensor(self.correct_var_name)
self.cost_var = self.get_tensor(self.cost_var_name)
try:
self.writer = tf.get_collection(SUMMARY_WRITER_COLLECTION_KEY)[0]
except Exception as e:
print "SummaryWriter should be the first extension!"
raise
def _trigger(self):
cnt = 0
correct_stat = Accuracy()
cost_sum = 0
for dp in self.ds.get_data():
feed = {self.dropout_var: 1.0}
feed.update(dict(zip(self.input_vars, dp)))
batch_size = dp[0].shape[0] # assume batched input
cnt += batch_size
correct, cost = self.sess.run(
[self.correct_var, self.cost_var], feed_dict=feed)
correct_stat.feed(correct, batch_size)
# each batch might not have the same size in validation
cost_sum += cost * batch_size
cost_avg = cost_sum / cnt
self.writer.add_summary(
create_summary('{} accuracy'.format(self.prefix),
correct_stat.accuracy),
self.epoch_num)
self.writer.add_summary(
create_summary('{} cost'.format(self.prefix),
cost_avg),
self.epoch_num)
print "{} validation after epoch {}: acc={}, cost={}".format(
self.prefix, self.epoch_num, correct_stat.accuracy, cost_avg)
class TrainingAccuracy(Callback):
def __init__(self, correct_var_name='correct:0'):
"""
correct_var: number of correct sample in this batch
"""
self.correct_var_name = correct_var_name
self.epoch_num = 0
def _before_train(self):
try:
self.writer = tf.get_collection(SUMMARY_WRITER_COLLECTION_KEY)[0]
except Exception as e:
print "SummaryWriter should be the first extension!"
raise
output_vars = self.graph.get_collection(OUTPUT_VARS_KEY)
for idx, var in enumerate(output_vars):
if var.name == self.correct_var_name:
self.correct_output_idx = idx
break
else:
raise RuntimeError(
"'correct' variable must be in the model outputs to use TrainingAccuracy")
self.running_cost = StatCounter()
self.running_acc = Accuracy()
def trigger_step(self, inputs, outputs, cost):
self.running_cost.feed(cost)
self.running_acc.feed(
outputs[self.correct_output_idx],
inputs[0].shape[0]) # assume batch input
def trigger_epoch(self):
self.epoch_num += 1
print('Training average in Epoch {}: cost={}, acc={}'.format
(self.epoch_num, self.running_cost.average,
self.running_acc.accuracy))
self.writer.add_summary(
create_summary('training average accuracy', self.running_acc.accuracy),
self.epoch_num)
self.writer.add_summary(
create_summary('training average cost', self.running_cost.average),
self.epoch_num)
self.running_cost.reset()
self.running_acc.reset()
class PeriodicSaver(PeriodicCallback):
def __init__(self, log_dir, period=1):
super(PeriodicSaver, self).__init__(period)
self.path = os.path.join(log_dir, 'model')
def _before_train(self):
self.saver = tf.train.Saver(max_to_keep=99999)
def _trigger(self):
self.saver.save(tf.get_default_session(), self.path,
global_step=self.epoch_num, latest_filename='latest')
class SummaryWriter(Callback):
def __init__(self, log_dir):
self.log_dir = log_dir
self.epoch_num = 0
def _before_train(self):
sess = tf.get_default_session()
graph = tf.get_default_graph()
self.writer = tf.train.SummaryWriter(
self.log_dir, graph_def=sess.graph_def)
graph.add_to_collection(SUMMARY_WRITER_COLLECTION_KEY, self.writer)
self.summary_op = tf.merge_all_summaries()
def trigger_step(self, dp, outputs, cost):
self.last_dp = dp
def trigger_epoch(self):
# check if there is any summary
if self.summary_op is None:
return
summary_str = self.summary_op.eval(self.last_dp)
self.epoch_num += 1
self.writer.add_summary(summary_str, self.epoch_num)
#!/usr/bin/env python2
# -*- coding: UTF-8 -*-
# File: extension.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
import tensorflow as tf
import sys
import numpy as np
import os
from abc import abstractmethod
class Extension(object):
def init(self):
pass
@abstractmethod
def trigger(self):
pass
class PeriodicExtension(Extension):
def __init__(self, period):
self.__period = period
self.epoch_num = 0
def init(self):
pass
def trigger(self):
self.epoch_num += 1
if self.epoch_num % self.__period == 0:
self._trigger()
@abstractmethod
def _trigger(self):
pass
class OnehotClassificationValidation(PeriodicExtension):
"""
use with output: bxn probability
and label: (b,) vector
"""
def __init__(self, ds, prefix,
period=1,
input_var_name='input:0',
label_var_name='label:0',
output_var_name='output:0'):
super(OnehotClassificationValidation, self).__init__(period)
self.ds = ds
self.input_var_name = input_var_name
self.output_var_name = output_var_name
self.label_var_name = label_var_name
def init(self):
self.graph = tf.get_default_graph()
with tf.name_scope('validation'):
self.input_var = self.graph.get_tensor_by_name(self.input_var_name)
self.label_var = self.graph.get_tensor_by_name(self.label_var_name)
self.output_var = self.graph.get_tensor_by_name(self.output_var_name)
self.dropout_var = self.graph.get_tensor_by_name('dropout_prob:0')
correct = tf.equal(tf.cast(tf.argmax(self.output_var, 1), tf.int32),
self.label_var)
self.nr_correct_var = tf.reduce_sum(tf.cast(correct, tf.int32))
self.cost_var = self.graph.get_tensor_by_name('cost:0')
def _trigger(self):
cnt = 0
correct_stat = Accuracy()
sess = tf.get_default_session()
cost_sum = 0
for (img, label) in self.ds.get_data():
feed = {self.input_var: img,
self.label_var: label,
self.dropout_var: 1.0}
cnt += img.shape[0]
correct, cost = sess.run([self.nr_correct_var, self.cost_var],
feed_dict=feed)
correct_stat.feed(correct, cnt)
cost_sum += cost * cnt
cost_sum /= cnt
# TODO write to summary?
print "After epoch {}: acc={}, cost={}".format(
self.epoch_num, correct_stat.accuracy, cost_sum)
class PeriodicSaver(PeriodicExtension):
def __init__(self, log_dir, period=1):
super(PeriodicSaver, self).__init__(period)
self.path = os.path.join(log_dir, 'model')
def init(self):
self.saver = tf.train.Saver(max_to_keep=99999)
def _trigger(self):
self.saver.save(tf.get_default_session(), self.path,
global_step=self.epoch_num, latest_filename='latest')
#!/usr/bin/env python2
# -*- coding: UTF-8 -*-
# File: naming.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
DROPOUT_PROB_OP_NAME = 'dropout_prob'
DROPOUT_PROB_VAR_NAME = 'dropout_prob:0'
SUMMARY_WRITER_COLLECTION_KEY = 'summary_writer'
MERGE_SUMMARY_OP_NAME = 'MergeSummary/MergeSummary:0'
INPUT_VARS_KEY = 'INPUT_VARIABLES'
OUTPUT_VARS_KEY = 'OUTPUT_VARIABLES'
# export all upper case variables
all_local_names = locals().keys()
__all__ = [x for x in all_local_names if x.upper() == x]
...@@ -4,14 +4,18 @@ ...@@ -4,14 +4,18 @@
# Author: Yuxin Wu <ppwwyyxx@gmail.com> # Author: Yuxin Wu <ppwwyyxx@gmail.com>
import numpy as np import numpy as np
__all__ = ['StatCounter', 'Accuracy']
class StatCounter(object): class StatCounter(object):
def __init__(self): def __init__(self):
self.values = [] self.reset()
def feed(self, v): def feed(self, v):
self.values.append(v) self.values.append(v)
def reset(self):
self.values = []
@property @property
def average(self): def average(self):
return np.mean(self.values) return np.mean(self.values)
...@@ -22,6 +26,9 @@ class StatCounter(object): ...@@ -22,6 +26,9 @@ class StatCounter(object):
class Accuracy(object): class Accuracy(object):
def __init__(self): def __init__(self):
self.reset()
def reset(self):
self.tot = 0 self.tot = 0
self.corr = 0 self.corr = 0
......
...@@ -5,3 +5,17 @@ ...@@ -5,3 +5,17 @@
import tensorflow as tf import tensorflow as tf
__all__ = ['create_summary']
def create_summary(name, v):
# TODO support image or histogram
"""
Args: v: a value
"""
assert isinstance(name, basestring), type(name)
v = float(v)
s = tf.Summary()
s.value.add(tag=name, simple_value=v)
return s
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment