Commit 33b212c0 authored by Yuxin Wu's avatar Yuxin Wu

print summary

parent a976e871
...@@ -134,7 +134,7 @@ def get_config(): ...@@ -134,7 +134,7 @@ def get_config():
dataset=dataset_train, dataset=dataset_train,
optimizer=tf.train.GradientDescentOptimizer(lr), optimizer=tf.train.GradientDescentOptimizer(lr),
callbacks=Callbacks([ callbacks=Callbacks([
SummaryWriter(), SummaryWriter(print_tag=['train_cost', 'train_error']),
PeriodicSaver(), PeriodicSaver(),
ValidationError(dataset_test, prefix='test'), ValidationError(dataset_test, prefix='test'),
]), ]),
......
...@@ -126,7 +126,7 @@ def get_config(): ...@@ -126,7 +126,7 @@ def get_config():
dataset=dataset_train, dataset=dataset_train,
optimizer=tf.train.AdamOptimizer(lr), optimizer=tf.train.AdamOptimizer(lr),
callbacks=Callbacks([ callbacks=Callbacks([
SummaryWriter(), SummaryWriter(print_tag=['train_cost', 'train_error']),
PeriodicSaver(), PeriodicSaver(),
ValidationError(dataset_test, prefix='test'), ValidationError(dataset_test, prefix='test'),
]), ]),
......
...@@ -30,8 +30,9 @@ class PeriodicSaver(PeriodicCallback): ...@@ -30,8 +30,9 @@ class PeriodicSaver(PeriodicCallback):
global_step=self.global_step) global_step=self.global_step)
class SummaryWriter(Callback): class SummaryWriter(Callback):
def __init__(self): def __init__(self, print_tag=None):
self.log_dir = logger.LOG_DIR self.log_dir = logger.LOG_DIR
self.print_tag = print_tag if print_tag else ['train_cost']
def _before_train(self): def _before_train(self):
self.writer = tf.train.SummaryWriter( self.writer = tf.train.SummaryWriter(
...@@ -44,5 +45,11 @@ class SummaryWriter(Callback): ...@@ -44,5 +45,11 @@ class SummaryWriter(Callback):
if self.summary_op is None: if self.summary_op is None:
return return
summary_str = self.summary_op.eval() summary_str = self.summary_op.eval()
self.writer.add_summary(summary_str, get_global_step()) summary = tf.Summary.FromString(summary_str)
for val in summary.value:
if val.tag in self.print_tag:
assert val.WhichOneof('value') == 'simple_value', \
'Cannot print summary {}: not a simple_value summary!'.format(val.tag)
logger.info('{}: {:.4f}'.format(val.tag, val.simple_value))
self.writer.add_summary(summary, get_global_step())
...@@ -58,9 +58,9 @@ class CallbackTimeLogger(object): ...@@ -58,9 +58,9 @@ class CallbackTimeLogger(object):
msgs = [] msgs = []
for name, t in self.times: for name, t in self.times:
if t / self.tot > 0.3 and t > 1: if t / self.tot > 0.3 and t > 1:
msgs.append("{}:{}sec".format(name, t)) msgs.append("{}:{:.3f}sec".format(name, t))
logger.info( logger.info(
"Callbacks took {} sec in total. {}".format( "Callbacks took {:.3f} sec in total. {}".format(
self.tot, ' '.join(msgs))) self.tot, ' '.join(msgs)))
......
...@@ -60,14 +60,9 @@ class ValidationError(PeriodicCallback): ...@@ -60,14 +60,9 @@ class ValidationError(PeriodicCallback):
pbar.update() pbar.update()
cost_avg = cost_sum / cnt cost_avg = cost_sum / cnt
self.writer.add_summary( self.writer.add_summary(create_summary(
create_summary('{}_error'.format(self.prefix), '{}_error'.format(self.prefix), err_stat.accuracy), self.global_step)
err_stat.accuracy), self.writer.add_summary(create_summary(
self.global_step) '{}_cost'.format(self.prefix), cost_avg), self.global_step)
self.writer.add_summary( logger.info("{}_cost: {:.4f}".format(self.prefix, cost_avg))
create_summary('{}_cost'.format(self.prefix), logger.info("{}_error: {:.4f}".format(self.prefix, err_stat.accuracy))
cost_avg),
self.global_step)
logger.info(
"{} validation after epoch{},step{}: err={:.4f}, cost={:.3f}".format(
self.prefix, self.epoch_num, self.global_step, err_stat.accuracy, cost_avg))
...@@ -56,7 +56,10 @@ def summary_moving_average(cost_var): ...@@ -56,7 +56,10 @@ def summary_moving_average(cost_var):
tf.get_collection(SUMMARY_VARS_KEY) + \ tf.get_collection(SUMMARY_VARS_KEY) + \
tf.get_collection(COST_VARS_KEY) tf.get_collection(COST_VARS_KEY)
avg_maintain_op = averager.apply(vars_to_summary) avg_maintain_op = averager.apply(vars_to_summary)
for c in vars_to_summary: for idx, c in enumerate(vars_to_summary):
tf.scalar_summary(c.op.name, averager.average(c)) name = c.op.name
if idx == 0:
name = 'train_cost'
tf.scalar_summary(name, averager.average(c))
return avg_maintain_op return avg_maintain_op
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment