Commit 585f0837 authored by ppwwyyxx's avatar ppwwyyxx

is_training as a bool

parent dac78238
......@@ -24,22 +24,21 @@ from utils.concurrency import *
from dataflow.dataset import Mnist
from dataflow import *
def get_model(inputs):
# TODO is_training as a python variable
def get_model(inputs, is_training):
"""
Args:
inputs: a list of input variable,
e.g.: [image_var, label_var] with:
image_var: bx28x28
label_var: bx1 integer
is_training: a python bool variable
Returns:
(outputs, cost)
outputs: a list of output variable
cost: scalar variable
cost: the cost to minimize. scalar variable
"""
is_training = tf.get_default_graph().get_tensor_by_name(IS_TRAINING_VAR_NAME)
keep_prob = control_flow_ops.cond(
is_training, lambda: tf.constant(0.5), lambda: tf.constant(1.0), name='dropout_prob')
is_training = bool(is_training)
keep_prob = tf.constant(0.5 if is_training else 1.0)
image, label = inputs
image = tf.expand_dims(image, 3) # add a single channel
......@@ -77,19 +76,22 @@ def get_model(inputs):
name='regularize_loss')
tf.add_to_collection(COST_VARS_KEY, wd_cost)
add_histogram_summary('.*/W') # monitor histogram of all W
# this won't work with multigpu
#return [prob, nr_wrong], tf.add_n(tf.get_collection(COST_VARS_KEY), name='cost')
return [prob, nr_wrong], tf.add_n([wd_cost, cost], name='cost')
def get_config():
IMAGE_SIZE = 28
log_dir = os.path.join('train_log', os.path.basename(__file__)[:-3])
logger.set_logger_dir(log_dir)
IMAGE_SIZE = 28
BATCH_SIZE = 128
dataset_train = BatchData(Mnist('train'), BATCH_SIZE)
#dataset_train = FixedSizeData(dataset_train, 20)
dataset_test = BatchData(Mnist('test'), 256, remainder=True)
dataset_train = FixedSizeData(dataset_train, 20)
dataset_test = FixedSizeData(dataset_test, 20)
sess_config = tf.ConfigProto()
sess_config.device_count['GPU'] = 1
......@@ -98,14 +100,15 @@ def get_config():
sess_config.allow_soft_placement = True
# prepare model
image_var = tf.placeholder(
tf.float32, shape=(None, IMAGE_SIZE, IMAGE_SIZE), name='input')
label_var = tf.placeholder(
input_vars = [
tf.placeholder(
tf.float32, shape=(None, IMAGE_SIZE, IMAGE_SIZE), name='input'),
tf.placeholder(
tf.int32, shape=(None,), name='label')
input_vars = [image_var, label_var]
input_queue = tf.RandomShuffleQueue(100, 50, ['float32', 'int32'], name='queue')
]
input_queue = tf.RandomShuffleQueue(
100, 50, [x.dtype for x in input_vars], name='queue')
add_histogram_summary('.*/W') # monitor histogram of all W
global_step_var = tf.get_default_graph().get_tensor_by_name(GLOBAL_STEP_VAR_NAME)
lr = tf.train.exponential_decay(
learning_rate=1e-4,
......
......@@ -10,6 +10,9 @@ from utils import logger
__all__ = ['regularize_cost']
def regularize_cost(regex, func):
"""
Apply a regularizer on every trainable variable matching the regex
"""
G = tf.get_default_graph()
params = G.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)
......@@ -17,7 +20,7 @@ def regularize_cost(regex, func):
for p in params:
name = p.name
if re.search(regex, name):
logger.info("Weight decay for {}".format(name))
logger.info("Apply regularizer for {}".format(name))
costs.append(func(p))
return tf.add_n(costs)
......@@ -13,9 +13,6 @@ from itertools import count
import argparse
def prepare():
is_training = tf.constant(True, name=IS_TRAINING_OP_NAME)
#keep_prob = tf.placeholder(
#tf.float32, shape=tuple(), name=DROPOUT_PROB_OP_NAME)
global_step_var = tf.Variable(
0, trainable=False, name=GLOBAL_STEP_OP_NAME)
......@@ -40,7 +37,7 @@ def start_train(config):
sess_config = config.get('session_config', None)
assert isinstance(sess_config, tf.ConfigProto), sess_config.__class__
# a list of input/output variables
# input/output variables
input_vars = config['inputs']
input_queue = config['input_queue']
get_model_func = config['get_model_func']
......@@ -49,9 +46,10 @@ def start_train(config):
enqueue_op = input_queue.enqueue(tuple(input_vars))
model_inputs = input_queue.dequeue()
# set dequeue shape
for qv, v in zip(model_inputs, input_vars):
qv.set_shape(v.get_shape())
output_vars, cost_var = get_model_func(model_inputs)
output_vars, cost_var = get_model_func(model_inputs, is_training=True)
# build graph
G = tf.get_default_graph()
......@@ -84,19 +82,19 @@ def start_train(config):
# a thread that keeps filling the queue
th = EnqueueThread(sess, coord, enqueue_op, dataset_train)
with sess.as_default(), \
coordinator_context(
coordinator_guard(
sess, coord, th, input_queue):
callbacks.before_train()
for epoch in xrange(1, max_epoch):
with timed_operation('epoch {}'.format(epoch)):
for step in xrange(dataset_train.size()):
# TODO eval dequeue to get dp
fetches = [train_op, cost_var] + output_vars
feed = {IS_TRAINING_VAR_NAME: True}
results = sess.run(fetches, feed_dict=feed)
fetches = [train_op, cost_var] + output_vars + model_inputs
results = sess.run(fetches)
cost = results[1]
outputs = results[2:]
# TODO trigger_step
outputs = results[2:2 + len(output_vars)]
inputs = results[-len(model_inputs):]
callbacks.trigger_step(inputs, outputs, cost)
# note that summary_op will take a data from the queue.
callbacks.trigger_epoch()
sess.close()
......
......@@ -57,9 +57,7 @@ def create_test_graph():
))
for v in input_vars:
Gtest.add_to_collection(INPUT_VARS_KEY, v)
is_training = tf.constant(False, name=IS_TRAINING_OP_NAME)
output_vars, cost = forward_func(input_vars)
output_vars, cost = forward_func(input_vars, is_training=False)
for v in output_vars:
Gtest.add_to_collection(OUTPUT_VARS_KEY, v)
yield Gtest
......
......@@ -34,9 +34,9 @@ class Callback(object):
"""
Callback to be triggered after every step (every backpropagation)
Args:
inputs: the input dict fed into the graph
outputs: list of output values after running this dp
cost: the cost value after running this dp
inputs: the list of input values
outputs: list of output values after running this inputs
cost: the cost value after running this input
"""
def trigger_epoch(self):
......@@ -85,9 +85,7 @@ class SummaryWriter(Callback):
# check if there is any summary
if self.summary_op is None:
return
feed = {IS_TRAINING_VAR_NAME: True}
summary_str = self.summary_op.eval(feed_dict=feed)
summary_str = self.summary_op.eval()
self.epoch_num += 1
self.writer.add_summary(summary_str, self.epoch_num)
......@@ -102,9 +100,7 @@ class CallbackTimeLogger(object):
self.times.append((name, time))
def log(self):
"""
log the time of some heavy callbacks
"""
""" log the time of some heavy callbacks """
if self.tot < 3:
return
msgs = []
......@@ -162,12 +158,15 @@ class TestCallbacks(Callback):
def trigger_epoch(self):
tm = CallbackTimeLogger()
with self.graph.as_default():
with self.sess.as_default():
with self.graph.as_default(), self.sess.as_default():
s = time.time()
ckpt = tf.train.get_checkpoint_state(logger.LOG_DIR)
if ckpt is None:
from IPython import embed; embed()
logger.error(
"Cannot find a checkpoint state. Do you forget to use PeriodicSaver?")
return
logger.info(
"Restore checkpoint from {}".format(ckpt.model_checkpoint_path))
self.saver.restore(self.sess, ckpt.model_checkpoint_path)
tm.add('restore session', time.time() - s)
for cb in self.cbs:
......@@ -198,7 +197,7 @@ class Callbacks(Callback):
self.test.before_train()
def trigger_step(self, inputs, outputs, cost):
self.train.trigger_step()
self.train.trigger_step(inputs, outputs, cost)
# test callback don't have trigger_step
def trigger_epoch(self):
......
......@@ -45,9 +45,11 @@ class EnqueueThread(threading.Thread):
logger.exception("Exception in EnqueueThread:")
@contextmanager
def coordinator_context(sess, coord, thread, queue):
def coordinator_guard(sess, coord, thread, queue):
"""
Context manager to make sure queue is closed and thread is joined
Context manager to make sure that:
queue is closed
thread is joined
"""
thread.start()
try:
......
......@@ -19,7 +19,6 @@ def create_summary(name, v):
return s
def add_activation_summary(x, name=None):
# TODO dedup
"""
Summary for an activation tensor x.
If name is None, use x.name
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment