Commit 087dc382 authored by ppwwyyxx's avatar ppwwyyxx

[WIP] tower training

parent 95037482
...@@ -56,7 +56,7 @@ def get_model(inputs, is_training): ...@@ -56,7 +56,7 @@ def get_model(inputs, is_training):
y = one_hot(label, 10) y = one_hot(label, 10)
cost = tf.nn.softmax_cross_entropy_with_logits(logits, y) cost = tf.nn.softmax_cross_entropy_with_logits(logits, y)
cost = tf.reduce_mean(cost, name='cross_entropy_loss') cost = tf.reduce_mean(cost, name='cross_entropy_loss')
tf.add_to_collection(COST_VARS_KEY, cost) tf.add_to_collection(SUMMARY_VARS_KEY, cost)
# compute the number of failed samples, for ValidationError to use at test time # compute the number of failed samples, for ValidationError to use at test time
wrong = tf.not_equal( wrong = tf.not_equal(
...@@ -71,7 +71,7 @@ def get_model(inputs, is_training): ...@@ -71,7 +71,7 @@ def get_model(inputs, is_training):
wd_cost = tf.mul(1e-4, wd_cost = tf.mul(1e-4,
regularize_cost('fc.*/W', tf.nn.l2_loss), regularize_cost('fc.*/W', tf.nn.l2_loss),
name='regularize_loss') name='regularize_loss')
tf.add_to_collection(COST_VARS_KEY, wd_cost) tf.add_to_collection(SUMMARY_VARS_KEY, wd_cost)
add_histogram_summary('.*/W') # monitor histogram of all W add_histogram_summary('.*/W') # monitor histogram of all W
return [prob, nr_wrong], tf.add_n([wd_cost, cost], name='cost') return [prob, nr_wrong], tf.add_n([wd_cost, cost], name='cost')
...@@ -105,6 +105,7 @@ def get_config(): ...@@ -105,6 +105,7 @@ def get_config():
sess_config = get_default_sess_config() sess_config = get_default_sess_config()
sess_config.gpu_options.per_process_gpu_memory_fraction = 0.5 sess_config.gpu_options.per_process_gpu_memory_fraction = 0.5
sess_config.device_count['GPU'] = 2
# prepare model # prepare model
input_vars = [ input_vars = [
...@@ -149,7 +150,8 @@ if __name__ == '__main__': ...@@ -149,7 +150,8 @@ if __name__ == '__main__':
os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
with tf.Graph().as_default(): with tf.Graph().as_default():
config = get_config() with tf.device('/cpu:0'):
config = get_config()
if args.load: if args.load:
config.session_init = SaverRestore(args.load) config.session_init = SaverRestore(args.load)
......
pip @ https://storage.googleapis.com/tensorflow/linux/gpu/tensorflow-0.6.0-cp27-none-linux_x86_64.whl
termcolor termcolor
numpy numpy
protobuf~=3.0.0a1 protobuf~=3.0.0a1
......
...@@ -5,6 +5,7 @@ ...@@ -5,6 +5,7 @@
import tensorflow as tf import tensorflow as tf
import os import os
import re
from .base import Callback, PeriodicCallback from .base import Callback, PeriodicCallback
from ..utils import * from ..utils import *
...@@ -47,6 +48,8 @@ class SummaryWriter(Callback): ...@@ -47,6 +48,8 @@ class SummaryWriter(Callback):
summary_str = self.summary_op.eval() summary_str = self.summary_op.eval()
summary = tf.Summary.FromString(summary_str) summary = tf.Summary.FromString(summary_str)
for val in summary.value: for val in summary.value:
#print val.tag
val.tag = re.sub('tower[0-9]*/', '', val.tag)
if val.tag in self.print_tag: if val.tag in self.print_tag:
assert val.WhichOneof('value') == 'simple_value', \ assert val.WhichOneof('value') == 'simple_value', \
'Cannot print summary {}: not a simple_value summary!'.format(val.tag) 'Cannot print summary {}: not a simple_value summary!'.format(val.tag)
......
...@@ -32,8 +32,6 @@ def create_test_graph(): ...@@ -32,8 +32,6 @@ def create_test_graph():
for v in input_vars: for v in input_vars:
Gtest.add_to_collection(INPUT_VARS_KEY, v) Gtest.add_to_collection(INPUT_VARS_KEY, v)
output_vars, cost = forward_func(input_vars, is_training=False) output_vars, cost = forward_func(input_vars, is_training=False)
for v in output_vars:
Gtest.add_to_collection(OUTPUT_VARS_KEY, v)
yield Gtest yield Gtest
@contextmanager @contextmanager
......
...@@ -5,6 +5,7 @@ ...@@ -5,6 +5,7 @@
import tensorflow as tf import tensorflow as tf
from itertools import count from itertools import count
import copy
import argparse import argparse
import tqdm import tqdm
...@@ -71,20 +72,43 @@ class TrainConfig(object): ...@@ -71,20 +72,43 @@ class TrainConfig(object):
assert self.step_per_epoch > 0 and self.max_epoch > 0 assert self.step_per_epoch > 0 and self.max_epoch > 0
assert len(kwargs) == 0, 'Unknown arguments: {}'.format(str(kwargs.keys())) assert len(kwargs) == 0, 'Unknown arguments: {}'.format(str(kwargs.keys()))
def get_train_op(optimizer, cost_var): def average_gradients(tower_grads):
global_step_var = tf.get_default_graph().get_tensor_by_name(GLOBAL_STEP_VAR_NAME) """Calculate the average gradient for each shared variable across all towers.
Note that this function provides a synchronization point across all towers.
Args:
tower_grads: List of lists of (gradient, variable) tuples. The outer list
is over individual gradients. The inner list is over the gradient
calculation for each tower.
Returns:
List of pairs of (gradient, variable) where the gradient has been averaged
across all towers.
"""
average_grads = []
for grad_and_vars in zip(*tower_grads):
# Note that each grad_and_vars looks like the following:
# ((grad0_gpu0, var0_gpu0), ... , (grad0_gpuN, var0_gpuN))
grads = []
for g, _ in grad_and_vars:
# Add 0 dimension to the gradients to represent the tower.
expanded_g = tf.expand_dims(g, 0)
# Append on a 'tower' dimension which we will average over below.
grads.append(expanded_g)
# Average over the 'tower' dimension.
grad = tf.concat(0, grads)
grad = tf.reduce_mean(grad, 0)
# Keep in mind that the Variables are redundant because they are shared
# across towers. So .. we will just return the first tower's pointer to
# the Variable.
v = grad_and_vars[0][1]
grad_and_var = (grad, v)
average_grads.append(grad_and_var)
return average_grads
avg_maintain_op = summary_moving_average(cost_var)
# maintain average in each step
with tf.control_dependencies([avg_maintain_op]):
grads = optimizer.compute_gradients(cost_var)
for grad, var in grads:
if grad:
tf.histogram_summary(var.op.name + '/gradients', grad)
return optimizer.apply_gradients(grads, global_step_var)
def start_train(config): def start_train(config):
""" """
...@@ -96,27 +120,53 @@ def start_train(config): ...@@ -96,27 +120,53 @@ def start_train(config):
input_queue = config.input_queue input_queue = config.input_queue
callbacks = config.callbacks callbacks = config.callbacks
if config.batched_model_input: def get_model_inputs():
enqueue_op = input_queue.enqueue(input_vars)
model_inputs = input_queue.dequeue() model_inputs = input_queue.dequeue()
for qv, v in zip(model_inputs, input_vars): for qv, v in zip(model_inputs, input_vars):
qv.set_shape(v.get_shape()) if config.batched_model_input:
qv.set_shape(v.get_shape())
else:
qv.set_shape(v.get_shape().as_list()[1:])
return model_inputs
if config.batched_model_input:
enqueue_op = input_queue.enqueue(input_vars)
else: else:
enqueue_op = input_queue.enqueue_many(input_vars) enqueue_op = input_queue.enqueue_many(input_vars)
model_inputs = input_queue.dequeue()
for qv, v in zip(model_inputs, input_vars): keys_to_maintain = [tf.GraphKeys.SUMMARIES, SUMMARY_VARS_KEY]
qv.set_shape(v.get_shape().as_list()[1:]) olds = {}
output_vars, cost_var = config.get_model_func(model_inputs, is_training=True) for k in keys_to_maintain:
olds[k] = copy.copy(tf.get_collection(k))
all_grads = []
n_tower = 1
for i in range(n_tower):
with tf.device('/gpu:{}'.format(i)):
with tf.name_scope('tower{}'.format(i)):
for k in keys_to_maintain:
del tf.get_collection(k)[:]
model_inputs = get_model_inputs()
output_vars, cost_var = config.get_model_func(model_inputs, is_training=True)
tf.get_variable_scope().reuse_variables()
grads = config.optimizer.compute_gradients(cost_var)
all_grads.append(grads)
for k in keys_to_maintain:
tf.get_collection(k).extend(olds[k])
grads = average_gradients(all_grads)
for grad, var in grads:
if grad:
tf.histogram_summary(var.op.name + '/gradients', grad)
avg_maintain_op = summary_moving_average(cost_var)
# build graph # build graph
tf.add_to_collection(FORWARD_FUNC_KEY, config.get_model_func) tf.add_to_collection(FORWARD_FUNC_KEY, config.get_model_func)
for v in input_vars: for v in input_vars:
tf.add_to_collection(INPUT_VARS_KEY, v) tf.add_to_collection(INPUT_VARS_KEY, v)
for v in output_vars:
tf.add_to_collection(OUTPUT_VARS_KEY, v)
describe_model() describe_model()
train_op = get_train_op(config.optimizer, cost_var) # train_op = get_train_op(config.optimizer, cost_var)
with tf.control_dependencies([avg_maintain_op]):
train_op = config.optimizer.apply_gradients(grads, get_global_step_var())
sess = tf.Session(config=config.session_config) sess = tf.Session(config=config.session_config)
config.session_init.init(sess) config.session_init.init(sess)
......
...@@ -44,6 +44,7 @@ class EnqueueThread(threading.Thread): ...@@ -44,6 +44,7 @@ class EnqueueThread(threading.Thread):
return return
feed = dict(izip(self.input_vars, dp)) feed = dict(izip(self.input_vars, dp))
self.sess.run([self.op], feed_dict=feed) self.sess.run([self.op], feed_dict=feed)
#print '\nExauhsted!!!'
except tf.errors.CancelledError as e: except tf.errors.CancelledError as e:
pass pass
except Exception: except Exception:
......
...@@ -9,7 +9,6 @@ GLOBAL_STEP_VAR_NAME = 'global_step:0' ...@@ -9,7 +9,6 @@ GLOBAL_STEP_VAR_NAME = 'global_step:0'
SUMMARY_WRITER_COLLECTION_KEY = 'summary_writer' SUMMARY_WRITER_COLLECTION_KEY = 'summary_writer'
INPUT_VARS_KEY = 'INPUT_VARIABLES' INPUT_VARS_KEY = 'INPUT_VARIABLES'
OUTPUT_VARS_KEY = 'OUTPUT_VARIABLES'
COST_VARS_KEY = 'COST_VARIABLES' # keep track of each individual cost COST_VARS_KEY = 'COST_VARIABLES' # keep track of each individual cost
SUMMARY_VARS_KEY = 'SUMMARY_VARIABLES' # extra variables to summarize during training SUMMARY_VARS_KEY = 'SUMMARY_VARIABLES' # extra variables to summarize during training
FORWARD_FUNC_KEY = 'FORWARD_FUNCTION' FORWARD_FUNC_KEY = 'FORWARD_FUNCTION'
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment