Commit 5ec865d8 authored by ppwwyyxx's avatar ppwwyyxx

histogram summary

parent 1f657fcb
...@@ -71,7 +71,7 @@ def get_model(inputs): ...@@ -71,7 +71,7 @@ def get_model(inputs):
# monitor training accuracy # monitor training accuracy
tf.add_to_collection( tf.add_to_collection(
SUMMARY_VARS_KEY, SUMMARY_VARS_KEY,
tf.reduce_mean(correct, name='train_accuracy')) 1 - tf.reduce_mean(correct, name='train_error'))
# weight decay on all W of fc layers # weight decay on all W of fc layers
wd_cost = tf.mul(1e-4, wd_cost = tf.mul(1e-4,
...@@ -81,7 +81,7 @@ def get_model(inputs): ...@@ -81,7 +81,7 @@ def get_model(inputs):
return [prob, nr_correct], tf.add_n(tf.get_collection(COST_VARS_KEY), name='cost') return [prob, nr_correct], tf.add_n(tf.get_collection(COST_VARS_KEY), name='cost')
def main(): def main(argv=None):
BATCH_SIZE = 128 BATCH_SIZE = 128
with tf.Graph().as_default(): with tf.Graph().as_default():
dataset_train = BatchData(Mnist('train'), BATCH_SIZE) dataset_train = BatchData(Mnist('train'), BATCH_SIZE)
...@@ -95,6 +95,7 @@ def main(): ...@@ -95,6 +95,7 @@ def main():
label_var = tf.placeholder(tf.int32, shape=(None,), name='label') label_var = tf.placeholder(tf.int32, shape=(None,), name='label')
input_vars = [image_var, label_var] input_vars = [image_var, label_var]
output_vars, cost_var = get_model(input_vars) output_vars, cost_var = get_model(input_vars)
add_histogram_summary('.*/W') # monitor histogram of all W
config = dict( config = dict(
dataset_train=dataset_train, dataset_train=dataset_train,
...@@ -104,7 +105,7 @@ def main(): ...@@ -104,7 +105,7 @@ def main():
dataset_test, dataset_test,
prefix='test'), prefix='test'),
PeriodicSaver(LOG_DIR, period=1), PeriodicSaver(LOG_DIR, period=1),
SummaryWriter(LOG_DIR, histogram_regex='.*/W'), SummaryWriter(LOG_DIR),
], ],
session_config=sess_config, session_config=sess_config,
inputs=input_vars, inputs=input_vars,
...@@ -115,6 +116,5 @@ def main(): ...@@ -115,6 +116,5 @@ def main():
from train import start_train from train import start_train
start_train(config) start_train(config)
if __name__ == '__main__': if __name__ == '__main__':
main() tf.app.run()
...@@ -55,6 +55,11 @@ def start_train(config): ...@@ -55,6 +55,11 @@ def start_train(config):
# maintain average in each step # maintain average in each step
with tf.control_dependencies([avg_maintain_op]): with tf.control_dependencies([avg_maintain_op]):
grads = optimizer.compute_gradients(cost_var) grads = optimizer.compute_gradients(cost_var)
for grad, var in grads:
if grad:
tf.histogram_summary(var.op.name + '/gradients', grad)
train_op = optimizer.apply_gradients(grads, global_step_var) train_op = optimizer.apply_gradients(grads, global_step_var)
sess = tf.Session(config=sess_config) sess = tf.Session(config=sess_config)
......
...@@ -66,22 +66,11 @@ class SummaryWriter(Callback): ...@@ -66,22 +66,11 @@ class SummaryWriter(Callback):
def __init__(self, log_dir, histogram_regex=None): def __init__(self, log_dir, histogram_regex=None):
self.log_dir = log_dir self.log_dir = log_dir
self.epoch_num = 0 self.epoch_num = 0
self.hist_regex = histogram_regex
def _before_train(self): def _before_train(self):
self.writer = tf.train.SummaryWriter( self.writer = tf.train.SummaryWriter(
self.log_dir, graph_def=self.sess.graph_def) self.log_dir, graph_def=self.sess.graph_def)
tf.add_to_collection(SUMMARY_WRITER_COLLECTION_KEY, self.writer) tf.add_to_collection(SUMMARY_WRITER_COLLECTION_KEY, self.writer)
# create some summary
if self.hist_regex is not None:
import re
params = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)
for p in params:
name = p.name
if re.search(self.hist_regex, name):
tf.histogram_summary(name, p)
self.summary_op = tf.merge_all_summaries() self.summary_op = tf.merge_all_summaries()
def trigger_step(self, inputs, outputs, cost): def trigger_step(self, inputs, outputs, cost):
......
...@@ -5,11 +5,11 @@ ...@@ -5,11 +5,11 @@
import tensorflow as tf import tensorflow as tf
__all__ = ['create_summary'] __all__ = ['create_summary', 'add_histogram_summary']
def create_summary(name, v): def create_summary(name, v):
# TODO support image or histogram
""" """
Return a tf.Summary object with name and simple value v
Args: v: a value Args: v: a value
""" """
...@@ -19,3 +19,14 @@ def create_summary(name, v): ...@@ -19,3 +19,14 @@ def create_summary(name, v):
s.value.add(tag=name, simple_value=v) s.value.add(tag=name, simple_value=v)
return s return s
def add_histogram_summary(regex):
"""
Add histogram summary for all trainable variables matching the regex
"""
import re
params = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)
for p in params:
name = p.name
if re.search(regex, name):
tf.histogram_summary(name, p)
...@@ -56,12 +56,12 @@ class ValidationAccuracy(PeriodicCallback): ...@@ -56,12 +56,12 @@ class ValidationAccuracy(PeriodicCallback):
cost_avg = cost_sum / cnt cost_avg = cost_sum / cnt
self.writer.add_summary( self.writer.add_summary(
create_summary('{}_accuracy'.format(self.prefix), create_summary('{}_error'.format(self.prefix),
correct_stat.accuracy), 1 - correct_stat.accuracy),
self.epoch_num) self.epoch_num)
self.writer.add_summary( self.writer.add_summary(
create_summary('{}_cost'.format(self.prefix), create_summary('{}_cost'.format(self.prefix),
cost_avg), cost_avg),
self.epoch_num) self.epoch_num)
print "{} validation after epoch {}: acc={}, cost={}".format( print "{} validation after epoch {}: err={}, cost={}".format(
self.prefix, self.epoch_num, correct_stat.accuracy, cost_avg) self.prefix, self.epoch_num, 1 - correct_stat.accuracy, cost_avg)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment