Commit d8330092 authored by Yuxin Wu's avatar Yuxin Wu

refine multigpu code

parent 087dc382
......@@ -55,7 +55,7 @@ def get_model(inputs, is_training):
y = one_hot(label, 1000)
cost = tf.nn.softmax_cross_entropy_with_logits(logits, y)
cost = tf.reduce_mean(cost, name='cross_entropy_loss')
tf.add_to_collection(COST_VARS_KEY, cost)
tf.add_to_collection(MOVING_SUMMARY_VARS_KEY, cost)
# compute the number of failed samples, for ValidationError to use at test time
wrong = tf.not_equal(
......@@ -64,13 +64,13 @@ def get_model(inputs, is_training):
nr_wrong = tf.reduce_sum(wrong, name='wrong')
# monitor training error
tf.add_to_collection(
SUMMARY_VARS_KEY, tf.reduce_mean(wrong, name='train_error'))
MOVING_SUMMARY_VARS_KEY, tf.reduce_mean(wrong, name='train_error'))
# weight decay on all W of fc layers
wd_cost = tf.mul(1e-4,
regularize_cost('fc.*/W', tf.nn.l2_loss),
name='regularize_loss')
tf.add_to_collection(COST_VARS_KEY, wd_cost)
tf.add_to_collection(MOVING_SUMMARY_VARS_KEY, wd_cost)
add_histogram_summary('.*/W') # monitor histogram of all W
return [prob, nr_wrong], tf.add_n([wd_cost, cost], name='cost')
......
......@@ -56,7 +56,7 @@ def get_model(inputs, is_training):
y = one_hot(label, 10)
cost = tf.nn.softmax_cross_entropy_with_logits(logits, y)
cost = tf.reduce_mean(cost, name='cross_entropy_loss')
tf.add_to_collection(SUMMARY_VARS_KEY, cost)
tf.add_to_collection(MOVING_SUMMARY_VARS_KEY, cost)
# compute the number of failed samples, for ValidationError to use at test time
wrong = tf.not_equal(
......@@ -65,13 +65,13 @@ def get_model(inputs, is_training):
nr_wrong = tf.reduce_sum(wrong, name='wrong')
# monitor training error
tf.add_to_collection(
SUMMARY_VARS_KEY, tf.reduce_mean(wrong, name='train_error'))
MOVING_SUMMARY_VARS_KEY, tf.reduce_mean(wrong, name='train_error'))
# weight decay on all W of fc layers
wd_cost = tf.mul(1e-4,
regularize_cost('fc.*/W', tf.nn.l2_loss),
name='regularize_loss')
tf.add_to_collection(SUMMARY_VARS_KEY, wd_cost)
tf.add_to_collection(MOVING_SUMMARY_VARS_KEY, wd_cost)
add_histogram_summary('.*/W') # monitor histogram of all W
return [prob, nr_wrong], tf.add_n([wd_cost, cost], name='cost')
......@@ -154,5 +154,6 @@ if __name__ == '__main__':
config = get_config()
if args.load:
config.session_init = SaverRestore(args.load)
if args.gpu:
config['nr_tower'] = len(args.gpu.split(','))
start_train(config)
......@@ -65,7 +65,7 @@ def get_model(inputs, is_training):
y = one_hot(label, 10)
cost = tf.nn.softmax_cross_entropy_with_logits(logits, y)
cost = tf.reduce_mean(cost, name='cross_entropy_loss')
tf.add_to_collection(COST_VARS_KEY, cost)
tf.add_to_collection(MOVING_SUMMARY_VARS_KEY, cost)
# compute the number of failed samples, for ValidationError to use at test time
wrong = tf.not_equal(
......@@ -74,13 +74,13 @@ def get_model(inputs, is_training):
nr_wrong = tf.reduce_sum(wrong, name='wrong')
# monitor training error
tf.add_to_collection(
SUMMARY_VARS_KEY, tf.reduce_mean(wrong, name='train_error'))
MOVING_SUMMARY_VARS_KEY, tf.reduce_mean(wrong, name='train_error'))
# weight decay on all W of fc layers
wd_cost = tf.mul(1e-4,
regularize_cost('fc.*/W', tf.nn.l2_loss),
name='regularize_loss')
tf.add_to_collection(COST_VARS_KEY, wd_cost)
tf.add_to_collection(MOVING_SUMMARY_VARS_KEY, wd_cost)
add_histogram_summary('.*/W') # monitor histogram of all W
return [prob, nr_wrong], tf.add_n([wd_cost, cost], name='cost')
......@@ -95,7 +95,7 @@ def get_config():
dataset_train = BatchData(dataset.Mnist('train'), 128)
dataset_test = BatchData(dataset.Mnist('test'), 256, remainder=True)
step_per_epoch = dataset_train.size()
step_per_epoch = 30
#step_per_epoch = 30
#dataset_test = FixedSizeData(dataset_test, 20)
sess_config = get_default_sess_config()
......
......@@ -46,6 +46,7 @@ class TrainConfig(object):
step_per_epoch: the number of steps (parameter updates) to perform
in each epoch. default to dataset.size()
max_epoch: maximum number of epoch to run training. default to 100
nr_tower: int. number of towers. default to 1.
"""
def assert_type(v, tp):
assert isinstance(v, tp), v.__class__
......@@ -70,21 +71,10 @@ class TrainConfig(object):
self.step_per_epoch = int(kwargs.pop('step_per_epoch', self.dataset.size()))
self.max_epoch = int(kwargs.pop('max_epoch', 100))
assert self.step_per_epoch > 0 and self.max_epoch > 0
self.nr_tower = int(kwargs.pop('nr_tower', 1))
assert len(kwargs) == 0, 'Unknown arguments: {}'.format(str(kwargs.keys()))
def average_gradients(tower_grads):
"""Calculate the average gradient for each shared variable across all towers.
Note that this function provides a synchronization point across all towers.
Args:
tower_grads: List of lists of (gradient, variable) tuples. The outer list
is over individual gradients. The inner list is over the gradient
calculation for each tower.
Returns:
List of pairs of (gradient, variable) where the gradient has been averaged
across all towers.
"""
average_grads = []
for grad_and_vars in zip(*tower_grads):
# Note that each grad_and_vars looks like the following:
......@@ -109,6 +99,10 @@ def average_gradients(tower_grads):
average_grads.append(grad_and_var)
return average_grads
def summary_grads(grads):
for grad, var in grads:
if grad:
tf.histogram_summary(var.op.name + '/gradients', grad)
def start_train(config):
"""
......@@ -120,6 +114,10 @@ def start_train(config):
input_queue = config.input_queue
callbacks = config.callbacks
tf.add_to_collection(FORWARD_FUNC_KEY, config.get_model_func)
for v in input_vars:
tf.add_to_collection(INPUT_VARS_KEY, v)
def get_model_inputs():
model_inputs = input_queue.dequeue()
for qv, v in zip(model_inputs, input_vars):
......@@ -134,40 +132,38 @@ def start_train(config):
else:
enqueue_op = input_queue.enqueue_many(input_vars)
keys_to_maintain = [tf.GraphKeys.SUMMARIES, SUMMARY_VARS_KEY]
olds = {}
for k in keys_to_maintain:
olds[k] = copy.copy(tf.get_collection(k))
all_grads = []
n_tower = 1
for i in range(n_tower):
# get gradients to update:
logger.info("Training a model of {} tower".format(config.nr_tower))
if config.nr_tower > 1:
coll_keys = [tf.GraphKeys.SUMMARIES, MOVING_SUMMARY_VARS_KEY]
kept_summaries = {}
grads = []
for i in range(config.nr_tower):
with tf.device('/gpu:{}'.format(i)):
with tf.name_scope('tower{}'.format(i)):
for k in keys_to_maintain:
del tf.get_collection(k)[:]
model_inputs = get_model_inputs()
output_vars, cost_var = config.get_model_func(model_inputs, is_training=True)
grads.append(
config.optimizer.compute_gradients(cost_var))
if i == 0:
tf.get_variable_scope().reuse_variables()
for k in coll_keys:
kept_summaries[k] = copy.copy(tf.get_collection(k))
for k in coll_keys: # avoid repeating summary on multiple devices
del tf.get_collection(k)[:]
tf.get_collection(k).extend(kept_summaries[k])
grads = average_gradients(grads)
else:
model_inputs = get_model_inputs()
output_vars, cost_var = config.get_model_func(model_inputs, is_training=True)
grads = config.optimizer.compute_gradients(cost_var)
all_grads.append(grads)
for k in keys_to_maintain:
tf.get_collection(k).extend(olds[k])
grads = average_gradients(all_grads)
for grad, var in grads:
if grad:
tf.histogram_summary(var.op.name + '/gradients', grad)
summary_grads(grads)
avg_maintain_op = summary_moving_average(cost_var)
# build graph
tf.add_to_collection(FORWARD_FUNC_KEY, config.get_model_func)
for v in input_vars:
tf.add_to_collection(INPUT_VARS_KEY, v)
describe_model()
# train_op = get_train_op(config.optimizer, cost_var)
with tf.control_dependencies([avg_maintain_op]):
train_op = config.optimizer.apply_gradients(grads, get_global_step_var())
describe_model()
sess = tf.Session(config=config.session_config)
config.session_init.init(sess)
......
......@@ -9,8 +9,7 @@ GLOBAL_STEP_VAR_NAME = 'global_step:0'
SUMMARY_WRITER_COLLECTION_KEY = 'summary_writer'
INPUT_VARS_KEY = 'INPUT_VARIABLES'
COST_VARS_KEY = 'COST_VARIABLES' # keep track of each individual cost
SUMMARY_VARS_KEY = 'SUMMARY_VARIABLES' # extra variables to summarize during training
MOVING_SUMMARY_VARS_KEY = 'MOVING_SUMMARY_VARIABLES' # extra variables to summarize during training
FORWARD_FUNC_KEY = 'FORWARD_FUNCTION'
# export all upper case variables
......
......@@ -31,7 +31,6 @@ def add_activation_summary(x, name=None):
name = x.name
tf.histogram_summary(name + '/activations', x)
tf.scalar_summary(name + '/sparsity', tf.nn.zero_fraction(x))
# TODO avoid repeating activations on multiple GPUs
def add_histogram_summary(regex):
"""
......@@ -46,15 +45,14 @@ def add_histogram_summary(regex):
def summary_moving_average(cost_var):
""" Create a MovingAverage op and summary for all variables in
COST_VARS_KEY, SUMMARY_VARS_KEY, as well as the argument
MOVING_SUMMARY_VARS_KEY, as well as the argument
Return a op to maintain these average
"""
global_step_var = tf.get_default_graph().get_tensor_by_name(GLOBAL_STEP_VAR_NAME)
averager = tf.train.ExponentialMovingAverage(
0.99, num_updates=global_step_var, name='moving_averages')
vars_to_summary = [cost_var] + \
tf.get_collection(SUMMARY_VARS_KEY) + \
tf.get_collection(COST_VARS_KEY)
tf.get_collection(MOVING_SUMMARY_VARS_KEY)
avg_maintain_op = averager.apply(vars_to_summary)
for idx, c in enumerate(vars_to_summary):
name = c.op.name
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment