Commit 9aa2b11c authored by ppwwyyxx's avatar ppwwyyxx

running-average summary

parent 031dd698
......@@ -59,17 +59,24 @@ def get_model(inputs):
y = one_hot(label, 10)
cost = tf.nn.softmax_cross_entropy_with_logits(fc1, y)
cost = tf.reduce_mean(cost)
cost = tf.reduce_mean(cost, name='cross_entropy_cost')
tf.add_to_collection(COST_VARS_KEY, cost)
# compute the number of correctly classified samples, for ValidationAccuracy to use
correct = tf.equal(
tf.cast(tf.argmax(prob, 1), tf.int32), label)
correct = tf.reduce_sum(tf.cast(correct, tf.int32), name='correct')
correct = tf.cast(correct, tf.float32)
nr_correct = tf.reduce_sum(correct, name='correct')
tf.add_to_collection(SUMMARY_VARS_KEY,
tf.reduce_mean(correct, name='training_accuracy'))
# weight decay on all W of fc layers
wd_cost = 1e-4 * regularize_cost('fc.*/W', tf.nn.l2_loss)
wd_cost = tf.mul(1e-4,
regularize_cost('fc.*/W', tf.nn.l2_loss),
name='regularize_loss')
tf.add_to_collection(COST_VARS_KEY, wd_cost)
return [prob, correct], tf.add(cost, wd_cost, name='cost')
return [prob, nr_correct], tf.add_n(tf.get_collection(COST_VARS_KEY), name='cost')
def main():
BATCH_SIZE = 128
......
......@@ -12,11 +12,11 @@ def regularize_cost(regex, func):
G = tf.get_default_graph()
params = G.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)
cost = 0
costs = []
for p in params:
name = p.name
if re.search(regex, name):
print("Weight decay for {}".format(name))
cost += func(p)
return cost
costs.append(func(p))
return tf.add_n(costs)
......@@ -40,7 +40,21 @@ def start_train(config):
for v in output_vars:
G.add_to_collection(OUTPUT_VARS_KEY, v)
train_op = optimizer.minimize(cost_var)
# add some summary ops to the graph
cost_avg = tf.train.ExponentialMovingAverage(0.9, name='avg')
# TODO global step
cost_vars = tf.get_collection(COST_VARS_KEY)
summary_vars = tf.get_collection(SUMMARY_VARS_KEY)
vars_to_summary = cost_vars + [cost_var] + summary_vars
cost_avg_maintain_op = cost_avg.apply(vars_to_summary)
for c in vars_to_summary:
#tf.scalar_summary(c.op.name +' (raw)', c)
tf.scalar_summary(c.op.name, cost_avg.average(c))
# maintain average in each step
with tf.control_dependencies([cost_avg_maintain_op]):
grads = optimizer.compute_gradients(cost_var)
train_op = optimizer.apply_gradients(grads) # TODO global_step
sess = tf.Session(config=sess_config)
# start training
......
......@@ -30,13 +30,11 @@ class Callback(object):
outputs: list of output values after running this dp
cost: the cost value after running this dp
"""
pass
def trigger_epoch(self):
"""
Callback to be triggered after every epoch (full iteration of input dataset)
"""
pass
class PeriodicCallback(Callback):
def __init__(self, period):
......@@ -73,12 +71,12 @@ class SummaryWriter(Callback):
def _before_train(self):
self.writer = tf.train.SummaryWriter(
self.log_dir, graph_def=self.sess.graph_def)
self.graph.add_to_collection(SUMMARY_WRITER_COLLECTION_KEY, self.writer)
tf.add_to_collection(SUMMARY_WRITER_COLLECTION_KEY, self.writer)
# create some summary
if self.hist_regex is not None:
import re
params = self.graph.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)
params = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)
for p in params:
name = p.name
if re.search(self.hist_regex, name):
......
......@@ -11,6 +11,8 @@ MERGE_SUMMARY_OP_NAME = 'MergeSummary/MergeSummary:0'
INPUT_VARS_KEY = 'INPUT_VARIABLES'
OUTPUT_VARS_KEY = 'OUTPUT_VARIABLES'
COST_VARS_KEY = 'COST_VARIABLES'
SUMMARY_VARS_KEY = 'SUMMARY_VARIABLES' # define extra variable to summarize
# export all upper case variables
all_local_names = locals().keys()
......
......@@ -31,7 +31,7 @@ class ValidationAccuracy(PeriodicCallback):
return self.graph.get_tensor_by_name(name)
def _before_train(self):
self.input_vars = self.graph.get_collection(INPUT_VARS_KEY)
self.input_vars = tf.get_collection(INPUT_VARS_KEY)
self.dropout_var = self.get_tensor(DROPOUT_PROB_VAR_NAME)
self.correct_var = self.get_tensor(self.correct_var_name)
self.cost_var = self.get_tensor(self.cost_var_name)
......@@ -56,59 +56,60 @@ class ValidationAccuracy(PeriodicCallback):
cost_avg = cost_sum / cnt
self.writer.add_summary(
create_summary('{} accuracy'.format(self.prefix),
create_summary('{}_accuracy'.format(self.prefix),
correct_stat.accuracy),
self.epoch_num)
self.writer.add_summary(
create_summary('{} cost'.format(self.prefix),
create_summary('{}_cost'.format(self.prefix),
cost_avg),
self.epoch_num)
print "{} validation after epoch {}: acc={}, cost={}".format(
self.prefix, self.epoch_num, correct_stat.accuracy, cost_avg)
class TrainingAccuracy(Callback):
"""
Record the accuracy and cost during each step of trianing.
The result is a running average, thus not directly comparable with ValidationAccuracy
"""
def __init__(self, batch_size, correct_var_name='correct:0'):
"""
correct_var: number of correct sample in this batch
"""
self.correct_var_name = correct_var_name
self.batch_size = batch_size
self.epoch_num = 0
# use SUMMARY_VARIABLES instead
#class TrainingAccuracy(Callback):
#"""
#Record the accuracy and cost during each step of trianing.
#The result is a running average, thus not directly comparable with ValidationAccuracy
#"""
#def __init__(self, batch_size, correct_var_name='correct:0'):
#"""
#correct_var: number of correct sample in this batch
#"""
#self.correct_var_name = correct_var_name
#self.batch_size = batch_size
#self.epoch_num = 0
def _before_train(self):
self.writer = tf.get_collection(SUMMARY_WRITER_COLLECTION_KEY)[0]
output_vars = self.graph.get_collection(OUTPUT_VARS_KEY)
for idx, var in enumerate(output_vars):
if var.name == self.correct_var_name:
self.correct_output_idx = idx
break
else:
raise RuntimeError(
"'correct' variable must be in the model outputs to use TrainingAccuracy")
self.running_cost = StatCounter()
self.running_acc = Accuracy()
#def _before_train(self):
#self.writer = tf.get_collection(SUMMARY_WRITER_COLLECTION_KEY)[0]
#output_vars = tf.get_collection(OUTPUT_VARS_KEY)
#for idx, var in enumerate(output_vars):
#if var.name == self.correct_var_name:
#self.correct_output_idx = idx
#break
#else:
#raise RuntimeError(
#"'correct' variable must be in the model outputs to use TrainingAccuracy")
#self.running_cost = StatCounter()
#self.running_acc = Accuracy()
def trigger_step(self, inputs, outputs, cost):
self.running_cost.feed(cost)
self.running_acc.feed(
outputs[self.correct_output_idx],
self.batch_size) # assume batch input
#def trigger_step(self, inputs, outputs, cost):
#self.running_cost.feed(cost)
#self.running_acc.feed(
#outputs[self.correct_output_idx],
#self.batch_size) # assume batch input
def trigger_epoch(self):
self.epoch_num += 1
print('Training average in Epoch {}: cost={}, acc={}'.format
(self.epoch_num, self.running_cost.average,
self.running_acc.accuracy))
self.writer.add_summary(
create_summary('training average accuracy', self.running_acc.accuracy),
self.epoch_num)
self.writer.add_summary(
create_summary('training average cost', self.running_cost.average),
self.epoch_num)
#def trigger_epoch(self):
#self.epoch_num += 1
#print('Training average in Epoch {}: cost={}, acc={}'.format
#(self.epoch_num, self.running_cost.average,
#self.running_acc.accuracy))
#self.writer.add_summary(
#create_summary('training average accuracy', self.running_acc.accuracy),
#self.epoch_num)
#self.writer.add_summary(
#create_summary('training average cost', self.running_cost.average),
#self.epoch_num)
self.running_cost.reset()
self.running_acc.reset()
#self.running_cost.reset()
#self.running_acc.reset()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment