Commit d8330092 authored by Yuxin Wu's avatar Yuxin Wu

refine multigpu code

parent 087dc382
...@@ -55,7 +55,7 @@ def get_model(inputs, is_training): ...@@ -55,7 +55,7 @@ def get_model(inputs, is_training):
y = one_hot(label, 1000) y = one_hot(label, 1000)
cost = tf.nn.softmax_cross_entropy_with_logits(logits, y) cost = tf.nn.softmax_cross_entropy_with_logits(logits, y)
cost = tf.reduce_mean(cost, name='cross_entropy_loss') cost = tf.reduce_mean(cost, name='cross_entropy_loss')
tf.add_to_collection(COST_VARS_KEY, cost) tf.add_to_collection(MOVING_SUMMARY_VARS_KEY, cost)
# compute the number of failed samples, for ValidationError to use at test time # compute the number of failed samples, for ValidationError to use at test time
wrong = tf.not_equal( wrong = tf.not_equal(
...@@ -64,13 +64,13 @@ def get_model(inputs, is_training): ...@@ -64,13 +64,13 @@ def get_model(inputs, is_training):
nr_wrong = tf.reduce_sum(wrong, name='wrong') nr_wrong = tf.reduce_sum(wrong, name='wrong')
# monitor training error # monitor training error
tf.add_to_collection( tf.add_to_collection(
SUMMARY_VARS_KEY, tf.reduce_mean(wrong, name='train_error')) MOVING_SUMMARY_VARS_KEY, tf.reduce_mean(wrong, name='train_error'))
# weight decay on all W of fc layers # weight decay on all W of fc layers
wd_cost = tf.mul(1e-4, wd_cost = tf.mul(1e-4,
regularize_cost('fc.*/W', tf.nn.l2_loss), regularize_cost('fc.*/W', tf.nn.l2_loss),
name='regularize_loss') name='regularize_loss')
tf.add_to_collection(COST_VARS_KEY, wd_cost) tf.add_to_collection(MOVING_SUMMARY_VARS_KEY, wd_cost)
add_histogram_summary('.*/W') # monitor histogram of all W add_histogram_summary('.*/W') # monitor histogram of all W
return [prob, nr_wrong], tf.add_n([wd_cost, cost], name='cost') return [prob, nr_wrong], tf.add_n([wd_cost, cost], name='cost')
......
...@@ -56,7 +56,7 @@ def get_model(inputs, is_training): ...@@ -56,7 +56,7 @@ def get_model(inputs, is_training):
y = one_hot(label, 10) y = one_hot(label, 10)
cost = tf.nn.softmax_cross_entropy_with_logits(logits, y) cost = tf.nn.softmax_cross_entropy_with_logits(logits, y)
cost = tf.reduce_mean(cost, name='cross_entropy_loss') cost = tf.reduce_mean(cost, name='cross_entropy_loss')
tf.add_to_collection(SUMMARY_VARS_KEY, cost) tf.add_to_collection(MOVING_SUMMARY_VARS_KEY, cost)
# compute the number of failed samples, for ValidationError to use at test time # compute the number of failed samples, for ValidationError to use at test time
wrong = tf.not_equal( wrong = tf.not_equal(
...@@ -65,13 +65,13 @@ def get_model(inputs, is_training): ...@@ -65,13 +65,13 @@ def get_model(inputs, is_training):
nr_wrong = tf.reduce_sum(wrong, name='wrong') nr_wrong = tf.reduce_sum(wrong, name='wrong')
# monitor training error # monitor training error
tf.add_to_collection( tf.add_to_collection(
SUMMARY_VARS_KEY, tf.reduce_mean(wrong, name='train_error')) MOVING_SUMMARY_VARS_KEY, tf.reduce_mean(wrong, name='train_error'))
# weight decay on all W of fc layers # weight decay on all W of fc layers
wd_cost = tf.mul(1e-4, wd_cost = tf.mul(1e-4,
regularize_cost('fc.*/W', tf.nn.l2_loss), regularize_cost('fc.*/W', tf.nn.l2_loss),
name='regularize_loss') name='regularize_loss')
tf.add_to_collection(SUMMARY_VARS_KEY, wd_cost) tf.add_to_collection(MOVING_SUMMARY_VARS_KEY, wd_cost)
add_histogram_summary('.*/W') # monitor histogram of all W add_histogram_summary('.*/W') # monitor histogram of all W
return [prob, nr_wrong], tf.add_n([wd_cost, cost], name='cost') return [prob, nr_wrong], tf.add_n([wd_cost, cost], name='cost')
...@@ -154,5 +154,6 @@ if __name__ == '__main__': ...@@ -154,5 +154,6 @@ if __name__ == '__main__':
config = get_config() config = get_config()
if args.load: if args.load:
config.session_init = SaverRestore(args.load) config.session_init = SaverRestore(args.load)
if args.gpu:
config['nr_tower'] = len(args.gpu.split(','))
start_train(config) start_train(config)
...@@ -65,7 +65,7 @@ def get_model(inputs, is_training): ...@@ -65,7 +65,7 @@ def get_model(inputs, is_training):
y = one_hot(label, 10) y = one_hot(label, 10)
cost = tf.nn.softmax_cross_entropy_with_logits(logits, y) cost = tf.nn.softmax_cross_entropy_with_logits(logits, y)
cost = tf.reduce_mean(cost, name='cross_entropy_loss') cost = tf.reduce_mean(cost, name='cross_entropy_loss')
tf.add_to_collection(COST_VARS_KEY, cost) tf.add_to_collection(MOVING_SUMMARY_VARS_KEY, cost)
# compute the number of failed samples, for ValidationError to use at test time # compute the number of failed samples, for ValidationError to use at test time
wrong = tf.not_equal( wrong = tf.not_equal(
...@@ -74,13 +74,13 @@ def get_model(inputs, is_training): ...@@ -74,13 +74,13 @@ def get_model(inputs, is_training):
nr_wrong = tf.reduce_sum(wrong, name='wrong') nr_wrong = tf.reduce_sum(wrong, name='wrong')
# monitor training error # monitor training error
tf.add_to_collection( tf.add_to_collection(
SUMMARY_VARS_KEY, tf.reduce_mean(wrong, name='train_error')) MOVING_SUMMARY_VARS_KEY, tf.reduce_mean(wrong, name='train_error'))
# weight decay on all W of fc layers # weight decay on all W of fc layers
wd_cost = tf.mul(1e-4, wd_cost = tf.mul(1e-4,
regularize_cost('fc.*/W', tf.nn.l2_loss), regularize_cost('fc.*/W', tf.nn.l2_loss),
name='regularize_loss') name='regularize_loss')
tf.add_to_collection(COST_VARS_KEY, wd_cost) tf.add_to_collection(MOVING_SUMMARY_VARS_KEY, wd_cost)
add_histogram_summary('.*/W') # monitor histogram of all W add_histogram_summary('.*/W') # monitor histogram of all W
return [prob, nr_wrong], tf.add_n([wd_cost, cost], name='cost') return [prob, nr_wrong], tf.add_n([wd_cost, cost], name='cost')
...@@ -95,7 +95,7 @@ def get_config(): ...@@ -95,7 +95,7 @@ def get_config():
dataset_train = BatchData(dataset.Mnist('train'), 128) dataset_train = BatchData(dataset.Mnist('train'), 128)
dataset_test = BatchData(dataset.Mnist('test'), 256, remainder=True) dataset_test = BatchData(dataset.Mnist('test'), 256, remainder=True)
step_per_epoch = dataset_train.size() step_per_epoch = dataset_train.size()
step_per_epoch = 30 #step_per_epoch = 30
#dataset_test = FixedSizeData(dataset_test, 20) #dataset_test = FixedSizeData(dataset_test, 20)
sess_config = get_default_sess_config() sess_config = get_default_sess_config()
......
...@@ -46,6 +46,7 @@ class TrainConfig(object): ...@@ -46,6 +46,7 @@ class TrainConfig(object):
step_per_epoch: the number of steps (parameter updates) to perform step_per_epoch: the number of steps (parameter updates) to perform
in each epoch. default to dataset.size() in each epoch. default to dataset.size()
max_epoch: maximum number of epoch to run training. default to 100 max_epoch: maximum number of epoch to run training. default to 100
nr_tower: int. number of towers. default to 1.
""" """
def assert_type(v, tp): def assert_type(v, tp):
assert isinstance(v, tp), v.__class__ assert isinstance(v, tp), v.__class__
...@@ -70,21 +71,10 @@ class TrainConfig(object): ...@@ -70,21 +71,10 @@ class TrainConfig(object):
self.step_per_epoch = int(kwargs.pop('step_per_epoch', self.dataset.size())) self.step_per_epoch = int(kwargs.pop('step_per_epoch', self.dataset.size()))
self.max_epoch = int(kwargs.pop('max_epoch', 100)) self.max_epoch = int(kwargs.pop('max_epoch', 100))
assert self.step_per_epoch > 0 and self.max_epoch > 0 assert self.step_per_epoch > 0 and self.max_epoch > 0
self.nr_tower = int(kwargs.pop('nr_tower', 1))
assert len(kwargs) == 0, 'Unknown arguments: {}'.format(str(kwargs.keys())) assert len(kwargs) == 0, 'Unknown arguments: {}'.format(str(kwargs.keys()))
def average_gradients(tower_grads): def average_gradients(tower_grads):
"""Calculate the average gradient for each shared variable across all towers.
Note that this function provides a synchronization point across all towers.
Args:
tower_grads: List of lists of (gradient, variable) tuples. The outer list
is over individual gradients. The inner list is over the gradient
calculation for each tower.
Returns:
List of pairs of (gradient, variable) where the gradient has been averaged
across all towers.
"""
average_grads = [] average_grads = []
for grad_and_vars in zip(*tower_grads): for grad_and_vars in zip(*tower_grads):
# Note that each grad_and_vars looks like the following: # Note that each grad_and_vars looks like the following:
...@@ -109,6 +99,10 @@ def average_gradients(tower_grads): ...@@ -109,6 +99,10 @@ def average_gradients(tower_grads):
average_grads.append(grad_and_var) average_grads.append(grad_and_var)
return average_grads return average_grads
def summary_grads(grads):
for grad, var in grads:
if grad:
tf.histogram_summary(var.op.name + '/gradients', grad)
def start_train(config): def start_train(config):
""" """
...@@ -120,6 +114,10 @@ def start_train(config): ...@@ -120,6 +114,10 @@ def start_train(config):
input_queue = config.input_queue input_queue = config.input_queue
callbacks = config.callbacks callbacks = config.callbacks
tf.add_to_collection(FORWARD_FUNC_KEY, config.get_model_func)
for v in input_vars:
tf.add_to_collection(INPUT_VARS_KEY, v)
def get_model_inputs(): def get_model_inputs():
model_inputs = input_queue.dequeue() model_inputs = input_queue.dequeue()
for qv, v in zip(model_inputs, input_vars): for qv, v in zip(model_inputs, input_vars):
...@@ -134,40 +132,38 @@ def start_train(config): ...@@ -134,40 +132,38 @@ def start_train(config):
else: else:
enqueue_op = input_queue.enqueue_many(input_vars) enqueue_op = input_queue.enqueue_many(input_vars)
keys_to_maintain = [tf.GraphKeys.SUMMARIES, SUMMARY_VARS_KEY] # get gradients to update:
olds = {} logger.info("Training a model of {} tower".format(config.nr_tower))
for k in keys_to_maintain: if config.nr_tower > 1:
olds[k] = copy.copy(tf.get_collection(k)) coll_keys = [tf.GraphKeys.SUMMARIES, MOVING_SUMMARY_VARS_KEY]
all_grads = [] kept_summaries = {}
n_tower = 1 grads = []
for i in range(n_tower): for i in range(config.nr_tower):
with tf.device('/gpu:{}'.format(i)): with tf.device('/gpu:{}'.format(i)):
with tf.name_scope('tower{}'.format(i)):
for k in keys_to_maintain:
del tf.get_collection(k)[:]
model_inputs = get_model_inputs() model_inputs = get_model_inputs()
output_vars, cost_var = config.get_model_func(model_inputs, is_training=True) output_vars, cost_var = config.get_model_func(model_inputs, is_training=True)
tf.get_variable_scope().reuse_variables() grads.append(
grads = config.optimizer.compute_gradients(cost_var) config.optimizer.compute_gradients(cost_var))
all_grads.append(grads)
for k in keys_to_maintain: if i == 0:
tf.get_collection(k).extend(olds[k]) tf.get_variable_scope().reuse_variables()
grads = average_gradients(all_grads) for k in coll_keys:
for grad, var in grads: kept_summaries[k] = copy.copy(tf.get_collection(k))
if grad: for k in coll_keys: # avoid repeating summary on multiple devices
tf.histogram_summary(var.op.name + '/gradients', grad) del tf.get_collection(k)[:]
tf.get_collection(k).extend(kept_summaries[k])
grads = average_gradients(grads)
else:
model_inputs = get_model_inputs()
output_vars, cost_var = config.get_model_func(model_inputs, is_training=True)
grads = config.optimizer.compute_gradients(cost_var)
summary_grads(grads)
avg_maintain_op = summary_moving_average(cost_var) avg_maintain_op = summary_moving_average(cost_var)
# build graph
tf.add_to_collection(FORWARD_FUNC_KEY, config.get_model_func)
for v in input_vars:
tf.add_to_collection(INPUT_VARS_KEY, v)
describe_model()
# train_op = get_train_op(config.optimizer, cost_var)
with tf.control_dependencies([avg_maintain_op]): with tf.control_dependencies([avg_maintain_op]):
train_op = config.optimizer.apply_gradients(grads, get_global_step_var()) train_op = config.optimizer.apply_gradients(grads, get_global_step_var())
describe_model()
sess = tf.Session(config=config.session_config) sess = tf.Session(config=config.session_config)
config.session_init.init(sess) config.session_init.init(sess)
......
...@@ -9,8 +9,7 @@ GLOBAL_STEP_VAR_NAME = 'global_step:0' ...@@ -9,8 +9,7 @@ GLOBAL_STEP_VAR_NAME = 'global_step:0'
SUMMARY_WRITER_COLLECTION_KEY = 'summary_writer' SUMMARY_WRITER_COLLECTION_KEY = 'summary_writer'
INPUT_VARS_KEY = 'INPUT_VARIABLES' INPUT_VARS_KEY = 'INPUT_VARIABLES'
COST_VARS_KEY = 'COST_VARIABLES' # keep track of each individual cost MOVING_SUMMARY_VARS_KEY = 'MOVING_SUMMARY_VARIABLES' # extra variables to summarize during training
SUMMARY_VARS_KEY = 'SUMMARY_VARIABLES' # extra variables to summarize during training
FORWARD_FUNC_KEY = 'FORWARD_FUNCTION' FORWARD_FUNC_KEY = 'FORWARD_FUNCTION'
# export all upper case variables # export all upper case variables
......
...@@ -31,7 +31,6 @@ def add_activation_summary(x, name=None): ...@@ -31,7 +31,6 @@ def add_activation_summary(x, name=None):
name = x.name name = x.name
tf.histogram_summary(name + '/activations', x) tf.histogram_summary(name + '/activations', x)
tf.scalar_summary(name + '/sparsity', tf.nn.zero_fraction(x)) tf.scalar_summary(name + '/sparsity', tf.nn.zero_fraction(x))
# TODO avoid repeating activations on multiple GPUs
def add_histogram_summary(regex): def add_histogram_summary(regex):
""" """
...@@ -46,15 +45,14 @@ def add_histogram_summary(regex): ...@@ -46,15 +45,14 @@ def add_histogram_summary(regex):
def summary_moving_average(cost_var): def summary_moving_average(cost_var):
""" Create a MovingAverage op and summary for all variables in """ Create a MovingAverage op and summary for all variables in
COST_VARS_KEY, SUMMARY_VARS_KEY, as well as the argument MOVING_SUMMARY_VARS_KEY, as well as the argument
Return a op to maintain these average Return a op to maintain these average
""" """
global_step_var = tf.get_default_graph().get_tensor_by_name(GLOBAL_STEP_VAR_NAME) global_step_var = tf.get_default_graph().get_tensor_by_name(GLOBAL_STEP_VAR_NAME)
averager = tf.train.ExponentialMovingAverage( averager = tf.train.ExponentialMovingAverage(
0.99, num_updates=global_step_var, name='moving_averages') 0.99, num_updates=global_step_var, name='moving_averages')
vars_to_summary = [cost_var] + \ vars_to_summary = [cost_var] + \
tf.get_collection(SUMMARY_VARS_KEY) + \ tf.get_collection(MOVING_SUMMARY_VARS_KEY)
tf.get_collection(COST_VARS_KEY)
avg_maintain_op = averager.apply(vars_to_summary) avg_maintain_op = averager.apply(vars_to_summary)
for idx, c in enumerate(vars_to_summary): for idx, c in enumerate(vars_to_summary):
name = c.op.name name = c.op.name
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment