Commit 5bc36930 authored by Yuxin Wu's avatar Yuxin Wu

lr multiplier

parent 7fe010cb
#!/usr/bin/env python2
# -*- coding: UTF-8 -*-
# -*- coding: utf-8 -*-
# File: example_mnist.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
......@@ -38,7 +38,7 @@ class Model(ModelDesc):
image = tf.expand_dims(image, 3) # add a single channel
l = Conv2D('conv0', image, out_channel=32, kernel_shape=3)
l = Conv2D('conv1', image, out_channel=32, kernel_shape=3)
l = Conv2D('conv1', l, out_channel=32, kernel_shape=3)
l = MaxPooling('pool0', l, 2)
l = Conv2D('conv2', l, out_channel=40, kernel_shape=3)
l = MaxPooling('pool1', l, 2)
......@@ -122,3 +122,4 @@ if __name__ == '__main__':
if args.load:
config.session_init = SaverRestore(args.load)
start_train(config)
......@@ -61,7 +61,7 @@ class ModelDesc(object):
but must have the same length
"""
def get_lr_multipler(self):
def get_lr_multiplier(self):
"""
Return a dict of {variable_regex: multiplier}
"""
......
......@@ -7,6 +7,7 @@ import tensorflow as tf
from itertools import count
import copy
import argparse
import re
import tqdm
from models import ModelDesc
......@@ -59,13 +60,13 @@ class TrainConfig(object):
self.nr_tower = int(kwargs.pop('nr_tower', 1))
assert len(kwargs) == 0, 'Unknown arguments: {}'.format(str(kwargs.keys()))
def average_gradients(tower_grads):
average_grads = []
def average_grads(tower_grads):
ret = []
for grad_and_vars in zip(*tower_grads):
grad = tf.add_n([x[0] for x in grad_and_vars]) / float(len(tower_grads))
v = grad_and_vars[0][1]
average_grads.append((grad, v))
return average_grads
ret.append((grad, v))
return ret
def summary_grads(grads):
for grad, var in grads:
......@@ -74,8 +75,22 @@ def summary_grads(grads):
def check_grads(grads):
for grad, var in grads:
assert grad is not None, "Grad is None for variable {}".format(var.name)
tf.Assert(tf.reduce_all(tf.is_finite(var)), [var])
def scale_grads(grads, multiplier):
ret = []
for grad, var in grads:
varname = var.name
for regex, val in multiplier.iteritems():
if re.search(regex, varname):
logger.info("Apply lr multiplier {} for {}".format(val, varname))
ret.append((grad * val, var))
break
else:
ret.append((grad, var))
return ret
def start_train(config):
"""
Start training with the given config
......@@ -120,15 +135,17 @@ def start_train(config):
for k in coll_keys: # avoid repeating summary on multiple devices
del tf.get_collection(k)[:]
tf.get_collection(k).extend(kept_summaries[k])
grads = average_gradients(grads)
grads = average_grads(grads)
else:
model_inputs = get_model_inputs()
cost_var = model.get_cost(model_inputs, is_training=True)
grads = config.optimizer.compute_gradients(cost_var)
summary_grads(grads)
check_grads(grads)
avg_maintain_op = summary_moving_average(cost_var)
check_grads(grads)
grads = scale_grads(grads, model.get_lr_multiplier())
summary_grads(grads)
train_op = tf.group(
config.optimizer.apply_gradients(grads, get_global_step_var()),
avg_maintain_op)
......@@ -154,7 +171,7 @@ def start_train(config):
for epoch in xrange(1, config.max_epoch):
with timed_operation('epoch {}'.format(epoch)):
for step in tqdm.trange(
config.step_per_epoch, leave=True, mininterval=0.2):
config.step_per_epoch, leave=True, mininterval=0.5, dynamic_ncols=True):
if coord.should_stop():
return
sess.run([train_op]) # faster since train_op return None
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment