Commit a976e871 authored by Yuxin Wu's avatar Yuxin Wu

split-out callbacks dir

parent eea48e2e
......@@ -13,8 +13,7 @@ from tensorpack.models import *
from tensorpack.utils import *
from tensorpack.utils.symbolic_functions import *
from tensorpack.utils.summary import *
from tensorpack.utils.callback import *
from tensorpack.utils.validation_callback import *
from tensorpack.callbacks import *
from tensorpack.dataflow import *
from tensorpack.dataflow import imgaug
......
......@@ -15,8 +15,7 @@ from tensorpack.models import *
from tensorpack.utils import *
from tensorpack.utils.symbolic_functions import *
from tensorpack.utils.summary import *
from tensorpack.utils.callback import *
from tensorpack.utils.validation_callback import *
from tensorpack.callbacks import *
from tensorpack.dataflow import *
BATCH_SIZE = 128
......
# !/usr/bin/env python2
# -*- coding: UTF-8 -*-
# File: __init__.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
from pkgutil import walk_packages
import os
def global_import(name):
p = __import__(name, globals(), locals())
lst = p.__all__ if '__all__' in dir(p) else dir(p)
for k in lst:
globals()[k] = p.__dict__[k]
for _, module_name, _ in walk_packages(
[os.path.dirname(__file__)]):
if not module_name.startswith('_'):
global_import(module_name)
#!/usr/bin/env python2
# -*- coding: UTF-8 -*-
# File: base.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
import tensorflow as tf
import sys
import os
import time
from abc import abstractmethod, ABCMeta
from ..utils import *
__all__ = ['Callback', 'PeriodicCallback']
class Callback(object):
__metaclass__ = ABCMeta
running_graph = 'train'
""" The graph that this callback should run on.
Either 'train' or 'test'
"""
def before_train(self):
self.graph = tf.get_default_graph()
self.sess = tf.get_default_session()
self._before_train()
def _before_train(self):
"""
Called before starting iterative training
"""
def trigger_step(self, inputs, outputs, cost):
"""
Callback to be triggered after every step (every backpropagation)
Args:
inputs: the list of input values
outputs: list of output values after running this inputs
cost: the cost value after running this input
"""
def trigger_epoch(self):
"""
Callback to be triggered after every epoch (full iteration of input dataset)
"""
class PeriodicCallback(Callback):
def __init__(self, period):
self.__period = period
self.epoch_num = 0
def trigger_epoch(self):
self.epoch_num += 1
if self.epoch_num % self.__period == 0:
self.global_step = get_global_step()
self._trigger()
@abstractmethod
def _trigger(self):
pass
#!/usr/bin/env python2
# -*- coding: UTF-8 -*-
# File: common.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
import tensorflow as tf
import os
from .base import Callback, PeriodicCallback
from ..utils import *
__all__ = ['PeriodicSaver', 'SummaryWriter']
class PeriodicSaver(PeriodicCallback):
def __init__(self, period=1, keep_recent=10, keep_freq=0.5):
super(PeriodicSaver, self).__init__(period)
self.path = os.path.join(logger.LOG_DIR, 'model')
self.keep_recent = keep_recent
self.keep_freq = keep_freq
def _before_train(self):
self.saver = tf.train.Saver(
max_to_keep=self.keep_recent,
keep_checkpoint_every_n_hours=self.keep_freq)
def _trigger(self):
self.saver.save(
tf.get_default_session(),
self.path,
global_step=self.global_step)
class SummaryWriter(Callback):
def __init__(self):
self.log_dir = logger.LOG_DIR
def _before_train(self):
self.writer = tf.train.SummaryWriter(
self.log_dir, graph_def=self.sess.graph_def)
tf.add_to_collection(SUMMARY_WRITER_COLLECTION_KEY, self.writer)
self.summary_op = tf.merge_all_summaries()
def trigger_epoch(self):
# check if there is any summary
if self.summary_op is None:
return
summary_str = self.summary_op.eval()
self.writer.add_summary(summary_str, get_global_step())
#!/usr/bin/env python2
# -*- coding: UTF-8 -*-
# File: callback.py
# File: group.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
import tensorflow as tf
import sys
import numpy as np
import os
import time
from abc import abstractmethod, ABCMeta
from . import create_test_session, get_global_step
from .naming import *
import logger
class Callback(object):
__metaclass__ = ABCMeta
running_graph = 'train'
""" The graph that this callback should run on.
Either 'train' or 'test'
"""
def before_train(self):
self.graph = tf.get_default_graph()
self.sess = tf.get_default_session()
self._before_train()
def _before_train(self):
"""
Called before starting iterative training
"""
def trigger_step(self, inputs, outputs, cost):
"""
Callback to be triggered after every step (every backpropagation)
Args:
inputs: the list of input values
outputs: list of output values after running this inputs
cost: the cost value after running this input
"""
def trigger_epoch(self):
"""
Callback to be triggered after every epoch (full iteration of input dataset)
"""
class PeriodicCallback(Callback):
def __init__(self, period):
self.__period = period
self.epoch_num = 0
def trigger_epoch(self):
self.epoch_num += 1
if self.epoch_num % self.__period == 0:
self.global_step = get_global_step()
self._trigger()
@abstractmethod
def _trigger(self):
pass
class PeriodicSaver(PeriodicCallback):
def __init__(self, period=1, keep_recent=10, keep_freq=0.5):
super(PeriodicSaver, self).__init__(period)
self.path = os.path.join(logger.LOG_DIR, 'model')
self.keep_recent = keep_recent
self.keep_freq = keep_freq
def _before_train(self):
self.saver = tf.train.Saver(
max_to_keep=self.keep_recent,
keep_checkpoint_every_n_hours=self.keep_freq)
def _trigger(self):
self.saver.save(
tf.get_default_session(),
self.path,
global_step=self.global_step)
class SummaryWriter(Callback):
def __init__(self):
self.log_dir = logger.LOG_DIR
def _before_train(self):
self.writer = tf.train.SummaryWriter(
self.log_dir, graph_def=self.sess.graph_def)
tf.add_to_collection(SUMMARY_WRITER_COLLECTION_KEY, self.writer)
self.summary_op = tf.merge_all_summaries()
def trigger_epoch(self):
# check if there is any summary
if self.summary_op is None:
return
summary_str = self.summary_op.eval()
self.writer.add_summary(summary_str, get_global_step())
from contextlib import contextmanager
from .base import Callback
from .common import *
from ..utils import *
__all__ = ['Callbacks']
@contextmanager
def create_test_graph():
G = tf.get_default_graph()
input_vars_train = G.get_collection(INPUT_VARS_KEY)
forward_func = G.get_collection(FORWARD_FUNC_KEY)[0]
with tf.Graph().as_default() as Gtest:
# create a global step var in test graph
global_step_var = tf.Variable(
0, trainable=False, name=GLOBAL_STEP_OP_NAME)
input_vars = []
for v in input_vars_train:
name = v.name
assert name.endswith(':0'), "I think placeholder variable should all ends with ':0'"
name = name[:-2]
input_vars.append(tf.placeholder(
v.dtype, shape=v.get_shape(), name=name
))
for v in input_vars:
Gtest.add_to_collection(INPUT_VARS_KEY, v)
output_vars, cost = forward_func(input_vars, is_training=False)
for v in output_vars:
Gtest.add_to_collection(OUTPUT_VARS_KEY, v)
yield Gtest
@contextmanager
def create_test_session():
with create_test_graph():
with tf.Session() as sess:
yield sess
class CallbackTimeLogger(object):
def __init__(self):
......@@ -126,7 +73,7 @@ class TrainCallbacks(Callback):
self.cbs.insert(0, self.cbs.pop(idx))
break
else:
raise RuntimeError("Callbacks must contain a SummaryWriter!")
raise ValueError("Callbacks must contain a SummaryWriter!")
def before_train(self):
for cb in self.cbs:
......@@ -199,7 +146,7 @@ class Callbacks(Callback):
elif cb.running_graph == 'train':
train_cbs.append(cb)
else:
raise RuntimeError(
raise ValueError(
"Unknown callback running graph {}!".format(cb.running_graph))
self.train = TrainCallbacks(train_cbs)
self.test = TestCallbacks(test_cbs)
......@@ -216,4 +163,3 @@ class Callbacks(Callback):
self.train.trigger_epoch()
# TODO test callbacks can be run async?
self.test.trigger_epoch()
......@@ -6,11 +6,12 @@
import tensorflow as tf
from tqdm import tqdm
from .stat import *
from .callback import PeriodicCallback, Callback
from .naming import *
from .summary import *
import logger
from ..utils import *
from ..utils.stat import *
from ..utils.summary import *
from .base import PeriodicCallback, Callback
__all__ = ['ValidationError']
class ValidationError(PeriodicCallback):
running_graph = 'test'
......
......@@ -18,7 +18,7 @@ def Conv2D(x, out_channel, kernel_shape,
kernel_shape: (h, w) or a int
stride: (h, w) or a int
padding: 'valid' or 'same'
split: split channels. used in alexnet
split: split channels. used in Alexnet
"""
in_shape = x.get_shape().as_list()
in_channel = in_shape[-1]
......@@ -31,7 +31,7 @@ def Conv2D(x, out_channel, kernel_shape,
stride = shape4d(stride)
if W_init is None:
W_init = tf.truncated_normal_initializer(stddev=1e-4)
W_init = tf.truncated_normal_initializer(stddev=4e-3)
if b_init is None:
b_init = tf.constant_initializer()
......
......@@ -10,7 +10,7 @@ import argparse
import tqdm
from utils import *
from utils.concurrency import EnqueueThread,coordinator_guard
from utils.callback import Callbacks
from callbacks import *
from utils.summary import summary_moving_average
from utils.modelutils import describe_model
from utils import logger
......
......@@ -31,36 +31,6 @@ def timed_operation(msg, log_start=False):
logger.info('finished {}, time={:.2f}sec.'.format(
msg, time.time() - start))
@contextmanager
def create_test_graph():
G = tf.get_default_graph()
input_vars_train = G.get_collection(INPUT_VARS_KEY)
forward_func = G.get_collection(FORWARD_FUNC_KEY)[0]
with tf.Graph().as_default() as Gtest:
# create a global step var in test graph
global_step_var = tf.Variable(
0, trainable=False, name=GLOBAL_STEP_OP_NAME)
input_vars = []
for v in input_vars_train:
name = v.name
assert name.endswith(':0'), "I think placeholder variable should all ends with ':0'"
name = name[:-2]
input_vars.append(tf.placeholder(
v.dtype, shape=v.get_shape(), name=name
))
for v in input_vars:
Gtest.add_to_collection(INPUT_VARS_KEY, v)
output_vars, cost = forward_func(input_vars, is_training=False)
for v in output_vars:
Gtest.add_to_collection(OUTPUT_VARS_KEY, v)
yield Gtest
@contextmanager
def create_test_session():
with create_test_graph():
with tf.Session() as sess:
yield sess
def get_default_sess_config():
"""
Return a better config to use as default.
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment