Commit 51c58dfa authored by Yuxin Wu's avatar Yuxin Wu

summary writer print all scalar by default

parent e6a136ce
......@@ -109,7 +109,7 @@ def get_config():
dataset=dataset_train,
optimizer=tf.train.AdamOptimizer(lr),
callbacks=Callbacks([
SummaryWriter(print_tag=['train_cost', 'train_error']),
SummaryWriter(),
PeriodicSaver(),
ValidationError(dataset_test, prefix='validation'),
]),
......
......@@ -32,8 +32,9 @@ class PeriodicSaver(PeriodicCallback):
class SummaryWriter(Callback):
def __init__(self, print_tag=None):
""" if None, print all scalar summary"""
self.log_dir = logger.LOG_DIR
self.print_tag = print_tag if print_tag else ['train_cost']
self.print_tag = print_tag
def _before_train(self):
self.writer = tf.train.SummaryWriter(
......@@ -44,22 +45,20 @@ class SummaryWriter(Callback):
def _trigger_epoch(self):
self.epoch_num += 1
# check if there is any summary
# check if there is any summary to write
if self.summary_op is None:
return
summary_str = self.summary_op.eval()
summary = tf.Summary.FromString(summary_str)
printed_tag = set()
for val in summary.value:
#print val.tag
val.tag = re.sub('tower[0-9]*/', '', val.tag)
if val.tag in self.print_tag:
assert val.WhichOneof('value') == 'simple_value', \
'Cannot print summary {}: not a simple_value summary!'.format(val.tag)
logger.info('{}: {:.4f}'.format(val.tag, val.simple_value))
printed_tag.add(val.tag)
if val.WhichOneof('value') == 'simple_value':
val.tag = re.sub('tower[0-9]*/', '', val.tag)
if self.print_tag is None or val.tag in self.print_tag:
logger.info('{}: {:.4f}'.format(val.tag, val.simple_value))
printed_tag.add(val.tag)
self.writer.add_summary(summary, get_global_step())
if self.epoch_num == 1:
if self.print_tag is not None and self.epoch_num == 1:
if len(printed_tag) != len(self.print_tag):
logger.warn("Tags to print not found in Summary Writer: {}".format(
", ".join([k for k in self.print_tag if k not in printed_tag])))
......
......@@ -72,6 +72,7 @@ def average_grads(tower_grads):
def summary_grads(grads):
for grad, var in grads:
if grad:
# TODO also summary RMS and print
tf.histogram_summary(var.op.name + '/gradients', grad)
def check_grads(grads):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment