Commit 1f657fcb authored by ppwwyyxx's avatar ppwwyyxx

monitor training accuracy with summary

parent 9aa2b11c
...@@ -59,16 +59,19 @@ def get_model(inputs): ...@@ -59,16 +59,19 @@ def get_model(inputs):
y = one_hot(label, 10) y = one_hot(label, 10)
cost = tf.nn.softmax_cross_entropy_with_logits(fc1, y) cost = tf.nn.softmax_cross_entropy_with_logits(fc1, y)
cost = tf.reduce_mean(cost, name='cross_entropy_cost') cost = tf.reduce_mean(cost, name='cross_entropy_loss')
tf.add_to_collection(COST_VARS_KEY, cost) tf.add_to_collection(COST_VARS_KEY, cost)
# compute the number of correctly classified samples, for ValidationAccuracy to use # compute the number of correctly classified samples, for ValidationAccuracy to use at test time
correct = tf.equal( correct = tf.equal(
tf.cast(tf.argmax(prob, 1), tf.int32), label) tf.cast(tf.argmax(prob, 1), tf.int32), label)
correct = tf.cast(correct, tf.float32) correct = tf.cast(correct, tf.float32)
nr_correct = tf.reduce_sum(correct, name='correct') nr_correct = tf.reduce_sum(correct, name='correct')
tf.add_to_collection(SUMMARY_VARS_KEY,
tf.reduce_mean(correct, name='training_accuracy')) # monitor training accuracy
tf.add_to_collection(
SUMMARY_VARS_KEY,
tf.reduce_mean(correct, name='train_accuracy'))
# weight decay on all W of fc layers # weight decay on all W of fc layers
wd_cost = tf.mul(1e-4, wd_cost = tf.mul(1e-4,
...@@ -97,9 +100,9 @@ def main(): ...@@ -97,9 +100,9 @@ def main():
dataset_train=dataset_train, dataset_train=dataset_train,
optimizer=tf.train.AdamOptimizer(1e-4), optimizer=tf.train.AdamOptimizer(1e-4),
callbacks=[ callbacks=[
TrainingAccuracy(batch_size=BATCH_SIZE), ValidationAccuracy(
ValidationAccuracy(dataset_test, dataset_test,
prefix='test', period=1), prefix='test'),
PeriodicSaver(LOG_DIR, period=1), PeriodicSaver(LOG_DIR, period=1),
SummaryWriter(LOG_DIR, histogram_regex='.*/W'), SummaryWriter(LOG_DIR, histogram_regex='.*/W'),
], ],
......
...@@ -40,21 +40,22 @@ def start_train(config): ...@@ -40,21 +40,22 @@ def start_train(config):
for v in output_vars: for v in output_vars:
G.add_to_collection(OUTPUT_VARS_KEY, v) G.add_to_collection(OUTPUT_VARS_KEY, v)
global_step_var = tf.Variable(0, trainable=False, name='global_step')
# add some summary ops to the graph # add some summary ops to the graph
cost_avg = tf.train.ExponentialMovingAverage(0.9, name='avg') averager = tf.train.ExponentialMovingAverage(
# TODO global step 0.9, num_updates=global_step_var, name='avg')
cost_vars = tf.get_collection(COST_VARS_KEY) vars_to_summary = [cost_var] + \
summary_vars = tf.get_collection(SUMMARY_VARS_KEY) tf.get_collection(SUMMARY_VARS_KEY) + \
vars_to_summary = cost_vars + [cost_var] + summary_vars tf.get_collection(COST_VARS_KEY)
cost_avg_maintain_op = cost_avg.apply(vars_to_summary) avg_maintain_op = averager.apply(vars_to_summary)
for c in vars_to_summary: for c in vars_to_summary:
#tf.scalar_summary(c.op.name +' (raw)', c) tf.scalar_summary(c.op.name, averager.average(c))
tf.scalar_summary(c.op.name, cost_avg.average(c))
# maintain average in each step # maintain average in each step
with tf.control_dependencies([cost_avg_maintain_op]): with tf.control_dependencies([avg_maintain_op]):
grads = optimizer.compute_gradients(cost_var) grads = optimizer.compute_gradients(cost_var)
train_op = optimizer.apply_gradients(grads) # TODO global_step train_op = optimizer.apply_gradients(grads, global_step_var)
sess = tf.Session(config=sess_config) sess = tf.Session(config=sess_config)
# start training # start training
......
...@@ -11,8 +11,8 @@ MERGE_SUMMARY_OP_NAME = 'MergeSummary/MergeSummary:0' ...@@ -11,8 +11,8 @@ MERGE_SUMMARY_OP_NAME = 'MergeSummary/MergeSummary:0'
INPUT_VARS_KEY = 'INPUT_VARIABLES' INPUT_VARS_KEY = 'INPUT_VARIABLES'
OUTPUT_VARS_KEY = 'OUTPUT_VARIABLES' OUTPUT_VARS_KEY = 'OUTPUT_VARIABLES'
COST_VARS_KEY = 'COST_VARIABLES' COST_VARS_KEY = 'COST_VARIABLES' # keep track of each individual cost
SUMMARY_VARS_KEY = 'SUMMARY_VARIABLES' # define extra variable to summarize SUMMARY_VARS_KEY = 'SUMMARY_VARIABLES' # extra variables to summarize during training
# export all upper case variables # export all upper case variables
all_local_names = locals().keys() all_local_names = locals().keys()
......
...@@ -65,51 +65,3 @@ class ValidationAccuracy(PeriodicCallback): ...@@ -65,51 +65,3 @@ class ValidationAccuracy(PeriodicCallback):
self.epoch_num) self.epoch_num)
print "{} validation after epoch {}: acc={}, cost={}".format( print "{} validation after epoch {}: acc={}, cost={}".format(
self.prefix, self.epoch_num, correct_stat.accuracy, cost_avg) self.prefix, self.epoch_num, correct_stat.accuracy, cost_avg)
# use SUMMARY_VARIABLES instead
#class TrainingAccuracy(Callback):
#"""
#Record the accuracy and cost during each step of trianing.
#The result is a running average, thus not directly comparable with ValidationAccuracy
#"""
#def __init__(self, batch_size, correct_var_name='correct:0'):
#"""
#correct_var: number of correct sample in this batch
#"""
#self.correct_var_name = correct_var_name
#self.batch_size = batch_size
#self.epoch_num = 0
#def _before_train(self):
#self.writer = tf.get_collection(SUMMARY_WRITER_COLLECTION_KEY)[0]
#output_vars = tf.get_collection(OUTPUT_VARS_KEY)
#for idx, var in enumerate(output_vars):
#if var.name == self.correct_var_name:
#self.correct_output_idx = idx
#break
#else:
#raise RuntimeError(
#"'correct' variable must be in the model outputs to use TrainingAccuracy")
#self.running_cost = StatCounter()
#self.running_acc = Accuracy()
#def trigger_step(self, inputs, outputs, cost):
#self.running_cost.feed(cost)
#self.running_acc.feed(
#outputs[self.correct_output_idx],
#self.batch_size) # assume batch input
#def trigger_epoch(self):
#self.epoch_num += 1
#print('Training average in Epoch {}: cost={}, acc={}'.format
#(self.epoch_num, self.running_cost.average,
#self.running_acc.accuracy))
#self.writer.add_summary(
#create_summary('training average accuracy', self.running_acc.accuracy),
#self.epoch_num)
#self.writer.add_summary(
#create_summary('training average cost', self.running_cost.average),
#self.epoch_num)
#self.running_cost.reset()
#self.running_acc.reset()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment