Commit 884af444 authored by Yuxin Wu's avatar Yuxin Wu

better train/test callback management

parent d731cf7b
...@@ -11,14 +11,20 @@ from abc import abstractmethod, ABCMeta ...@@ -11,14 +11,20 @@ from abc import abstractmethod, ABCMeta
from ..utils import * from ..utils import *
__all__ = ['Callback', 'PeriodicCallback'] __all__ = ['Callback', 'PeriodicCallback', 'TrainCallback', 'TestCallback']
class TrainCallback(object):
pass
class TestCallback(object):
pass
class Callback(object): class Callback(object):
__metaclass__ = ABCMeta __metaclass__ = ABCMeta
running_graph = 'train' type = TrainCallback()
""" The graph that this callback should run on. """ The graph that this callback should run on.
Either 'train' or 'test' Either TrainCallback or TestCallback
""" """
def before_train(self): def before_train(self):
......
...@@ -31,4 +31,8 @@ class PeriodicSaver(PeriodicCallback): ...@@ -31,4 +31,8 @@ class PeriodicSaver(PeriodicCallback):
global_step=self.global_step) global_step=self.global_step)
class MinSaver(Callback): class MinSaver(Callback):
pass def __init__(self, monitor_stat):
self.monitor_stat = monitor_stat
def _trigger_epoch(self):
pass
...@@ -6,7 +6,7 @@ ...@@ -6,7 +6,7 @@
import tensorflow as tf import tensorflow as tf
from contextlib import contextmanager from contextlib import contextmanager
from .base import Callback from .base import Callback, TrainCallback, TestCallback
from .summary import * from .summary import *
from ..utils import * from ..utils import *
...@@ -41,6 +41,12 @@ class CallbackTimeLogger(object): ...@@ -41,6 +41,12 @@ class CallbackTimeLogger(object):
self.tot += time self.tot += time
self.times.append((name, time)) self.times.append((name, time))
@contextmanager
def timed_callback(self, name):
s = time.time()
yield
self.add(name, time.time() - s)
def log(self): def log(self):
""" log the time of some heavy callbacks """ """ log the time of some heavy callbacks """
if self.tot < 3: if self.tot < 3:
...@@ -53,119 +59,98 @@ class CallbackTimeLogger(object): ...@@ -53,119 +59,98 @@ class CallbackTimeLogger(object):
"Callbacks took {:.3f} sec in total. {}".format( "Callbacks took {:.3f} sec in total. {}".format(
self.tot, ' '.join(msgs))) self.tot, ' '.join(msgs)))
class TestCallbackContext(object):
class TrainCallbacks(Callback):
def __init__(self, callbacks):
self.cbs = callbacks
for idx, cb in enumerate(self.cbs):
# put SummaryWriter to the beginning
if type(cb) == SummaryWriter:
self.cbs.insert(0, self.cbs.pop(idx))
break
else:
logger.warn("SummaryWriter must be used! Insert a default one automatically.")
self.cbs.insert(0, SummaryWriter())
def _before_train(self):
for cb in self.cbs:
cb.before_train()
def _after_train(self):
for cb in self.cbs:
cb.after_train()
def trigger_step(self):
for cb in self.cbs:
cb.trigger_step()
def _trigger_epoch(self):
tm = CallbackTimeLogger()
for cb in self.cbs:
s = time.time()
cb.trigger_epoch()
tm.add(type(cb).__name__, time.time() - s)
tm.log()
class TestCallbacks(Callback):
""" """
Hold callbacks to be run in testing graph. A class holding the context needed for running TestCallback
Will set a context with testing graph and testing session, for
each test-time callback to run
""" """
def __init__(self, callbacks): def __init__(self):
self.cbs = callbacks self.sess = None
def _before_train(self): def _init_test_sess(self):
with create_test_session() as sess: with create_test_session() as sess:
self.sess = sess self.sess = sess
self.graph = sess.graph self.graph = sess.graph
self.saver = tf.train.Saver() self.saver = tf.train.Saver()
for cb in self.cbs:
cb.before_train()
def _after_train(self): @contextmanager
for cb in self.cbs: def before_train_context(self):
cb.after_train() if self.sess is None:
self._init_test_sess()
with self.graph.as_default(), self.sess.as_default():
yield
def _trigger_epoch(self): # TODO also do this for after_train?
if not self.cbs:
return def restore_checkpoint(self):
tm = CallbackTimeLogger() ckpt = tf.train.get_checkpoint_state(logger.LOG_DIR)
if ckpt is None:
raise RuntimeError(
"Cannot find a checkpoint state. Do you forget to use PeriodicSaver before any TestCallback?")
logger.info(
"Restore checkpoint from {}".format(ckpt.model_checkpoint_path))
self.saver.restore(self.sess, ckpt.model_checkpoint_path)
@contextmanager
def trigger_epoch_context(self):
with self.graph.as_default(), self.sess.as_default(): with self.graph.as_default(), self.sess.as_default():
s = time.time() yield
ckpt = tf.train.get_checkpoint_state(logger.LOG_DIR)
if ckpt is None:
logger.error(
"Cannot find a checkpoint state. Do you forget to use PeriodicSaver?")
return
logger.info(
"Restore checkpoint from {}".format(ckpt.model_checkpoint_path))
self.saver.restore(self.sess, ckpt.model_checkpoint_path)
tm.add('restore session', time.time() - s)
for cb in self.cbs:
s = time.time()
cb.trigger_epoch()
tm.add(type(cb).__name__, time.time() - s)
tm.log()
class Callbacks(Callback): class Callbacks(Callback):
def __init__(self, cbs): def __init__(self, cbs):
train_cbs = [] # check type
test_cbs = []
for cb in cbs: for cb in cbs:
assert isinstance(cb, Callback), cb.__class__ assert isinstance(cb, Callback), cb.__class__
if cb.running_graph == 'test': if not isinstance(cb.type, (TrainCallback, TestCallback)):
test_cbs.append(cb)
elif cb.running_graph == 'train':
train_cbs.append(cb)
else:
raise ValueError( raise ValueError(
"Unknown callback running graph {}!".format(cb.running_graph)) "Unknown callback running graph {}!".format(str(cb.type)))
self.train = TrainCallbacks(train_cbs)
if test_cbs: # ensure a SummaryWriter
self.test = TestCallbacks(test_cbs) for idx, cb in enumerate(cbs):
if type(cb) == SummaryWriter:
cbs.insert(0, cbs.pop(idx))
break
else: else:
self.test = None logger.warn("SummaryWriter must be used! Insert a default one automatically.")
cbs.insert(0, SummaryWriter())
self.cbs = cbs
self.test_callback_context = TestCallbackContext()
def _before_train(self): def _before_train(self):
self.train.before_train() for cb in self.cbs:
if self.test: if isinstance(cb.type, TrainCallback):
self.test.before_train() cb.before_train()
else:
with self.test_callback_context.before_train_context():
cb.before_train()
def _after_train(self): def _after_train(self):
self.train.after_train() for cb in self.cbs:
if self.test: cb.after_train()
self.test.after_train()
logger.writer.close() logger.writer.close()
def trigger_step(self): def trigger_step(self):
self.train.trigger_step() for cb in self.cbs:
if isinstance(cb.type, TrainCallback):
cb.trigger_step()
# test callback don't have trigger_step # test callback don't have trigger_step
def _trigger_epoch(self): def _trigger_epoch(self):
self.train.trigger_epoch() tm = CallbackTimeLogger()
if self.test:
self.test.trigger_epoch() test_sess_restored = False
for cb in self.cbs:
if isinstance(cb.type, TrainCallback):
with tm.timed_callback(type(cb).__name__):
cb.trigger_epoch()
else:
if not test_sess_restored:
with tm.timed_callback('restore checkpoint'):
self.test_callback_context.restore_checkpoint()
test_sess_restored = True
with self.test_callback_context.trigger_epoch_context(), \
tm.timed_callback(type(cb).__name__):
cb.trigger_epoch()
tm.log()
logger.writer.flush() logger.writer.flush()
logger.stat_holder.finalize() logger.stat_holder.finalize()
...@@ -10,12 +10,12 @@ from tqdm import tqdm ...@@ -10,12 +10,12 @@ from tqdm import tqdm
from ..utils import * from ..utils import *
from ..utils.stat import * from ..utils.stat import *
from ..utils.summary import * from ..utils.summary import *
from .base import PeriodicCallback, Callback from .base import PeriodicCallback, Callback, TestCallback
__all__ = ['ValidationError', 'ValidationCallback'] __all__ = ['ValidationError', 'ValidationCallback']
class ValidationCallback(PeriodicCallback): class ValidationCallback(PeriodicCallback):
running_graph = 'test' type = TestCallback()
""" """
Basic routine for validation callbacks. Basic routine for validation callbacks.
""" """
...@@ -48,7 +48,7 @@ class ValidationCallback(PeriodicCallback): ...@@ -48,7 +48,7 @@ class ValidationCallback(PeriodicCallback):
output_vars = self._get_output_vars() output_vars = self._get_output_vars()
output_vars.append(self.cost_var) output_vars.append(self.cost_var)
with tqdm(total=self.ds.size()) as pbar: with tqdm(total=self.ds.size(), ascii=True) as pbar:
for dp in self.ds.get_data(): for dp in self.ds.get_data():
feed = dict(itertools.izip(self.input_vars, dp)) feed = dict(itertools.izip(self.input_vars, dp))
......
...@@ -12,7 +12,9 @@ from ..utils.symbolic_functions import * ...@@ -12,7 +12,9 @@ from ..utils.symbolic_functions import *
__all__ = ['FullyConnected'] __all__ = ['FullyConnected']
@layer_register(summary_activation=True) @layer_register(summary_activation=True)
def FullyConnected(x, out_dim, W_init=None, b_init=None, nl=tf.nn.relu): def FullyConnected(x, out_dim,
W_init=None, b_init=None,
nl=tf.nn.relu, use_bias=True):
x = batch_flatten(x) x = batch_flatten(x)
in_dim = x.get_shape().as_list()[1] in_dim = x.get_shape().as_list()[1]
...@@ -20,9 +22,11 @@ def FullyConnected(x, out_dim, W_init=None, b_init=None, nl=tf.nn.relu): ...@@ -20,9 +22,11 @@ def FullyConnected(x, out_dim, W_init=None, b_init=None, nl=tf.nn.relu):
#W_init = tf.truncated_normal_initializer(stddev=1 / math.sqrt(float(in_dim))) #W_init = tf.truncated_normal_initializer(stddev=1 / math.sqrt(float(in_dim)))
W_init = tf.uniform_unit_scaling_initializer() W_init = tf.uniform_unit_scaling_initializer()
if b_init is None: if b_init is None:
b_init = tf.constant_initializer(0.0) b_init = tf.constant_initializer()
with tf.device('/cpu:0'): with tf.device('/cpu:0'):
W = tf.get_variable('W', [in_dim, out_dim], initializer=W_init) W = tf.get_variable('W', [in_dim, out_dim], initializer=W_init)
b = tf.get_variable('b', [out_dim], initializer=b_init) if use_bias:
return nl(tf.nn.xw_plus_b(x, W, b), name=tf.get_variable_scope().name + '_output') b = tf.get_variable('b', [out_dim], initializer=b_init)
prod = tf.nn.xw_plus_b(x, W, b) if use_bias else tf.matmul(x, W)
return nl(prod, name=tf.get_variable_scope().name + '_output')
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment