Commit 51c58dfa authored by Yuxin Wu's avatar Yuxin Wu

summary writer print all scalar by default

parent e6a136ce
...@@ -109,7 +109,7 @@ def get_config(): ...@@ -109,7 +109,7 @@ def get_config():
dataset=dataset_train, dataset=dataset_train,
optimizer=tf.train.AdamOptimizer(lr), optimizer=tf.train.AdamOptimizer(lr),
callbacks=Callbacks([ callbacks=Callbacks([
SummaryWriter(print_tag=['train_cost', 'train_error']), SummaryWriter(),
PeriodicSaver(), PeriodicSaver(),
ValidationError(dataset_test, prefix='validation'), ValidationError(dataset_test, prefix='validation'),
]), ]),
......
...@@ -32,8 +32,9 @@ class PeriodicSaver(PeriodicCallback): ...@@ -32,8 +32,9 @@ class PeriodicSaver(PeriodicCallback):
class SummaryWriter(Callback): class SummaryWriter(Callback):
def __init__(self, print_tag=None): def __init__(self, print_tag=None):
""" if None, print all scalar summary"""
self.log_dir = logger.LOG_DIR self.log_dir = logger.LOG_DIR
self.print_tag = print_tag if print_tag else ['train_cost'] self.print_tag = print_tag
def _before_train(self): def _before_train(self):
self.writer = tf.train.SummaryWriter( self.writer = tf.train.SummaryWriter(
...@@ -44,22 +45,20 @@ class SummaryWriter(Callback): ...@@ -44,22 +45,20 @@ class SummaryWriter(Callback):
def _trigger_epoch(self): def _trigger_epoch(self):
self.epoch_num += 1 self.epoch_num += 1
# check if there is any summary # check if there is any summary to write
if self.summary_op is None: if self.summary_op is None:
return return
summary_str = self.summary_op.eval() summary_str = self.summary_op.eval()
summary = tf.Summary.FromString(summary_str) summary = tf.Summary.FromString(summary_str)
printed_tag = set() printed_tag = set()
for val in summary.value: for val in summary.value:
#print val.tag if val.WhichOneof('value') == 'simple_value':
val.tag = re.sub('tower[0-9]*/', '', val.tag) val.tag = re.sub('tower[0-9]*/', '', val.tag)
if val.tag in self.print_tag: if self.print_tag is None or val.tag in self.print_tag:
assert val.WhichOneof('value') == 'simple_value', \ logger.info('{}: {:.4f}'.format(val.tag, val.simple_value))
'Cannot print summary {}: not a simple_value summary!'.format(val.tag) printed_tag.add(val.tag)
logger.info('{}: {:.4f}'.format(val.tag, val.simple_value))
printed_tag.add(val.tag)
self.writer.add_summary(summary, get_global_step()) self.writer.add_summary(summary, get_global_step())
if self.epoch_num == 1: if self.print_tag is not None and self.epoch_num == 1:
if len(printed_tag) != len(self.print_tag): if len(printed_tag) != len(self.print_tag):
logger.warn("Tags to print not found in Summary Writer: {}".format( logger.warn("Tags to print not found in Summary Writer: {}".format(
", ".join([k for k in self.print_tag if k not in printed_tag]))) ", ".join([k for k in self.print_tag if k not in printed_tag])))
......
...@@ -72,6 +72,7 @@ def average_grads(tower_grads): ...@@ -72,6 +72,7 @@ def average_grads(tower_grads):
def summary_grads(grads): def summary_grads(grads):
for grad, var in grads: for grad, var in grads:
if grad: if grad:
# TODO also summary RMS and print
tf.histogram_summary(var.op.name + '/gradients', grad) tf.histogram_summary(var.op.name + '/gradients', grad)
def check_grads(grads): def check_grads(grads):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment