Commit a4d51a2d authored by Yuxin Wu's avatar Yuxin Wu

stat holder and summary writer

parent 51c58dfa
...@@ -92,7 +92,7 @@ def get_config(): ...@@ -92,7 +92,7 @@ def get_config():
dataset_train = BatchData(dataset.Mnist('train'), 128) dataset_train = BatchData(dataset.Mnist('train'), 128)
dataset_test = BatchData(dataset.Mnist('test'), 256, remainder=True) dataset_test = BatchData(dataset.Mnist('test'), 256, remainder=True)
step_per_epoch = dataset_train.size() step_per_epoch = dataset_train.size()
step_per_epoch = 3 #step_per_epoch = 20
# prepare session # prepare session
sess_config = get_default_sess_config() sess_config = get_default_sess_config()
......
...@@ -10,7 +10,7 @@ import re ...@@ -10,7 +10,7 @@ import re
from .base import Callback, PeriodicCallback from .base import Callback, PeriodicCallback
from ..utils import * from ..utils import *
__all__ = ['PeriodicSaver', 'SummaryWriter'] __all__ = ['PeriodicSaver']
class PeriodicSaver(PeriodicCallback): class PeriodicSaver(PeriodicCallback):
def __init__(self, period=1, keep_recent=10, keep_freq=0.5): def __init__(self, period=1, keep_recent=10, keep_freq=0.5):
...@@ -30,39 +30,5 @@ class PeriodicSaver(PeriodicCallback): ...@@ -30,39 +30,5 @@ class PeriodicSaver(PeriodicCallback):
self.path, self.path,
global_step=self.global_step) global_step=self.global_step)
class SummaryWriter(Callback): class MinSaver(Callback):
def __init__(self, print_tag=None): pass
""" if None, print all scalar summary"""
self.log_dir = logger.LOG_DIR
self.print_tag = print_tag
def _before_train(self):
self.writer = tf.train.SummaryWriter(
self.log_dir, graph_def=self.sess.graph_def)
tf.add_to_collection(SUMMARY_WRITER_COLLECTION_KEY, self.writer)
self.summary_op = tf.merge_all_summaries()
self.epoch_num = 0
def _trigger_epoch(self):
self.epoch_num += 1
# check if there is any summary to write
if self.summary_op is None:
return
summary_str = self.summary_op.eval()
summary = tf.Summary.FromString(summary_str)
printed_tag = set()
for val in summary.value:
if val.WhichOneof('value') == 'simple_value':
val.tag = re.sub('tower[0-9]*/', '', val.tag)
if self.print_tag is None or val.tag in self.print_tag:
logger.info('{}: {:.4f}'.format(val.tag, val.simple_value))
printed_tag.add(val.tag)
self.writer.add_summary(summary, get_global_step())
if self.print_tag is not None and self.epoch_num == 1:
if len(printed_tag) != len(self.print_tag):
logger.warn("Tags to print not found in Summary Writer: {}".format(
", ".join([k for k in self.print_tag if k not in printed_tag])))
def _after_train(self):
self.writer.close()
...@@ -7,7 +7,7 @@ import tensorflow as tf ...@@ -7,7 +7,7 @@ import tensorflow as tf
from contextlib import contextmanager from contextlib import contextmanager
from .base import Callback from .base import Callback
from .common import * from .summary import *
from ..utils import * from ..utils import *
__all__ = ['Callbacks'] __all__ = ['Callbacks']
...@@ -57,18 +57,18 @@ class CallbackTimeLogger(object): ...@@ -57,18 +57,18 @@ class CallbackTimeLogger(object):
class TrainCallbacks(Callback): class TrainCallbacks(Callback):
def __init__(self, callbacks): def __init__(self, callbacks):
self.cbs = callbacks self.cbs = callbacks
# put SummaryWriter to the first
for idx, cb in enumerate(self.cbs): for idx, cb in enumerate(self.cbs):
# put SummaryWriter to the beginning
if type(cb) == SummaryWriter: if type(cb) == SummaryWriter:
self.cbs.insert(0, self.cbs.pop(idx)) self.cbs.insert(0, self.cbs.pop(idx))
break break
else: else:
raise ValueError("Callbacks must contain a SummaryWriter!") logger.warn("SummaryWriter must be used! Insert a default one automatically.")
self.cbs.insert(0, SummaryWriter())
def _before_train(self): def _before_train(self):
for cb in self.cbs: for cb in self.cbs:
cb.before_train() cb.before_train()
self.writer = tf.get_collection(SUMMARY_WRITER_COLLECTION_KEY)[0]
def _after_train(self): def _after_train(self):
for cb in self.cbs: for cb in self.cbs:
...@@ -84,7 +84,6 @@ class TrainCallbacks(Callback): ...@@ -84,7 +84,6 @@ class TrainCallbacks(Callback):
s = time.time() s = time.time()
cb.trigger_epoch() cb.trigger_epoch()
tm.add(type(cb).__name__, time.time() - s) tm.add(type(cb).__name__, time.time() - s)
self.writer.flush()
tm.log() tm.log()
class TestCallbacks(Callback): class TestCallbacks(Callback):
...@@ -97,13 +96,11 @@ class TestCallbacks(Callback): ...@@ -97,13 +96,11 @@ class TestCallbacks(Callback):
self.cbs = callbacks self.cbs = callbacks
def _before_train(self): def _before_train(self):
self.writer = tf.get_collection(SUMMARY_WRITER_COLLECTION_KEY)[0]
with create_test_session() as sess: with create_test_session() as sess:
self.sess = sess self.sess = sess
self.graph = sess.graph self.graph = sess.graph
self.saver = tf.train.Saver() self.saver = tf.train.Saver()
tf.add_to_collection(SUMMARY_WRITER_COLLECTION_KEY, self.writer)
for cb in self.cbs: for cb in self.cbs:
cb.before_train() cb.before_train()
...@@ -130,7 +127,6 @@ class TestCallbacks(Callback): ...@@ -130,7 +127,6 @@ class TestCallbacks(Callback):
s = time.time() s = time.time()
cb.trigger_epoch() cb.trigger_epoch()
tm.add(type(cb).__name__, time.time() - s) tm.add(type(cb).__name__, time.time() - s)
self.writer.flush()
tm.log() tm.log()
class Callbacks(Callback): class Callbacks(Callback):
...@@ -161,6 +157,7 @@ class Callbacks(Callback): ...@@ -161,6 +157,7 @@ class Callbacks(Callback):
self.train.after_train() self.train.after_train()
if self.test: if self.test:
self.test.after_train() self.test.after_train()
logger.writer.close()
def trigger_step(self): def trigger_step(self):
self.train.trigger_step() self.train.trigger_step()
...@@ -168,6 +165,7 @@ class Callbacks(Callback): ...@@ -168,6 +165,7 @@ class Callbacks(Callback):
def _trigger_epoch(self): def _trigger_epoch(self):
self.train.trigger_epoch() self.train.trigger_epoch()
# TODO test callbacks can be run async?
if self.test: if self.test:
self.test.trigger_epoch() self.test.trigger_epoch()
logger.writer.flush()
logger.stat_holder.finalize()
#!/usr/bin/env python2
# -*- coding: utf-8 -*-
# File: summary.py
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
import tensorflow as tf
import re
import os
import operator
import cPickle as pickle
from .base import Callback, PeriodicCallback
from ..utils import *
__all__ = ['SummaryWriter']
class StatHolder(object):
def __init__(self, log_dir, print_tag=None):
self.print_tag = None if print_tag is None else set(print_tag)
self.stat_now = {}
self.log_dir = log_dir
self.filename = os.path.join(log_dir, 'stat.pkl')
if os.path.isfile(self.filename):
logger.info("Loading stats from {}...".format(self.filename))
with open(self.filename) as f:
self.stat_history = pickle.load(f)
else:
self.stat_history = []
def add_stat(self, k, v):
self.stat_now[k] = v
def finalize(self):
self._print_stat()
self.stat_history.append(self.stat_now)
self.stat_now = {}
self._write_stat()
def _print_stat(self):
for k, v in sorted(self.stat_now.items(), key=operator.itemgetter(0)):
if self.print_tag is None or k in self.print_tag:
logger.info('{}: {:.4f}'.format(k, v))
def _write_stat(self):
tmp_filename = self.filename + '.tmp'
with open(tmp_filename, 'wb') as f:
pickle.dump(self.stat_history, f)
os.rename(tmp_filename, self.filename)
class SummaryWriter(Callback):
def __init__(self, print_tag=None):
""" print_tag : a list of regex to match scalar summary to print
if None, will print all scalar tags
"""
self.log_dir = logger.LOG_DIR
logger.stat_holder = StatHolder(self.log_dir, print_tag)
def _before_train(self):
logger.writer = tf.train.SummaryWriter(
self.log_dir, graph_def=self.sess.graph_def)
self.summary_op = tf.merge_all_summaries()
def _trigger_epoch(self):
# check if there is any summary to write
if self.summary_op is None:
return
summary_str = self.summary_op.eval()
summary = tf.Summary.FromString(summary_str)
for val in summary.value:
if val.WhichOneof('value') == 'simple_value':
val.tag = re.sub('tower[0-9]*/', '', val.tag)
logger.stat_holder.add_stat(val.tag, val.simple_value)
logger.writer.add_summary(summary, self.global_step)
...@@ -28,7 +28,6 @@ class ValidationCallback(PeriodicCallback): ...@@ -28,7 +28,6 @@ class ValidationCallback(PeriodicCallback):
def _before_train(self): def _before_train(self):
self.input_vars = tf.get_collection(MODEL_KEY)[0].get_input_vars() self.input_vars = tf.get_collection(MODEL_KEY)[0].get_input_vars()
self.cost_var = self.get_tensor(self.cost_var_name) self.cost_var = self.get_tensor(self.cost_var_name)
self.writer = tf.get_collection(SUMMARY_WRITER_COLLECTION_KEY)[0]
self._find_output_vars() self._find_output_vars()
def get_tensor(self, name): def get_tensor(self, name):
...@@ -64,9 +63,9 @@ class ValidationCallback(PeriodicCallback): ...@@ -64,9 +63,9 @@ class ValidationCallback(PeriodicCallback):
pbar.update() pbar.update()
cost_avg = cost_sum / cnt cost_avg = cost_sum / cnt
self.writer.add_summary(create_summary( logger.writer.add_summary(create_summary(
'{}_cost'.format(self.prefix), cost_avg), self.global_step) '{}_cost'.format(self.prefix), cost_avg), self.global_step)
logger.info("{}_cost: {:.4f}".format(self.prefix, cost_avg)) logger.stat_holder.add_stat("{}_cost".format(self.prefix), cost_avg)
def _trigger_periodic(self): def _trigger_periodic(self):
for dp, outputs in self._run_validation(): for dp, outputs in self._run_validation():
...@@ -102,6 +101,6 @@ class ValidationError(ValidationCallback): ...@@ -102,6 +101,6 @@ class ValidationError(ValidationCallback):
wrong = outputs[0] wrong = outputs[0]
err_stat.feed(wrong, batch_size) err_stat.feed(wrong, batch_size)
self.writer.add_summary(create_summary( logger.writer.add_summary(create_summary(
'{}_error'.format(self.prefix), err_stat.accuracy), self.global_step) '{}_error'.format(self.prefix), err_stat.accuracy), self.global_step)
logger.info("{}_error: {:.4f}".format(self.prefix, err_stat.accuracy)) logger.stat_holder.add_stat("{}_error".format(self.prefix), err_stat.accuracy)
...@@ -61,3 +61,10 @@ def set_logger_file(filename): ...@@ -61,3 +61,10 @@ def set_logger_file(filename):
mkdir_p(os.path.dirname(LOG_FILE)) mkdir_p(os.path.dirname(LOG_FILE))
set_file(LOG_FILE) set_file(LOG_FILE)
# global logger:
# a SummaryWriter
writer = None
# a StatHolder
stat_holder = None
...@@ -6,11 +6,10 @@ ...@@ -6,11 +6,10 @@
GLOBAL_STEP_OP_NAME = 'global_step' GLOBAL_STEP_OP_NAME = 'global_step'
GLOBAL_STEP_VAR_NAME = 'global_step:0' GLOBAL_STEP_VAR_NAME = 'global_step:0'
SUMMARY_WRITER_COLLECTION_KEY = 'summary_writer' # extra variables to summarize during training in a moving-average way
MOVING_SUMMARY_VARS_KEY = 'MOVING_SUMMARY_VARIABLES'
MOVING_SUMMARY_VARS_KEY = 'MOVING_SUMMARY_VARIABLES' # extra variables to summarize during training
MODEL_KEY = 'MODEL' MODEL_KEY = 'MODEL'
# export all upper case variables # export all upper case variables
all_local_names = locals().keys() all_local_names = locals().keys()
__all__ = [x for x in all_local_names if x.upper() == x] __all__ = [x for x in all_local_names if x.isupper()]
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment