Commit 5bc36930 authored by Yuxin Wu's avatar Yuxin Wu

lr multiplier

parent 7fe010cb
#!/usr/bin/env python2 #!/usr/bin/env python2
# -*- coding: UTF-8 -*- # -*- coding: utf-8 -*-
# File: example_mnist.py # File: example_mnist.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com> # Author: Yuxin Wu <ppwwyyxx@gmail.com>
...@@ -38,7 +38,7 @@ class Model(ModelDesc): ...@@ -38,7 +38,7 @@ class Model(ModelDesc):
image = tf.expand_dims(image, 3) # add a single channel image = tf.expand_dims(image, 3) # add a single channel
l = Conv2D('conv0', image, out_channel=32, kernel_shape=3) l = Conv2D('conv0', image, out_channel=32, kernel_shape=3)
l = Conv2D('conv1', image, out_channel=32, kernel_shape=3) l = Conv2D('conv1', l, out_channel=32, kernel_shape=3)
l = MaxPooling('pool0', l, 2) l = MaxPooling('pool0', l, 2)
l = Conv2D('conv2', l, out_channel=40, kernel_shape=3) l = Conv2D('conv2', l, out_channel=40, kernel_shape=3)
l = MaxPooling('pool1', l, 2) l = MaxPooling('pool1', l, 2)
...@@ -122,3 +122,4 @@ if __name__ == '__main__': ...@@ -122,3 +122,4 @@ if __name__ == '__main__':
if args.load: if args.load:
config.session_init = SaverRestore(args.load) config.session_init = SaverRestore(args.load)
start_train(config) start_train(config)
...@@ -61,7 +61,7 @@ class ModelDesc(object): ...@@ -61,7 +61,7 @@ class ModelDesc(object):
but must have the same length but must have the same length
""" """
def get_lr_multipler(self): def get_lr_multiplier(self):
""" """
Return a dict of {variable_regex: multiplier} Return a dict of {variable_regex: multiplier}
""" """
......
...@@ -7,6 +7,7 @@ import tensorflow as tf ...@@ -7,6 +7,7 @@ import tensorflow as tf
from itertools import count from itertools import count
import copy import copy
import argparse import argparse
import re
import tqdm import tqdm
from models import ModelDesc from models import ModelDesc
...@@ -59,13 +60,13 @@ class TrainConfig(object): ...@@ -59,13 +60,13 @@ class TrainConfig(object):
self.nr_tower = int(kwargs.pop('nr_tower', 1)) self.nr_tower = int(kwargs.pop('nr_tower', 1))
assert len(kwargs) == 0, 'Unknown arguments: {}'.format(str(kwargs.keys())) assert len(kwargs) == 0, 'Unknown arguments: {}'.format(str(kwargs.keys()))
def average_gradients(tower_grads): def average_grads(tower_grads):
average_grads = [] ret = []
for grad_and_vars in zip(*tower_grads): for grad_and_vars in zip(*tower_grads):
grad = tf.add_n([x[0] for x in grad_and_vars]) / float(len(tower_grads)) grad = tf.add_n([x[0] for x in grad_and_vars]) / float(len(tower_grads))
v = grad_and_vars[0][1] v = grad_and_vars[0][1]
average_grads.append((grad, v)) ret.append((grad, v))
return average_grads return ret
def summary_grads(grads): def summary_grads(grads):
for grad, var in grads: for grad, var in grads:
...@@ -74,8 +75,22 @@ def summary_grads(grads): ...@@ -74,8 +75,22 @@ def summary_grads(grads):
def check_grads(grads): def check_grads(grads):
for grad, var in grads: for grad, var in grads:
assert grad is not None, "Grad is None for variable {}".format(var.name)
tf.Assert(tf.reduce_all(tf.is_finite(var)), [var]) tf.Assert(tf.reduce_all(tf.is_finite(var)), [var])
def scale_grads(grads, multiplier):
ret = []
for grad, var in grads:
varname = var.name
for regex, val in multiplier.iteritems():
if re.search(regex, varname):
logger.info("Apply lr multiplier {} for {}".format(val, varname))
ret.append((grad * val, var))
break
else:
ret.append((grad, var))
return ret
def start_train(config): def start_train(config):
""" """
Start training with the given config Start training with the given config
...@@ -120,15 +135,17 @@ def start_train(config): ...@@ -120,15 +135,17 @@ def start_train(config):
for k in coll_keys: # avoid repeating summary on multiple devices for k in coll_keys: # avoid repeating summary on multiple devices
del tf.get_collection(k)[:] del tf.get_collection(k)[:]
tf.get_collection(k).extend(kept_summaries[k]) tf.get_collection(k).extend(kept_summaries[k])
grads = average_gradients(grads) grads = average_grads(grads)
else: else:
model_inputs = get_model_inputs() model_inputs = get_model_inputs()
cost_var = model.get_cost(model_inputs, is_training=True) cost_var = model.get_cost(model_inputs, is_training=True)
grads = config.optimizer.compute_gradients(cost_var) grads = config.optimizer.compute_gradients(cost_var)
summary_grads(grads)
check_grads(grads)
avg_maintain_op = summary_moving_average(cost_var) avg_maintain_op = summary_moving_average(cost_var)
check_grads(grads)
grads = scale_grads(grads, model.get_lr_multiplier())
summary_grads(grads)
train_op = tf.group( train_op = tf.group(
config.optimizer.apply_gradients(grads, get_global_step_var()), config.optimizer.apply_gradients(grads, get_global_step_var()),
avg_maintain_op) avg_maintain_op)
...@@ -154,7 +171,7 @@ def start_train(config): ...@@ -154,7 +171,7 @@ def start_train(config):
for epoch in xrange(1, config.max_epoch): for epoch in xrange(1, config.max_epoch):
with timed_operation('epoch {}'.format(epoch)): with timed_operation('epoch {}'.format(epoch)):
for step in tqdm.trange( for step in tqdm.trange(
config.step_per_epoch, leave=True, mininterval=0.2): config.step_per_epoch, leave=True, mininterval=0.5, dynamic_ncols=True):
if coord.should_stop(): if coord.should_stop():
return return
sess.run([train_op]) # faster since train_op return None sess.run([train_op]) # faster since train_op return None
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment