Commit a976e871 authored by Yuxin Wu's avatar Yuxin Wu

split-out callbacks dir

parent eea48e2e
...@@ -13,8 +13,7 @@ from tensorpack.models import * ...@@ -13,8 +13,7 @@ from tensorpack.models import *
from tensorpack.utils import * from tensorpack.utils import *
from tensorpack.utils.symbolic_functions import * from tensorpack.utils.symbolic_functions import *
from tensorpack.utils.summary import * from tensorpack.utils.summary import *
from tensorpack.utils.callback import * from tensorpack.callbacks import *
from tensorpack.utils.validation_callback import *
from tensorpack.dataflow import * from tensorpack.dataflow import *
from tensorpack.dataflow import imgaug from tensorpack.dataflow import imgaug
......
...@@ -15,8 +15,7 @@ from tensorpack.models import * ...@@ -15,8 +15,7 @@ from tensorpack.models import *
from tensorpack.utils import * from tensorpack.utils import *
from tensorpack.utils.symbolic_functions import * from tensorpack.utils.symbolic_functions import *
from tensorpack.utils.summary import * from tensorpack.utils.summary import *
from tensorpack.utils.callback import * from tensorpack.callbacks import *
from tensorpack.utils.validation_callback import *
from tensorpack.dataflow import * from tensorpack.dataflow import *
BATCH_SIZE = 128 BATCH_SIZE = 128
......
# !/usr/bin/env python2
# -*- coding: UTF-8 -*-
# File: __init__.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
from pkgutil import walk_packages
import os
def global_import(name):
p = __import__(name, globals(), locals())
lst = p.__all__ if '__all__' in dir(p) else dir(p)
for k in lst:
globals()[k] = p.__dict__[k]
for _, module_name, _ in walk_packages(
[os.path.dirname(__file__)]):
if not module_name.startswith('_'):
global_import(module_name)
#!/usr/bin/env python2
# -*- coding: UTF-8 -*-
# File: base.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
import tensorflow as tf
import sys
import os
import time
from abc import abstractmethod, ABCMeta
from ..utils import *
__all__ = ['Callback', 'PeriodicCallback']
class Callback(object):
__metaclass__ = ABCMeta
running_graph = 'train'
""" The graph that this callback should run on.
Either 'train' or 'test'
"""
def before_train(self):
self.graph = tf.get_default_graph()
self.sess = tf.get_default_session()
self._before_train()
def _before_train(self):
"""
Called before starting iterative training
"""
def trigger_step(self, inputs, outputs, cost):
"""
Callback to be triggered after every step (every backpropagation)
Args:
inputs: the list of input values
outputs: list of output values after running this inputs
cost: the cost value after running this input
"""
def trigger_epoch(self):
"""
Callback to be triggered after every epoch (full iteration of input dataset)
"""
class PeriodicCallback(Callback):
def __init__(self, period):
self.__period = period
self.epoch_num = 0
def trigger_epoch(self):
self.epoch_num += 1
if self.epoch_num % self.__period == 0:
self.global_step = get_global_step()
self._trigger()
@abstractmethod
def _trigger(self):
pass
#!/usr/bin/env python2
# -*- coding: UTF-8 -*-
# File: common.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
import tensorflow as tf
import os
from .base import Callback, PeriodicCallback
from ..utils import *
__all__ = ['PeriodicSaver', 'SummaryWriter']
class PeriodicSaver(PeriodicCallback):
def __init__(self, period=1, keep_recent=10, keep_freq=0.5):
super(PeriodicSaver, self).__init__(period)
self.path = os.path.join(logger.LOG_DIR, 'model')
self.keep_recent = keep_recent
self.keep_freq = keep_freq
def _before_train(self):
self.saver = tf.train.Saver(
max_to_keep=self.keep_recent,
keep_checkpoint_every_n_hours=self.keep_freq)
def _trigger(self):
self.saver.save(
tf.get_default_session(),
self.path,
global_step=self.global_step)
class SummaryWriter(Callback):
def __init__(self):
self.log_dir = logger.LOG_DIR
def _before_train(self):
self.writer = tf.train.SummaryWriter(
self.log_dir, graph_def=self.sess.graph_def)
tf.add_to_collection(SUMMARY_WRITER_COLLECTION_KEY, self.writer)
self.summary_op = tf.merge_all_summaries()
def trigger_epoch(self):
# check if there is any summary
if self.summary_op is None:
return
summary_str = self.summary_op.eval()
self.writer.add_summary(summary_str, get_global_step())
#!/usr/bin/env python2 #!/usr/bin/env python2
# -*- coding: UTF-8 -*- # -*- coding: UTF-8 -*-
# File: callback.py # File: group.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com> # Author: Yuxin Wu <ppwwyyxx@gmail.com>
import tensorflow as tf import tensorflow as tf
import sys from contextlib import contextmanager
import numpy as np
import os from .base import Callback
import time from .common import *
from abc import abstractmethod, ABCMeta from ..utils import *
from . import create_test_session, get_global_step __all__ = ['Callbacks']
from .naming import *
import logger @contextmanager
def create_test_graph():
class Callback(object): G = tf.get_default_graph()
__metaclass__ = ABCMeta input_vars_train = G.get_collection(INPUT_VARS_KEY)
running_graph = 'train' forward_func = G.get_collection(FORWARD_FUNC_KEY)[0]
""" The graph that this callback should run on. with tf.Graph().as_default() as Gtest:
Either 'train' or 'test' # create a global step var in test graph
""" global_step_var = tf.Variable(
0, trainable=False, name=GLOBAL_STEP_OP_NAME)
def before_train(self): input_vars = []
self.graph = tf.get_default_graph() for v in input_vars_train:
self.sess = tf.get_default_session() name = v.name
self._before_train() assert name.endswith(':0'), "I think placeholder variable should all ends with ':0'"
name = name[:-2]
def _before_train(self): input_vars.append(tf.placeholder(
""" v.dtype, shape=v.get_shape(), name=name
Called before starting iterative training ))
""" for v in input_vars:
Gtest.add_to_collection(INPUT_VARS_KEY, v)
def trigger_step(self, inputs, outputs, cost): output_vars, cost = forward_func(input_vars, is_training=False)
""" for v in output_vars:
Callback to be triggered after every step (every backpropagation) Gtest.add_to_collection(OUTPUT_VARS_KEY, v)
Args: yield Gtest
inputs: the list of input values
outputs: list of output values after running this inputs @contextmanager
cost: the cost value after running this input def create_test_session():
""" with create_test_graph():
with tf.Session() as sess:
def trigger_epoch(self): yield sess
"""
Callback to be triggered after every epoch (full iteration of input dataset)
"""
class PeriodicCallback(Callback):
def __init__(self, period):
self.__period = period
self.epoch_num = 0
def trigger_epoch(self):
self.epoch_num += 1
if self.epoch_num % self.__period == 0:
self.global_step = get_global_step()
self._trigger()
@abstractmethod
def _trigger(self):
pass
class PeriodicSaver(PeriodicCallback):
def __init__(self, period=1, keep_recent=10, keep_freq=0.5):
super(PeriodicSaver, self).__init__(period)
self.path = os.path.join(logger.LOG_DIR, 'model')
self.keep_recent = keep_recent
self.keep_freq = keep_freq
def _before_train(self):
self.saver = tf.train.Saver(
max_to_keep=self.keep_recent,
keep_checkpoint_every_n_hours=self.keep_freq)
def _trigger(self):
self.saver.save(
tf.get_default_session(),
self.path,
global_step=self.global_step)
class SummaryWriter(Callback):
def __init__(self):
self.log_dir = logger.LOG_DIR
def _before_train(self):
self.writer = tf.train.SummaryWriter(
self.log_dir, graph_def=self.sess.graph_def)
tf.add_to_collection(SUMMARY_WRITER_COLLECTION_KEY, self.writer)
self.summary_op = tf.merge_all_summaries()
def trigger_epoch(self):
# check if there is any summary
if self.summary_op is None:
return
summary_str = self.summary_op.eval()
self.writer.add_summary(summary_str, get_global_step())
class CallbackTimeLogger(object): class CallbackTimeLogger(object):
def __init__(self): def __init__(self):
...@@ -126,7 +73,7 @@ class TrainCallbacks(Callback): ...@@ -126,7 +73,7 @@ class TrainCallbacks(Callback):
self.cbs.insert(0, self.cbs.pop(idx)) self.cbs.insert(0, self.cbs.pop(idx))
break break
else: else:
raise RuntimeError("Callbacks must contain a SummaryWriter!") raise ValueError("Callbacks must contain a SummaryWriter!")
def before_train(self): def before_train(self):
for cb in self.cbs: for cb in self.cbs:
...@@ -199,7 +146,7 @@ class Callbacks(Callback): ...@@ -199,7 +146,7 @@ class Callbacks(Callback):
elif cb.running_graph == 'train': elif cb.running_graph == 'train':
train_cbs.append(cb) train_cbs.append(cb)
else: else:
raise RuntimeError( raise ValueError(
"Unknown callback running graph {}!".format(cb.running_graph)) "Unknown callback running graph {}!".format(cb.running_graph))
self.train = TrainCallbacks(train_cbs) self.train = TrainCallbacks(train_cbs)
self.test = TestCallbacks(test_cbs) self.test = TestCallbacks(test_cbs)
...@@ -216,4 +163,3 @@ class Callbacks(Callback): ...@@ -216,4 +163,3 @@ class Callbacks(Callback):
self.train.trigger_epoch() self.train.trigger_epoch()
# TODO test callbacks can be run async? # TODO test callbacks can be run async?
self.test.trigger_epoch() self.test.trigger_epoch()
...@@ -6,11 +6,12 @@ ...@@ -6,11 +6,12 @@
import tensorflow as tf import tensorflow as tf
from tqdm import tqdm from tqdm import tqdm
from .stat import * from ..utils import *
from .callback import PeriodicCallback, Callback from ..utils.stat import *
from .naming import * from ..utils.summary import *
from .summary import * from .base import PeriodicCallback, Callback
import logger
__all__ = ['ValidationError']
class ValidationError(PeriodicCallback): class ValidationError(PeriodicCallback):
running_graph = 'test' running_graph = 'test'
......
...@@ -18,7 +18,7 @@ def Conv2D(x, out_channel, kernel_shape, ...@@ -18,7 +18,7 @@ def Conv2D(x, out_channel, kernel_shape,
kernel_shape: (h, w) or a int kernel_shape: (h, w) or a int
stride: (h, w) or a int stride: (h, w) or a int
padding: 'valid' or 'same' padding: 'valid' or 'same'
split: split channels. used in alexnet split: split channels. used in Alexnet
""" """
in_shape = x.get_shape().as_list() in_shape = x.get_shape().as_list()
in_channel = in_shape[-1] in_channel = in_shape[-1]
...@@ -31,7 +31,7 @@ def Conv2D(x, out_channel, kernel_shape, ...@@ -31,7 +31,7 @@ def Conv2D(x, out_channel, kernel_shape,
stride = shape4d(stride) stride = shape4d(stride)
if W_init is None: if W_init is None:
W_init = tf.truncated_normal_initializer(stddev=1e-4) W_init = tf.truncated_normal_initializer(stddev=4e-3)
if b_init is None: if b_init is None:
b_init = tf.constant_initializer() b_init = tf.constant_initializer()
......
...@@ -10,7 +10,7 @@ import argparse ...@@ -10,7 +10,7 @@ import argparse
import tqdm import tqdm
from utils import * from utils import *
from utils.concurrency import EnqueueThread,coordinator_guard from utils.concurrency import EnqueueThread,coordinator_guard
from utils.callback import Callbacks from callbacks import *
from utils.summary import summary_moving_average from utils.summary import summary_moving_average
from utils.modelutils import describe_model from utils.modelutils import describe_model
from utils import logger from utils import logger
......
...@@ -31,36 +31,6 @@ def timed_operation(msg, log_start=False): ...@@ -31,36 +31,6 @@ def timed_operation(msg, log_start=False):
logger.info('finished {}, time={:.2f}sec.'.format( logger.info('finished {}, time={:.2f}sec.'.format(
msg, time.time() - start)) msg, time.time() - start))
@contextmanager
def create_test_graph():
G = tf.get_default_graph()
input_vars_train = G.get_collection(INPUT_VARS_KEY)
forward_func = G.get_collection(FORWARD_FUNC_KEY)[0]
with tf.Graph().as_default() as Gtest:
# create a global step var in test graph
global_step_var = tf.Variable(
0, trainable=False, name=GLOBAL_STEP_OP_NAME)
input_vars = []
for v in input_vars_train:
name = v.name
assert name.endswith(':0'), "I think placeholder variable should all ends with ':0'"
name = name[:-2]
input_vars.append(tf.placeholder(
v.dtype, shape=v.get_shape(), name=name
))
for v in input_vars:
Gtest.add_to_collection(INPUT_VARS_KEY, v)
output_vars, cost = forward_func(input_vars, is_training=False)
for v in output_vars:
Gtest.add_to_collection(OUTPUT_VARS_KEY, v)
yield Gtest
@contextmanager
def create_test_session():
with create_test_graph():
with tf.Session() as sess:
yield sess
def get_default_sess_config(): def get_default_sess_config():
""" """
Return a better config to use as default. Return a better config to use as default.
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment