Commit 93020942 authored by ppwwyyxx's avatar ppwwyyxx

queue works with validation! double graph!

parent 86370a76
......@@ -8,6 +8,7 @@ from abc import abstractmethod
__all__ = ['DataFlow']
class DataFlow(object):
# TODO private impl
@abstractmethod
def get_data(self):
"""
......
......@@ -6,7 +6,7 @@
import numpy as np
from .base import DataFlow
__all__ = ['BatchData']
__all__ = ['BatchData', 'FixedSizeData']
class BatchData(DataFlow):
def __init__(self, ds, batch_size, remainder=False):
......@@ -46,3 +46,21 @@ class BatchData(DataFlow):
np.array([x[k] for x in data_holder],
dtype=data_holder[0][k].dtype))
return tuple(result)
class FixedSizeData(DataFlow):
def __init__(self, ds, size):
self.ds = ds
self._size = size
def size(self):
return self._size
def get_data(self):
cnt = 0
while True:
for dp in self.ds.get_data():
cnt += 1
yield dp
if cnt == self._size:
return
......@@ -18,11 +18,14 @@ from models import *
from utils import *
from utils.symbolic_functions import *
from utils.summary import *
from utils.callback import *
from utils.validation_callback import *
from utils.concurrency import *
from dataflow.dataset import Mnist
from dataflow import *
def get_model(inputs):
# TODO is_training as a python variable
"""
Args:
inputs: a list of input variable,
......@@ -41,12 +44,12 @@ def get_model(inputs):
image, label = inputs
image = tf.expand_dims(image, 3) # add a single channel
#conv0 = Conv2D('conv0', image, out_channel=32, kernel_shape=5)
#pool0 = MaxPooling('pool0', conv0, 2)
#conv1 = Conv2D('conv1', pool0, out_channel=40, kernel_shape=3)
#pool1 = MaxPooling('pool1', conv1, 2)
conv0 = Conv2D('conv0', image, out_channel=32, kernel_shape=5)
pool0 = MaxPooling('pool0', conv0, 2)
conv1 = Conv2D('conv1', pool0, out_channel=40, kernel_shape=3)
pool1 = MaxPooling('pool1', conv1, 2)
fc0 = FullyConnected('fc0', image, 1024)
fc0 = FullyConnected('fc0', pool1, 1024)
fc0 = tf.nn.dropout(fc0, keep_prob)
# fc will have activation summary by default. disable this for the output layer
......@@ -74,15 +77,18 @@ def get_model(inputs):
name='regularize_loss')
tf.add_to_collection(COST_VARS_KEY, wd_cost)
return [prob, nr_wrong], tf.add_n(tf.get_collection(COST_VARS_KEY), name='cost')
# this won't work with multigpu
#return [prob, nr_wrong], tf.add_n(tf.get_collection(COST_VARS_KEY), name='cost')
return [prob, nr_wrong], tf.add_n([wd_cost, cost], name='cost')
def get_config():
IMAGE_SIZE = 28
LOG_DIR = os.path.join('train_log', os.path.basename(__file__)[:-3])
log_dir = os.path.join('train_log', os.path.basename(__file__)[:-3])
logger.set_logger_dir(log_dir)
BATCH_SIZE = 128
logger.set_file(os.path.join(LOG_DIR, 'training.log'))
dataset_train = BatchData(Mnist('train'), BATCH_SIZE)
#dataset_train = FixedSizeData(dataset_train, 20)
dataset_test = BatchData(Mnist('test'), 256, remainder=True)
sess_config = tf.ConfigProto()
......@@ -111,11 +117,11 @@ def get_config():
return dict(
dataset_train=dataset_train,
optimizer=tf.train.AdamOptimizer(lr),
callbacks=[
SummaryWriter(LOG_DIR),
#ValidationError(dataset_test, prefix='test'),
PeriodicSaver(LOG_DIR),
],
callback=Callbacks([
SummaryWriter(),
PeriodicSaver(),
ValidationError(dataset_test, prefix='test'),
]),
session_config=sess_config,
inputs=input_vars,
input_queue=input_queue,
......
......@@ -7,6 +7,9 @@ import tensorflow as tf
from utils.summary import *
from utils import logger
# make sure each layer is only logged once
_layer_logged = set()
def layer_register(summary_activation=False):
"""
summary_activation: default behavior of whether to summary the output of this layer
......@@ -19,26 +22,29 @@ def layer_register(summary_activation=False):
do_summary = kwargs.pop(
'summary_activation', summary_activation)
inputs = args[0]
if isinstance(inputs, list):
shape_str = ",".join(
map(str(x.get_shape().as_list()), inputs))
else:
shape_str = str(inputs.get_shape().as_list())
logger.info("{} input: {}".format(name, shape_str))
with tf.variable_scope(name) as scope:
outputs = func(*args, **kwargs)
if isinstance(outputs, list):
shape_str = ",".join(
map(str(x.get_shape().as_list()), outputs))
if do_summary:
for x in outputs:
add_activation_summary(x, scope.name)
else:
shape_str = str(outputs.get_shape().as_list())
if do_summary:
add_activation_summary(outputs, scope.name)
logger.info("{} output: {}".format(name, shape_str))
if name not in _layer_logged:
# log shape info and add activation
if isinstance(inputs, list):
shape_str = ",".join(
map(str(x.get_shape().as_list()), inputs))
else:
shape_str = str(inputs.get_shape().as_list())
logger.info("{} input: {}".format(name, shape_str))
if isinstance(outputs, list):
shape_str = ",".join(
map(str(x.get_shape().as_list()), outputs))
if do_summary:
for x in outputs:
add_activation_summary(x, scope.name)
else:
shape_str = str(outputs.get_shape().as_list())
if do_summary:
add_activation_summary(outputs, scope.name)
logger.info("{} output: {}".format(name, shape_str))
_layer_logged.add(name)
return outputs
return inner
return wrapper
......
......@@ -6,12 +6,14 @@
import tensorflow as tf
from utils import *
from utils.concurrency import *
from utils.callback import *
from utils.summary import *
from dataflow import DataFlow
from itertools import count
import argparse
def prepare():
is_training = tf.placeholder(tf.bool, shape=(), name=IS_TRAINING_OP_NAME)
is_training = tf.constant(True, name=IS_TRAINING_OP_NAME)
#keep_prob = tf.placeholder(
#tf.float32, shape=tuple(), name=DROPOUT_PROB_OP_NAME)
global_step_var = tf.Variable(
......@@ -32,7 +34,7 @@ def start_train(config):
assert isinstance(optimizer, tf.train.Optimizer), optimizer.__class__
# a list of Callback instance
callbacks = Callbacks(config.get('callbacks', []))
callbacks = config['callback']
# a tf.ConfigProto instance
sess_config = config.get('session_config', None)
......@@ -53,6 +55,7 @@ def start_train(config):
# build graph
G = tf.get_default_graph()
G.add_to_collection(FORWARD_FUNC_KEY, get_model_func)
for v in input_vars:
G.add_to_collection(INPUT_VARS_KEY, v)
for v in output_vars:
......@@ -74,30 +77,28 @@ def start_train(config):
train_op = optimizer.apply_gradients(grads, global_step_var)
sess = tf.Session(config=sess_config)
# start training
with sess.as_default():
sess.run(tf.initialize_all_variables())
sess.run(tf.initialize_all_variables())
# start training:
coord = tf.train.Coordinator()
# a thread that keeps filling the queue
th = EnqueueThread(sess, coord, enqueue_op, dataset_train)
with sess.as_default(), \
coordinator_context(
sess, coord, th, input_queue):
callbacks.before_train()
for epoch in xrange(1, max_epoch):
coord = tf.train.Coordinator()
th = EnqueueThread(sess, coord, enqueue_op, dataset_train)
with timed_operation('epoch {}'.format(epoch)), \
coordinator_context(
sess, coord, th, input_queue):
with timed_operation('epoch {}'.format(epoch)):
for step in xrange(dataset_train.size()):
# TODO eval dequeue to get dp
fetches = [train_op, cost_var] + output_vars
results = sess.run(fetches,
feed_dict={IS_TRAINING_VAR_NAME: True})
feed = {IS_TRAINING_VAR_NAME: True}
results = sess.run(fetches, feed_dict=feed)
cost = results[1]
outputs = results[2:]
print tf.train.global_step(sess, global_step_var), cost
# trigger_step
coord.request_stop()
# summary will take a data from the queue
# TODO trigger_step
# note that summary_op will take a data from the queue.
callbacks.trigger_epoch()
print "Finish callback"
sess.close()
def main(get_config_func):
......
......@@ -16,11 +16,7 @@ def global_import(name):
lst = p.__all__ if '__all__' in dir(p) else dir(p)
for k in lst:
globals()[k] = p.__dict__[k]
global_import('naming')
global_import('callback')
global_import('validation_callback')
@contextmanager
def timed_operation(msg, log_start=False):
......@@ -44,3 +40,32 @@ def describe_model():
msg.append("Total dim={}".format(total))
logger.info("Model Params: {}".format('\n'.join(msg)))
# TODO disable shape output in get_model
@contextmanager
def create_test_graph():
G = tf.get_default_graph()
input_vars_train = G.get_collection(INPUT_VARS_KEY)
forward_func = G.get_collection(FORWARD_FUNC_KEY)[0]
with tf.Graph().as_default() as Gtest:
input_vars = []
for v in input_vars_train:
name = v.name
assert name.endswith(':0'), "I think placeholder variable should all ends with ':0'"
name = name[:-2]
input_vars.append(tf.placeholder(
v.dtype, shape=v.get_shape(), name=name
))
for v in input_vars:
Gtest.add_to_collection(INPUT_VARS_KEY, v)
is_training = tf.constant(False, name=IS_TRAINING_OP_NAME)
output_vars, cost = forward_func(input_vars)
for v in output_vars:
Gtest.add_to_collection(OUTPUT_VARS_KEY, v)
yield Gtest
@contextmanager
def create_test_session():
with create_test_graph():
with tf.Session() as sess:
yield sess
......@@ -10,10 +10,16 @@ import os
import time
from abc import abstractmethod
from . import create_test_session
from .naming import *
import logger
class Callback(object):
running_graph = 'train'
""" The graph that this callback should run on.
Either 'train' or 'test'
"""
def before_train(self):
self.graph = tf.get_default_graph()
self.sess = tf.get_default_session()
......@@ -53,20 +59,20 @@ class PeriodicCallback(Callback):
pass
class PeriodicSaver(PeriodicCallback):
def __init__(self, log_dir, period=1):
def __init__(self, period=1):
super(PeriodicSaver, self).__init__(period)
self.path = os.path.join(log_dir, 'model')
self.path = os.path.join(logger.LOG_DIR, 'model')
def _before_train(self):
self.saver = tf.train.Saver(max_to_keep=99999)
def _trigger(self):
self.saver.save(tf.get_default_session(), self.path,
global_step=self.epoch_num, latest_filename='latest')
global_step=self.epoch_num)
class SummaryWriter(Callback):
def __init__(self, log_dir):
self.log_dir = log_dir
def __init__(self):
self.log_dir = logger.LOG_DIR
self.epoch_num = 0
def _before_train(self):
......@@ -80,52 +86,122 @@ class SummaryWriter(Callback):
if self.summary_op is None:
return
summary_str = self.summary_op.eval(
feed_dict={IS_TRAINING_VAR_NAME: True})
feed = {IS_TRAINING_VAR_NAME: True}
summary_str = self.summary_op.eval(feed_dict=feed)
self.epoch_num += 1
self.writer.add_summary(summary_str, self.epoch_num)
class Callbacks(Callback):
def __init__(self, callbacks):
for cb in callbacks:
assert isinstance(cb, Callback), cb.__class__
class CallbackTimeLogger(object):
def __init__(self):
self.times = []
self.tot = 0
def add(self, name, time):
self.tot += time
self.times.append((name, time))
def log(self):
"""
log the time of some heavy callbacks
"""
if self.tot < 3:
return
msgs = []
for name, t in self.times:
if t / self.tot > 0.3 and t > 1:
msgs.append("{}:{}".format(name, t))
logger.info(
"Callbacks took {} sec. {}".format(
self.tot, ' '.join(msgs)))
class TrainCallbacks(Callback):
def __init__(self, callbacks):
self.cbs = callbacks
# put SummaryWriter to the first
for idx, cb in enumerate(callbacks):
for idx, cb in enumerate(self.cbs):
if type(cb) == SummaryWriter:
callbacks.insert(0, callbacks.pop(idx))
self.cbs.insert(0, self.cbs.pop(idx))
break
else:
raise RuntimeError("callbacks must contain a SummaryWriter!")
self.callbacks = callbacks
raise RuntimeError("Callbacks must contain a SummaryWriter!")
def before_train(self):
for cb in self.callbacks:
for cb in self.cbs:
cb.before_train()
self.writer = tf.get_collection(SUMMARY_WRITER_COLLECTION_KEY)[0]
def trigger_step(self, inputs, outputs, cost):
for cb in self.callbacks:
for cb in self.cbs:
cb.trigger_step(inputs, outputs, cost)
def trigger_epoch(self):
start = time.time()
times = []
for cb in self.callbacks:
tm = CallbackTimeLogger()
for cb in self.cbs:
s = time.time()
cb.trigger_epoch()
times.append(time.time() - s)
tm.add(type(cb).__name__, time.time() - s)
self.writer.flush()
tot = time.time() - start
tm.log()
# log the time of some heavy callbacks
if tot < 3:
return
msgs = []
for idx, t in enumerate(times):
if t / tot > 0.3 and t > 1:
msgs.append("{}:{}".format(
type(self.callbacks[idx]).__name__, t))
logger.info("Callbacks took {} sec. {}".format(tot, ' '.join(msgs)))
class TestCallbacks(Callback):
def __init__(self, callbacks):
self.cbs = callbacks
def before_train(self):
self.writer = tf.get_collection(SUMMARY_WRITER_COLLECTION_KEY)[0]
with create_test_session() as sess:
self.sess = sess
self.graph = sess.graph
self.saver = tf.train.Saver()
tf.add_to_collection(SUMMARY_WRITER_COLLECTION_KEY, self.writer)
for cb in self.cbs:
cb.before_train()
def trigger_epoch(self):
tm = CallbackTimeLogger()
with self.graph.as_default():
with self.sess.as_default():
s = time.time()
ckpt = tf.train.get_checkpoint_state(logger.LOG_DIR)
if ckpt is None:
from IPython import embed; embed()
self.saver.restore(self.sess, ckpt.model_checkpoint_path)
tm.add('restore session', time.time() - s)
for cb in self.cbs:
s = time.time()
cb.trigger_epoch()
tm.add(type(cb).__name__, time.time() - s)
self.writer.flush()
tm.log()
class Callbacks(Callback):
def __init__(self, cbs):
train_cbs = []
test_cbs = []
for cb in cbs:
assert isinstance(cb, Callback), cb.__class__
if cb.running_graph == 'test':
test_cbs.append(cb)
elif cb.running_graph == 'train':
train_cbs.append(cb)
else:
raise RuntimeError(
"Unknown callback running graph {}!".format(cb.running_graph))
self.train = TrainCallbacks(train_cbs)
self.test = TestCallbacks(test_cbs)
def before_train(self):
self.train.before_train()
self.test.before_train()
def trigger_step(self, inputs, outputs, cost):
self.train.trigger_step()
# test callback don't have trigger_step
def trigger_epoch(self):
self.train.trigger_epoch()
self.test.trigger_epoch()
#!/usr/bin/env python2
# -*- coding: UTF-8 -*-
# File: concurrency.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
import threading
from contextlib import contextmanager
import tensorflow as tf
from .naming import *
import logger
class StoppableThread(threading.Thread):
def __init__(self):
super(StoppableThread, self).__init__()
self._stop = threading.Event()
def stop(self):
self._stop.set()
def stopped(self):
return self._stop.isSet()
class EnqueueThread(threading.Thread):
def __init__(self, sess, coord, enqueue_op, dataflow):
super(EnqueueThread, self).__init__()
self.sess = sess
self.coord = coord
self.input_vars = sess.graph.get_collection(INPUT_VARS_KEY)
self.dataflow = dataflow
self.op = enqueue_op
def run(self):
try:
while True:
for dp in self.dataflow.get_data():
if self.coord.should_stop():
return
feed = dict(zip(self.input_vars, dp))
self.sess.run([self.op], feed_dict=feed)
except tf.errors.CancelledError as e:
pass
except Exception:
logger.exception("Exception in EnqueueThread:")
@contextmanager
def coordinator_context(sess, coord, thread, queue):
"""
Context manager to make sure queue is closed and thread is joined
"""
thread.start()
try:
yield
except (KeyboardInterrupt, Exception) as e:
raise
finally:
coord.request_stop()
sess.run(
queue.close(cancel_pending_enqueues=True))
coord.join([thread])
......@@ -50,3 +50,11 @@ def set_file(path):
hdl = logging.FileHandler(
filename=path, encoding='utf-8', mode='w')
logger.addHandler(hdl)
global LOG_DIR
LOG_DIR = "train_log"
def set_logger_dir(dirname):
global LOG_DIR
LOG_DIR = dirname
set_file(os.path.join(LOG_DIR, 'training.log'))
......@@ -15,6 +15,7 @@ INPUT_VARS_KEY = 'INPUT_VARIABLES'
OUTPUT_VARS_KEY = 'OUTPUT_VARIABLES'
COST_VARS_KEY = 'COST_VARIABLES' # keep track of each individual cost
SUMMARY_VARS_KEY = 'SUMMARY_VARIABLES' # extra variables to summarize during training
FORWARD_FUNC_KEY = 'FORWARD_FUNCTION'
# export all upper case variables
all_local_names = locals().keys()
......
......@@ -19,6 +19,7 @@ def create_summary(name, v):
return s
def add_activation_summary(x, name=None):
# TODO dedup
"""
Summary for an activation tensor x.
If name is None, use x.name
......
......@@ -11,6 +11,7 @@ from .summary import *
import logger
class ValidationError(PeriodicCallback):
running_graph = 'test'
"""
Validate the accuracy for the given wrong and cost variable
Use under the following setup:
......@@ -33,7 +34,6 @@ class ValidationError(PeriodicCallback):
def _before_train(self):
self.input_vars = tf.get_collection(INPUT_VARS_KEY)
self.is_training_var = self.get_tensor(IS_TRAINING_VAR_NAME)
self.wrong_var = self.get_tensor(self.wrong_var_name)
self.cost_var = self.get_tensor(self.cost_var_name)
self.writer = tf.get_collection(SUMMARY_WRITER_COLLECTION_KEY)[0]
......@@ -43,8 +43,7 @@ class ValidationError(PeriodicCallback):
err_stat = Accuracy()
cost_sum = 0
for dp in self.ds.get_data():
feed = {self.is_training_var: False}
feed.update(dict(zip(self.input_vars, dp)))
feed = dict(zip(self.input_vars, dp))
batch_size = dp[0].shape[0] # assume batched input
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment