Commit 585f0837 authored by ppwwyyxx's avatar ppwwyyxx

is_training as a bool

parent dac78238
...@@ -24,22 +24,21 @@ from utils.concurrency import * ...@@ -24,22 +24,21 @@ from utils.concurrency import *
from dataflow.dataset import Mnist from dataflow.dataset import Mnist
from dataflow import * from dataflow import *
def get_model(inputs): def get_model(inputs, is_training):
# TODO is_training as a python variable
""" """
Args: Args:
inputs: a list of input variable, inputs: a list of input variable,
e.g.: [image_var, label_var] with: e.g.: [image_var, label_var] with:
image_var: bx28x28 image_var: bx28x28
label_var: bx1 integer label_var: bx1 integer
is_training: a python bool variable
Returns: Returns:
(outputs, cost) (outputs, cost)
outputs: a list of output variable outputs: a list of output variable
cost: scalar variable cost: the cost to minimize. scalar variable
""" """
is_training = tf.get_default_graph().get_tensor_by_name(IS_TRAINING_VAR_NAME) is_training = bool(is_training)
keep_prob = control_flow_ops.cond( keep_prob = tf.constant(0.5 if is_training else 1.0)
is_training, lambda: tf.constant(0.5), lambda: tf.constant(1.0), name='dropout_prob')
image, label = inputs image, label = inputs
image = tf.expand_dims(image, 3) # add a single channel image = tf.expand_dims(image, 3) # add a single channel
...@@ -77,19 +76,22 @@ def get_model(inputs): ...@@ -77,19 +76,22 @@ def get_model(inputs):
name='regularize_loss') name='regularize_loss')
tf.add_to_collection(COST_VARS_KEY, wd_cost) tf.add_to_collection(COST_VARS_KEY, wd_cost)
add_histogram_summary('.*/W') # monitor histogram of all W
# this won't work with multigpu # this won't work with multigpu
#return [prob, nr_wrong], tf.add_n(tf.get_collection(COST_VARS_KEY), name='cost') #return [prob, nr_wrong], tf.add_n(tf.get_collection(COST_VARS_KEY), name='cost')
return [prob, nr_wrong], tf.add_n([wd_cost, cost], name='cost') return [prob, nr_wrong], tf.add_n([wd_cost, cost], name='cost')
def get_config(): def get_config():
IMAGE_SIZE = 28
log_dir = os.path.join('train_log', os.path.basename(__file__)[:-3]) log_dir = os.path.join('train_log', os.path.basename(__file__)[:-3])
logger.set_logger_dir(log_dir) logger.set_logger_dir(log_dir)
IMAGE_SIZE = 28
BATCH_SIZE = 128 BATCH_SIZE = 128
dataset_train = BatchData(Mnist('train'), BATCH_SIZE) dataset_train = BatchData(Mnist('train'), BATCH_SIZE)
#dataset_train = FixedSizeData(dataset_train, 20)
dataset_test = BatchData(Mnist('test'), 256, remainder=True) dataset_test = BatchData(Mnist('test'), 256, remainder=True)
dataset_train = FixedSizeData(dataset_train, 20)
dataset_test = FixedSizeData(dataset_test, 20)
sess_config = tf.ConfigProto() sess_config = tf.ConfigProto()
sess_config.device_count['GPU'] = 1 sess_config.device_count['GPU'] = 1
...@@ -98,14 +100,15 @@ def get_config(): ...@@ -98,14 +100,15 @@ def get_config():
sess_config.allow_soft_placement = True sess_config.allow_soft_placement = True
# prepare model # prepare model
image_var = tf.placeholder( input_vars = [
tf.float32, shape=(None, IMAGE_SIZE, IMAGE_SIZE), name='input') tf.placeholder(
label_var = tf.placeholder( tf.float32, shape=(None, IMAGE_SIZE, IMAGE_SIZE), name='input'),
tf.int32, shape=(None,), name='label') tf.placeholder(
input_vars = [image_var, label_var] tf.int32, shape=(None,), name='label')
input_queue = tf.RandomShuffleQueue(100, 50, ['float32', 'int32'], name='queue') ]
input_queue = tf.RandomShuffleQueue(
add_histogram_summary('.*/W') # monitor histogram of all W 100, 50, [x.dtype for x in input_vars], name='queue')
global_step_var = tf.get_default_graph().get_tensor_by_name(GLOBAL_STEP_VAR_NAME) global_step_var = tf.get_default_graph().get_tensor_by_name(GLOBAL_STEP_VAR_NAME)
lr = tf.train.exponential_decay( lr = tf.train.exponential_decay(
learning_rate=1e-4, learning_rate=1e-4,
......
...@@ -10,6 +10,9 @@ from utils import logger ...@@ -10,6 +10,9 @@ from utils import logger
__all__ = ['regularize_cost'] __all__ = ['regularize_cost']
def regularize_cost(regex, func): def regularize_cost(regex, func):
"""
Apply a regularizer on every trainable variable matching the regex
"""
G = tf.get_default_graph() G = tf.get_default_graph()
params = G.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES) params = G.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)
...@@ -17,7 +20,7 @@ def regularize_cost(regex, func): ...@@ -17,7 +20,7 @@ def regularize_cost(regex, func):
for p in params: for p in params:
name = p.name name = p.name
if re.search(regex, name): if re.search(regex, name):
logger.info("Weight decay for {}".format(name)) logger.info("Apply regularizer for {}".format(name))
costs.append(func(p)) costs.append(func(p))
return tf.add_n(costs) return tf.add_n(costs)
...@@ -13,9 +13,6 @@ from itertools import count ...@@ -13,9 +13,6 @@ from itertools import count
import argparse import argparse
def prepare(): def prepare():
is_training = tf.constant(True, name=IS_TRAINING_OP_NAME)
#keep_prob = tf.placeholder(
#tf.float32, shape=tuple(), name=DROPOUT_PROB_OP_NAME)
global_step_var = tf.Variable( global_step_var = tf.Variable(
0, trainable=False, name=GLOBAL_STEP_OP_NAME) 0, trainable=False, name=GLOBAL_STEP_OP_NAME)
...@@ -40,7 +37,7 @@ def start_train(config): ...@@ -40,7 +37,7 @@ def start_train(config):
sess_config = config.get('session_config', None) sess_config = config.get('session_config', None)
assert isinstance(sess_config, tf.ConfigProto), sess_config.__class__ assert isinstance(sess_config, tf.ConfigProto), sess_config.__class__
# a list of input/output variables # input/output variables
input_vars = config['inputs'] input_vars = config['inputs']
input_queue = config['input_queue'] input_queue = config['input_queue']
get_model_func = config['get_model_func'] get_model_func = config['get_model_func']
...@@ -49,9 +46,10 @@ def start_train(config): ...@@ -49,9 +46,10 @@ def start_train(config):
enqueue_op = input_queue.enqueue(tuple(input_vars)) enqueue_op = input_queue.enqueue(tuple(input_vars))
model_inputs = input_queue.dequeue() model_inputs = input_queue.dequeue()
# set dequeue shape
for qv, v in zip(model_inputs, input_vars): for qv, v in zip(model_inputs, input_vars):
qv.set_shape(v.get_shape()) qv.set_shape(v.get_shape())
output_vars, cost_var = get_model_func(model_inputs) output_vars, cost_var = get_model_func(model_inputs, is_training=True)
# build graph # build graph
G = tf.get_default_graph() G = tf.get_default_graph()
...@@ -84,19 +82,19 @@ def start_train(config): ...@@ -84,19 +82,19 @@ def start_train(config):
# a thread that keeps filling the queue # a thread that keeps filling the queue
th = EnqueueThread(sess, coord, enqueue_op, dataset_train) th = EnqueueThread(sess, coord, enqueue_op, dataset_train)
with sess.as_default(), \ with sess.as_default(), \
coordinator_context( coordinator_guard(
sess, coord, th, input_queue): sess, coord, th, input_queue):
callbacks.before_train() callbacks.before_train()
for epoch in xrange(1, max_epoch): for epoch in xrange(1, max_epoch):
with timed_operation('epoch {}'.format(epoch)): with timed_operation('epoch {}'.format(epoch)):
for step in xrange(dataset_train.size()): for step in xrange(dataset_train.size()):
# TODO eval dequeue to get dp fetches = [train_op, cost_var] + output_vars + model_inputs
fetches = [train_op, cost_var] + output_vars results = sess.run(fetches)
feed = {IS_TRAINING_VAR_NAME: True}
results = sess.run(fetches, feed_dict=feed)
cost = results[1] cost = results[1]
outputs = results[2:] outputs = results[2:2 + len(output_vars)]
# TODO trigger_step inputs = results[-len(model_inputs):]
callbacks.trigger_step(inputs, outputs, cost)
# note that summary_op will take a data from the queue. # note that summary_op will take a data from the queue.
callbacks.trigger_epoch() callbacks.trigger_epoch()
sess.close() sess.close()
......
...@@ -57,9 +57,7 @@ def create_test_graph(): ...@@ -57,9 +57,7 @@ def create_test_graph():
)) ))
for v in input_vars: for v in input_vars:
Gtest.add_to_collection(INPUT_VARS_KEY, v) Gtest.add_to_collection(INPUT_VARS_KEY, v)
is_training = tf.constant(False, name=IS_TRAINING_OP_NAME) output_vars, cost = forward_func(input_vars, is_training=False)
output_vars, cost = forward_func(input_vars)
for v in output_vars: for v in output_vars:
Gtest.add_to_collection(OUTPUT_VARS_KEY, v) Gtest.add_to_collection(OUTPUT_VARS_KEY, v)
yield Gtest yield Gtest
......
...@@ -34,9 +34,9 @@ class Callback(object): ...@@ -34,9 +34,9 @@ class Callback(object):
""" """
Callback to be triggered after every step (every backpropagation) Callback to be triggered after every step (every backpropagation)
Args: Args:
inputs: the input dict fed into the graph inputs: the list of input values
outputs: list of output values after running this dp outputs: list of output values after running this inputs
cost: the cost value after running this dp cost: the cost value after running this input
""" """
def trigger_epoch(self): def trigger_epoch(self):
...@@ -85,9 +85,7 @@ class SummaryWriter(Callback): ...@@ -85,9 +85,7 @@ class SummaryWriter(Callback):
# check if there is any summary # check if there is any summary
if self.summary_op is None: if self.summary_op is None:
return return
summary_str = self.summary_op.eval()
feed = {IS_TRAINING_VAR_NAME: True}
summary_str = self.summary_op.eval(feed_dict=feed)
self.epoch_num += 1 self.epoch_num += 1
self.writer.add_summary(summary_str, self.epoch_num) self.writer.add_summary(summary_str, self.epoch_num)
...@@ -102,9 +100,7 @@ class CallbackTimeLogger(object): ...@@ -102,9 +100,7 @@ class CallbackTimeLogger(object):
self.times.append((name, time)) self.times.append((name, time))
def log(self): def log(self):
""" """ log the time of some heavy callbacks """
log the time of some heavy callbacks
"""
if self.tot < 3: if self.tot < 3:
return return
msgs = [] msgs = []
...@@ -162,18 +158,21 @@ class TestCallbacks(Callback): ...@@ -162,18 +158,21 @@ class TestCallbacks(Callback):
def trigger_epoch(self): def trigger_epoch(self):
tm = CallbackTimeLogger() tm = CallbackTimeLogger()
with self.graph.as_default(): with self.graph.as_default(), self.sess.as_default():
with self.sess.as_default(): s = time.time()
ckpt = tf.train.get_checkpoint_state(logger.LOG_DIR)
if ckpt is None:
logger.error(
"Cannot find a checkpoint state. Do you forget to use PeriodicSaver?")
return
logger.info(
"Restore checkpoint from {}".format(ckpt.model_checkpoint_path))
self.saver.restore(self.sess, ckpt.model_checkpoint_path)
tm.add('restore session', time.time() - s)
for cb in self.cbs:
s = time.time() s = time.time()
ckpt = tf.train.get_checkpoint_state(logger.LOG_DIR) cb.trigger_epoch()
if ckpt is None: tm.add(type(cb).__name__, time.time() - s)
from IPython import embed; embed()
self.saver.restore(self.sess, ckpt.model_checkpoint_path)
tm.add('restore session', time.time() - s)
for cb in self.cbs:
s = time.time()
cb.trigger_epoch()
tm.add(type(cb).__name__, time.time() - s)
self.writer.flush() self.writer.flush()
tm.log() tm.log()
...@@ -198,7 +197,7 @@ class Callbacks(Callback): ...@@ -198,7 +197,7 @@ class Callbacks(Callback):
self.test.before_train() self.test.before_train()
def trigger_step(self, inputs, outputs, cost): def trigger_step(self, inputs, outputs, cost):
self.train.trigger_step() self.train.trigger_step(inputs, outputs, cost)
# test callback don't have trigger_step # test callback don't have trigger_step
def trigger_epoch(self): def trigger_epoch(self):
......
...@@ -45,9 +45,11 @@ class EnqueueThread(threading.Thread): ...@@ -45,9 +45,11 @@ class EnqueueThread(threading.Thread):
logger.exception("Exception in EnqueueThread:") logger.exception("Exception in EnqueueThread:")
@contextmanager @contextmanager
def coordinator_context(sess, coord, thread, queue): def coordinator_guard(sess, coord, thread, queue):
""" """
Context manager to make sure queue is closed and thread is joined Context manager to make sure that:
queue is closed
thread is joined
""" """
thread.start() thread.start()
try: try:
......
...@@ -19,7 +19,6 @@ def create_summary(name, v): ...@@ -19,7 +19,6 @@ def create_summary(name, v):
return s return s
def add_activation_summary(x, name=None): def add_activation_summary(x, name=None):
# TODO dedup
""" """
Summary for an activation tensor x. Summary for an activation tensor x.
If name is None, use x.name If name is None, use x.name
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment