Commit 9aa2b11c authored by ppwwyyxx's avatar ppwwyyxx

running-average summary

parent 031dd698
...@@ -59,17 +59,24 @@ def get_model(inputs): ...@@ -59,17 +59,24 @@ def get_model(inputs):
y = one_hot(label, 10) y = one_hot(label, 10)
cost = tf.nn.softmax_cross_entropy_with_logits(fc1, y) cost = tf.nn.softmax_cross_entropy_with_logits(fc1, y)
cost = tf.reduce_mean(cost) cost = tf.reduce_mean(cost, name='cross_entropy_cost')
tf.add_to_collection(COST_VARS_KEY, cost)
# compute the number of correctly classified samples, for ValidationAccuracy to use # compute the number of correctly classified samples, for ValidationAccuracy to use
correct = tf.equal( correct = tf.equal(
tf.cast(tf.argmax(prob, 1), tf.int32), label) tf.cast(tf.argmax(prob, 1), tf.int32), label)
correct = tf.reduce_sum(tf.cast(correct, tf.int32), name='correct') correct = tf.cast(correct, tf.float32)
nr_correct = tf.reduce_sum(correct, name='correct')
tf.add_to_collection(SUMMARY_VARS_KEY,
tf.reduce_mean(correct, name='training_accuracy'))
# weight decay on all W of fc layers # weight decay on all W of fc layers
wd_cost = 1e-4 * regularize_cost('fc.*/W', tf.nn.l2_loss) wd_cost = tf.mul(1e-4,
regularize_cost('fc.*/W', tf.nn.l2_loss),
name='regularize_loss')
tf.add_to_collection(COST_VARS_KEY, wd_cost)
return [prob, correct], tf.add(cost, wd_cost, name='cost') return [prob, nr_correct], tf.add_n(tf.get_collection(COST_VARS_KEY), name='cost')
def main(): def main():
BATCH_SIZE = 128 BATCH_SIZE = 128
......
...@@ -12,11 +12,11 @@ def regularize_cost(regex, func): ...@@ -12,11 +12,11 @@ def regularize_cost(regex, func):
G = tf.get_default_graph() G = tf.get_default_graph()
params = G.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES) params = G.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)
cost = 0 costs = []
for p in params: for p in params:
name = p.name name = p.name
if re.search(regex, name): if re.search(regex, name):
print("Weight decay for {}".format(name)) print("Weight decay for {}".format(name))
cost += func(p) costs.append(func(p))
return cost return tf.add_n(costs)
...@@ -40,7 +40,21 @@ def start_train(config): ...@@ -40,7 +40,21 @@ def start_train(config):
for v in output_vars: for v in output_vars:
G.add_to_collection(OUTPUT_VARS_KEY, v) G.add_to_collection(OUTPUT_VARS_KEY, v)
train_op = optimizer.minimize(cost_var) # add some summary ops to the graph
cost_avg = tf.train.ExponentialMovingAverage(0.9, name='avg')
# TODO global step
cost_vars = tf.get_collection(COST_VARS_KEY)
summary_vars = tf.get_collection(SUMMARY_VARS_KEY)
vars_to_summary = cost_vars + [cost_var] + summary_vars
cost_avg_maintain_op = cost_avg.apply(vars_to_summary)
for c in vars_to_summary:
#tf.scalar_summary(c.op.name +' (raw)', c)
tf.scalar_summary(c.op.name, cost_avg.average(c))
# maintain average in each step
with tf.control_dependencies([cost_avg_maintain_op]):
grads = optimizer.compute_gradients(cost_var)
train_op = optimizer.apply_gradients(grads) # TODO global_step
sess = tf.Session(config=sess_config) sess = tf.Session(config=sess_config)
# start training # start training
......
...@@ -30,13 +30,11 @@ class Callback(object): ...@@ -30,13 +30,11 @@ class Callback(object):
outputs: list of output values after running this dp outputs: list of output values after running this dp
cost: the cost value after running this dp cost: the cost value after running this dp
""" """
pass
def trigger_epoch(self): def trigger_epoch(self):
""" """
Callback to be triggered after every epoch (full iteration of input dataset) Callback to be triggered after every epoch (full iteration of input dataset)
""" """
pass
class PeriodicCallback(Callback): class PeriodicCallback(Callback):
def __init__(self, period): def __init__(self, period):
...@@ -73,12 +71,12 @@ class SummaryWriter(Callback): ...@@ -73,12 +71,12 @@ class SummaryWriter(Callback):
def _before_train(self): def _before_train(self):
self.writer = tf.train.SummaryWriter( self.writer = tf.train.SummaryWriter(
self.log_dir, graph_def=self.sess.graph_def) self.log_dir, graph_def=self.sess.graph_def)
self.graph.add_to_collection(SUMMARY_WRITER_COLLECTION_KEY, self.writer) tf.add_to_collection(SUMMARY_WRITER_COLLECTION_KEY, self.writer)
# create some summary # create some summary
if self.hist_regex is not None: if self.hist_regex is not None:
import re import re
params = self.graph.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES) params = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)
for p in params: for p in params:
name = p.name name = p.name
if re.search(self.hist_regex, name): if re.search(self.hist_regex, name):
......
...@@ -11,6 +11,8 @@ MERGE_SUMMARY_OP_NAME = 'MergeSummary/MergeSummary:0' ...@@ -11,6 +11,8 @@ MERGE_SUMMARY_OP_NAME = 'MergeSummary/MergeSummary:0'
INPUT_VARS_KEY = 'INPUT_VARIABLES' INPUT_VARS_KEY = 'INPUT_VARIABLES'
OUTPUT_VARS_KEY = 'OUTPUT_VARIABLES' OUTPUT_VARS_KEY = 'OUTPUT_VARIABLES'
COST_VARS_KEY = 'COST_VARIABLES'
SUMMARY_VARS_KEY = 'SUMMARY_VARIABLES' # define extra variable to summarize
# export all upper case variables # export all upper case variables
all_local_names = locals().keys() all_local_names = locals().keys()
......
...@@ -31,7 +31,7 @@ class ValidationAccuracy(PeriodicCallback): ...@@ -31,7 +31,7 @@ class ValidationAccuracy(PeriodicCallback):
return self.graph.get_tensor_by_name(name) return self.graph.get_tensor_by_name(name)
def _before_train(self): def _before_train(self):
self.input_vars = self.graph.get_collection(INPUT_VARS_KEY) self.input_vars = tf.get_collection(INPUT_VARS_KEY)
self.dropout_var = self.get_tensor(DROPOUT_PROB_VAR_NAME) self.dropout_var = self.get_tensor(DROPOUT_PROB_VAR_NAME)
self.correct_var = self.get_tensor(self.correct_var_name) self.correct_var = self.get_tensor(self.correct_var_name)
self.cost_var = self.get_tensor(self.cost_var_name) self.cost_var = self.get_tensor(self.cost_var_name)
...@@ -56,59 +56,60 @@ class ValidationAccuracy(PeriodicCallback): ...@@ -56,59 +56,60 @@ class ValidationAccuracy(PeriodicCallback):
cost_avg = cost_sum / cnt cost_avg = cost_sum / cnt
self.writer.add_summary( self.writer.add_summary(
create_summary('{} accuracy'.format(self.prefix), create_summary('{}_accuracy'.format(self.prefix),
correct_stat.accuracy), correct_stat.accuracy),
self.epoch_num) self.epoch_num)
self.writer.add_summary( self.writer.add_summary(
create_summary('{} cost'.format(self.prefix), create_summary('{}_cost'.format(self.prefix),
cost_avg), cost_avg),
self.epoch_num) self.epoch_num)
print "{} validation after epoch {}: acc={}, cost={}".format( print "{} validation after epoch {}: acc={}, cost={}".format(
self.prefix, self.epoch_num, correct_stat.accuracy, cost_avg) self.prefix, self.epoch_num, correct_stat.accuracy, cost_avg)
class TrainingAccuracy(Callback): # use SUMMARY_VARIABLES instead
""" #class TrainingAccuracy(Callback):
Record the accuracy and cost during each step of trianing. #"""
The result is a running average, thus not directly comparable with ValidationAccuracy #Record the accuracy and cost during each step of trianing.
""" #The result is a running average, thus not directly comparable with ValidationAccuracy
def __init__(self, batch_size, correct_var_name='correct:0'): #"""
""" #def __init__(self, batch_size, correct_var_name='correct:0'):
correct_var: number of correct sample in this batch #"""
""" #correct_var: number of correct sample in this batch
self.correct_var_name = correct_var_name #"""
self.batch_size = batch_size #self.correct_var_name = correct_var_name
self.epoch_num = 0 #self.batch_size = batch_size
#self.epoch_num = 0
def _before_train(self): #def _before_train(self):
self.writer = tf.get_collection(SUMMARY_WRITER_COLLECTION_KEY)[0] #self.writer = tf.get_collection(SUMMARY_WRITER_COLLECTION_KEY)[0]
output_vars = self.graph.get_collection(OUTPUT_VARS_KEY) #output_vars = tf.get_collection(OUTPUT_VARS_KEY)
for idx, var in enumerate(output_vars): #for idx, var in enumerate(output_vars):
if var.name == self.correct_var_name: #if var.name == self.correct_var_name:
self.correct_output_idx = idx #self.correct_output_idx = idx
break #break
else: #else:
raise RuntimeError( #raise RuntimeError(
"'correct' variable must be in the model outputs to use TrainingAccuracy") #"'correct' variable must be in the model outputs to use TrainingAccuracy")
self.running_cost = StatCounter() #self.running_cost = StatCounter()
self.running_acc = Accuracy() #self.running_acc = Accuracy()
def trigger_step(self, inputs, outputs, cost): #def trigger_step(self, inputs, outputs, cost):
self.running_cost.feed(cost) #self.running_cost.feed(cost)
self.running_acc.feed( #self.running_acc.feed(
outputs[self.correct_output_idx], #outputs[self.correct_output_idx],
self.batch_size) # assume batch input #self.batch_size) # assume batch input
def trigger_epoch(self): #def trigger_epoch(self):
self.epoch_num += 1 #self.epoch_num += 1
print('Training average in Epoch {}: cost={}, acc={}'.format #print('Training average in Epoch {}: cost={}, acc={}'.format
(self.epoch_num, self.running_cost.average, #(self.epoch_num, self.running_cost.average,
self.running_acc.accuracy)) #self.running_acc.accuracy))
self.writer.add_summary( #self.writer.add_summary(
create_summary('training average accuracy', self.running_acc.accuracy), #create_summary('training average accuracy', self.running_acc.accuracy),
self.epoch_num) #self.epoch_num)
self.writer.add_summary( #self.writer.add_summary(
create_summary('training average cost', self.running_cost.average), #create_summary('training average cost', self.running_cost.average),
self.epoch_num) #self.epoch_num)
self.running_cost.reset() #self.running_cost.reset()
self.running_acc.reset() #self.running_acc.reset()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment