Commit a4d51a2d authored by Yuxin Wu's avatar Yuxin Wu

stat holder and summary writer

parent 51c58dfa
......@@ -92,7 +92,7 @@ def get_config():
dataset_train = BatchData(dataset.Mnist('train'), 128)
dataset_test = BatchData(dataset.Mnist('test'), 256, remainder=True)
step_per_epoch = dataset_train.size()
step_per_epoch = 3
#step_per_epoch = 20
# prepare session
sess_config = get_default_sess_config()
......
......@@ -10,7 +10,7 @@ import re
from .base import Callback, PeriodicCallback
from ..utils import *
__all__ = ['PeriodicSaver', 'SummaryWriter']
__all__ = ['PeriodicSaver']
class PeriodicSaver(PeriodicCallback):
def __init__(self, period=1, keep_recent=10, keep_freq=0.5):
......@@ -30,39 +30,5 @@ class PeriodicSaver(PeriodicCallback):
self.path,
global_step=self.global_step)
class SummaryWriter(Callback):
def __init__(self, print_tag=None):
""" if None, print all scalar summary"""
self.log_dir = logger.LOG_DIR
self.print_tag = print_tag
def _before_train(self):
self.writer = tf.train.SummaryWriter(
self.log_dir, graph_def=self.sess.graph_def)
tf.add_to_collection(SUMMARY_WRITER_COLLECTION_KEY, self.writer)
self.summary_op = tf.merge_all_summaries()
self.epoch_num = 0
def _trigger_epoch(self):
self.epoch_num += 1
# check if there is any summary to write
if self.summary_op is None:
return
summary_str = self.summary_op.eval()
summary = tf.Summary.FromString(summary_str)
printed_tag = set()
for val in summary.value:
if val.WhichOneof('value') == 'simple_value':
val.tag = re.sub('tower[0-9]*/', '', val.tag)
if self.print_tag is None or val.tag in self.print_tag:
logger.info('{}: {:.4f}'.format(val.tag, val.simple_value))
printed_tag.add(val.tag)
self.writer.add_summary(summary, get_global_step())
if self.print_tag is not None and self.epoch_num == 1:
if len(printed_tag) != len(self.print_tag):
logger.warn("Tags to print not found in Summary Writer: {}".format(
", ".join([k for k in self.print_tag if k not in printed_tag])))
def _after_train(self):
self.writer.close()
class MinSaver(Callback):
pass
......@@ -7,7 +7,7 @@ import tensorflow as tf
from contextlib import contextmanager
from .base import Callback
from .common import *
from .summary import *
from ..utils import *
__all__ = ['Callbacks']
......@@ -57,18 +57,18 @@ class CallbackTimeLogger(object):
class TrainCallbacks(Callback):
def __init__(self, callbacks):
self.cbs = callbacks
# put SummaryWriter to the first
for idx, cb in enumerate(self.cbs):
# put SummaryWriter to the beginning
if type(cb) == SummaryWriter:
self.cbs.insert(0, self.cbs.pop(idx))
break
else:
raise ValueError("Callbacks must contain a SummaryWriter!")
logger.warn("SummaryWriter must be used! Insert a default one automatically.")
self.cbs.insert(0, SummaryWriter())
def _before_train(self):
for cb in self.cbs:
cb.before_train()
self.writer = tf.get_collection(SUMMARY_WRITER_COLLECTION_KEY)[0]
def _after_train(self):
for cb in self.cbs:
......@@ -84,7 +84,6 @@ class TrainCallbacks(Callback):
s = time.time()
cb.trigger_epoch()
tm.add(type(cb).__name__, time.time() - s)
self.writer.flush()
tm.log()
class TestCallbacks(Callback):
......@@ -97,13 +96,11 @@ class TestCallbacks(Callback):
self.cbs = callbacks
def _before_train(self):
self.writer = tf.get_collection(SUMMARY_WRITER_COLLECTION_KEY)[0]
with create_test_session() as sess:
self.sess = sess
self.graph = sess.graph
self.saver = tf.train.Saver()
tf.add_to_collection(SUMMARY_WRITER_COLLECTION_KEY, self.writer)
for cb in self.cbs:
cb.before_train()
......@@ -130,7 +127,6 @@ class TestCallbacks(Callback):
s = time.time()
cb.trigger_epoch()
tm.add(type(cb).__name__, time.time() - s)
self.writer.flush()
tm.log()
class Callbacks(Callback):
......@@ -161,6 +157,7 @@ class Callbacks(Callback):
self.train.after_train()
if self.test:
self.test.after_train()
logger.writer.close()
def trigger_step(self):
self.train.trigger_step()
......@@ -168,6 +165,7 @@ class Callbacks(Callback):
def _trigger_epoch(self):
self.train.trigger_epoch()
# TODO test callbacks can be run async?
if self.test:
self.test.trigger_epoch()
logger.writer.flush()
logger.stat_holder.finalize()
#!/usr/bin/env python2
# -*- coding: utf-8 -*-
# File: summary.py
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
import tensorflow as tf
import re
import os
import operator
import cPickle as pickle
from .base import Callback, PeriodicCallback
from ..utils import *
__all__ = ['SummaryWriter']
class StatHolder(object):
def __init__(self, log_dir, print_tag=None):
self.print_tag = None if print_tag is None else set(print_tag)
self.stat_now = {}
self.log_dir = log_dir
self.filename = os.path.join(log_dir, 'stat.pkl')
if os.path.isfile(self.filename):
logger.info("Loading stats from {}...".format(self.filename))
with open(self.filename) as f:
self.stat_history = pickle.load(f)
else:
self.stat_history = []
def add_stat(self, k, v):
self.stat_now[k] = v
def finalize(self):
self._print_stat()
self.stat_history.append(self.stat_now)
self.stat_now = {}
self._write_stat()
def _print_stat(self):
for k, v in sorted(self.stat_now.items(), key=operator.itemgetter(0)):
if self.print_tag is None or k in self.print_tag:
logger.info('{}: {:.4f}'.format(k, v))
def _write_stat(self):
tmp_filename = self.filename + '.tmp'
with open(tmp_filename, 'wb') as f:
pickle.dump(self.stat_history, f)
os.rename(tmp_filename, self.filename)
class SummaryWriter(Callback):
def __init__(self, print_tag=None):
""" print_tag : a list of regex to match scalar summary to print
if None, will print all scalar tags
"""
self.log_dir = logger.LOG_DIR
logger.stat_holder = StatHolder(self.log_dir, print_tag)
def _before_train(self):
logger.writer = tf.train.SummaryWriter(
self.log_dir, graph_def=self.sess.graph_def)
self.summary_op = tf.merge_all_summaries()
def _trigger_epoch(self):
# check if there is any summary to write
if self.summary_op is None:
return
summary_str = self.summary_op.eval()
summary = tf.Summary.FromString(summary_str)
for val in summary.value:
if val.WhichOneof('value') == 'simple_value':
val.tag = re.sub('tower[0-9]*/', '', val.tag)
logger.stat_holder.add_stat(val.tag, val.simple_value)
logger.writer.add_summary(summary, self.global_step)
......@@ -28,7 +28,6 @@ class ValidationCallback(PeriodicCallback):
def _before_train(self):
self.input_vars = tf.get_collection(MODEL_KEY)[0].get_input_vars()
self.cost_var = self.get_tensor(self.cost_var_name)
self.writer = tf.get_collection(SUMMARY_WRITER_COLLECTION_KEY)[0]
self._find_output_vars()
def get_tensor(self, name):
......@@ -64,9 +63,9 @@ class ValidationCallback(PeriodicCallback):
pbar.update()
cost_avg = cost_sum / cnt
self.writer.add_summary(create_summary(
logger.writer.add_summary(create_summary(
'{}_cost'.format(self.prefix), cost_avg), self.global_step)
logger.info("{}_cost: {:.4f}".format(self.prefix, cost_avg))
logger.stat_holder.add_stat("{}_cost".format(self.prefix), cost_avg)
def _trigger_periodic(self):
for dp, outputs in self._run_validation():
......@@ -102,6 +101,6 @@ class ValidationError(ValidationCallback):
wrong = outputs[0]
err_stat.feed(wrong, batch_size)
self.writer.add_summary(create_summary(
logger.writer.add_summary(create_summary(
'{}_error'.format(self.prefix), err_stat.accuracy), self.global_step)
logger.info("{}_error: {:.4f}".format(self.prefix, err_stat.accuracy))
logger.stat_holder.add_stat("{}_error".format(self.prefix), err_stat.accuracy)
......@@ -61,3 +61,10 @@ def set_logger_file(filename):
mkdir_p(os.path.dirname(LOG_FILE))
set_file(LOG_FILE)
# global logger:
# a SummaryWriter
writer = None
# a StatHolder
stat_holder = None
......@@ -6,11 +6,10 @@
GLOBAL_STEP_OP_NAME = 'global_step'
GLOBAL_STEP_VAR_NAME = 'global_step:0'
SUMMARY_WRITER_COLLECTION_KEY = 'summary_writer'
MOVING_SUMMARY_VARS_KEY = 'MOVING_SUMMARY_VARIABLES' # extra variables to summarize during training
# extra variables to summarize during training in a moving-average way
MOVING_SUMMARY_VARS_KEY = 'MOVING_SUMMARY_VARIABLES'
MODEL_KEY = 'MODEL'
# export all upper case variables
all_local_names = locals().keys()
__all__ = [x for x in all_local_names if x.upper() == x]
__all__ = [x for x in all_local_names if x.isupper()]
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment