Commit 9b1d0907 authored by Yuxin Wu's avatar Yuxin Wu

fix bn performance

parent 3080e91e
......@@ -31,6 +31,8 @@ def BatchNormV1(x, use_local_stat=None, decay=0.9, epsilon=1e-5):
Default to True in training and False in inference.
:param decay: decay rate. default to 0.9.
:param epsilon: default to 1e-5.
Note that only the first training tower maintains a moving average.
"""
shape = x.get_shape().as_list()
......@@ -122,6 +124,8 @@ def BatchNormV2(x, use_local_stat=None, decay=0.9, epsilon=1e-5):
Default to True in training and False in inference.
:param decay: decay rate. default to 0.9.
:param epsilon: default to 1e-5.
Note that only the first training tower maintains a moving average.
"""
shape = x.get_shape().as_list()
assert len(shape) in [2, 4]
......@@ -150,29 +154,31 @@ def BatchNormV2(x, use_local_stat=None, decay=0.9, epsilon=1e-5):
if use_local_stat:
xn, batch_mean, batch_var = tf.nn.fused_batch_norm(x, gamma, beta,
epsilon=epsilon, is_training=True)
if ctx.is_training:
# maintain EMA if training
# maintain EMA only in the main training tower
if ctx.is_main_training_tower:
update_op1 = moving_averages.assign_moving_average(
moving_mean, batch_mean, decay, zero_debias=False,
name='mean_ema_op')
update_op2 = moving_averages.assign_moving_average(
moving_var, batch_var, decay, zero_debias=False,
name='var_ema_op')
if ctx.is_main_training_tower:
add_model_variable(moving_mean)
add_model_variable(moving_var)
add_model_variable(moving_mean)
add_model_variable(moving_var)
else:
assert not ctx.is_training, "In training, local statistics has to be used!"
# TODO do I need to add_model_variable.
# assume some fixed-param tasks, such as load model and fine tune one layer
# consider some fixed-param tasks, such as load model and fine tune one layer
# fused is slower in inference
# fused seems slower in inference
#xn, _, _ = tf.nn.fused_batch_norm(x, gamma, beta,
#moving_mean, moving_var,
#epsilon=epsilon, is_training=False, name='output')
xn = tf.nn.batch_normalization(
x, moving_mean, moving_var, beta, gamma, epsilon)
if ctx.is_training:
# TODO for other towers, maybe can make it depend some op later
if ctx.is_main_training_tower:
with tf.control_dependencies([update_op1, update_op2]):
return tf.identity(xn, name='output')
else:
......
......@@ -10,7 +10,7 @@ from ..tfutils import get_global_step_var
from ..tfutils.tower import TowerContext
from ..tfutils.gradproc import apply_grad_processors
from ..tfutils.summary import summary_moving_average, add_moving_summary
from .input_data import QueueInput, FeedfreeInput
from .input_data import QueueInput, FeedfreeInput, DummyConstantInput
from .base import Trainer
from .trainer import MultiPredictorTowerTrainer
......@@ -41,7 +41,9 @@ class SingleCostFeedfreeTrainer(FeedfreeTrainer):
cost_var = self.model.get_cost()
# GATE_NONE faster?
grads = self.config.optimizer.compute_gradients(
cost_var, gate_gradients=0)
cost_var,
gate_gradients=tf.train.Optimizer.GATE_NONE,
colocate_gradients_with_ops=False)
add_moving_summary(cost_var)
return cost_var, grads
......
......@@ -13,7 +13,8 @@ from ..tfutils.summary import add_moving_summary
from ..utils import logger
from ..callbacks.concurrency import StartProcOrThread
__all__ = ['QueueInput', 'FeedfreeInput', 'TensorInput']
__all__ = ['QueueInput', 'FeedfreeInput', 'TensorInput',
'DummyConstantInput']
@six.add_metaclass(ABCMeta)
class InputData(object):
......@@ -131,6 +132,24 @@ class QueueInput(FeedfreeInput):
#tf.Variable(tf.ones([128], dtype=tf.int32), trainable=False)]
return ret
class DummyConstantInput(QueueInput):
""" only for debugging performance issues """
def __init__(self, ds, shapes):
super(DummyConstantInput, self).__init__(ds)
self.shapes = shapes
logger.warn("Using dummy input for debug!")
def _get_input_tensors(self):
placehdrs = self.input_placehdrs
assert len(self.shapes) == len(placehdrs)
ret = []
for idx, p in enumerate(placehdrs):
with tf.device('/gpu:0'):
ret.append(tf.get_variable('dummy-' + p.op.name,
shape=self.shapes[idx], dtype=p.dtype, trainable=False,
initializer=tf.constant_initializer()))
return ret
class TensorInput(FeedfreeInput):
def __init__(self, get_tensor_fn, size=None):
self.get_tensor_fn = get_tensor_fn
......
......@@ -26,6 +26,7 @@ class MultiGPUTrainer(Trainer):
""" Base class for multi-gpu training"""
@staticmethod
def _multi_tower_grads(towers, get_tower_grad_func):
""" ret[i] is a lists of (grad,var) tuple for tower i"""
logger.info("Training a model of {} tower".format(len(towers)))
grad_list = []
......@@ -66,6 +67,8 @@ class SyncMultiGPUTrainer(MultiGPUTrainer,
@staticmethod
def _average_grads(tower_grads):
if len(tower_grads) == 1:
return tower_grads[0]
ret = []
with tf.name_scope('AvgGrad'):
for grad_and_vars in zip(*tower_grads):
......@@ -90,6 +93,12 @@ class SyncMultiGPUTrainer(MultiGPUTrainer,
super(SyncMultiGPUTrainer, self)._setup()
grad_list = MultiGPUTrainer._multi_tower_grads(
self.config.tower, lambda: self._get_cost_and_grad()[1])
# debug tower performance:
#ops = [k[0] for k in grad_list[1]] + [k[0] for k in grad_list[0]]
#self.train_op = tf.group(*ops)
#return
grads = SyncMultiGPUTrainer._average_grads(grad_list)
grads = apply_grad_processors(grads, self.model.get_gradient_processor())
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment