Commit 9b1d0907 authored by Yuxin Wu's avatar Yuxin Wu

fix bn performance

parent 3080e91e
...@@ -31,6 +31,8 @@ def BatchNormV1(x, use_local_stat=None, decay=0.9, epsilon=1e-5): ...@@ -31,6 +31,8 @@ def BatchNormV1(x, use_local_stat=None, decay=0.9, epsilon=1e-5):
Default to True in training and False in inference. Default to True in training and False in inference.
:param decay: decay rate. default to 0.9. :param decay: decay rate. default to 0.9.
:param epsilon: default to 1e-5. :param epsilon: default to 1e-5.
Note that only the first training tower maintains a moving average.
""" """
shape = x.get_shape().as_list() shape = x.get_shape().as_list()
...@@ -122,6 +124,8 @@ def BatchNormV2(x, use_local_stat=None, decay=0.9, epsilon=1e-5): ...@@ -122,6 +124,8 @@ def BatchNormV2(x, use_local_stat=None, decay=0.9, epsilon=1e-5):
Default to True in training and False in inference. Default to True in training and False in inference.
:param decay: decay rate. default to 0.9. :param decay: decay rate. default to 0.9.
:param epsilon: default to 1e-5. :param epsilon: default to 1e-5.
Note that only the first training tower maintains a moving average.
""" """
shape = x.get_shape().as_list() shape = x.get_shape().as_list()
assert len(shape) in [2, 4] assert len(shape) in [2, 4]
...@@ -150,29 +154,31 @@ def BatchNormV2(x, use_local_stat=None, decay=0.9, epsilon=1e-5): ...@@ -150,29 +154,31 @@ def BatchNormV2(x, use_local_stat=None, decay=0.9, epsilon=1e-5):
if use_local_stat: if use_local_stat:
xn, batch_mean, batch_var = tf.nn.fused_batch_norm(x, gamma, beta, xn, batch_mean, batch_var = tf.nn.fused_batch_norm(x, gamma, beta,
epsilon=epsilon, is_training=True) epsilon=epsilon, is_training=True)
if ctx.is_training:
# maintain EMA if training # maintain EMA only in the main training tower
if ctx.is_main_training_tower:
update_op1 = moving_averages.assign_moving_average( update_op1 = moving_averages.assign_moving_average(
moving_mean, batch_mean, decay, zero_debias=False, moving_mean, batch_mean, decay, zero_debias=False,
name='mean_ema_op') name='mean_ema_op')
update_op2 = moving_averages.assign_moving_average( update_op2 = moving_averages.assign_moving_average(
moving_var, batch_var, decay, zero_debias=False, moving_var, batch_var, decay, zero_debias=False,
name='var_ema_op') name='var_ema_op')
if ctx.is_main_training_tower: add_model_variable(moving_mean)
add_model_variable(moving_mean) add_model_variable(moving_var)
add_model_variable(moving_var)
else: else:
assert not ctx.is_training, "In training, local statistics has to be used!" assert not ctx.is_training, "In training, local statistics has to be used!"
# TODO do I need to add_model_variable. # TODO do I need to add_model_variable.
# assume some fixed-param tasks, such as load model and fine tune one layer # consider some fixed-param tasks, such as load model and fine tune one layer
# fused is slower in inference # fused seems slower in inference
#xn, _, _ = tf.nn.fused_batch_norm(x, gamma, beta, #xn, _, _ = tf.nn.fused_batch_norm(x, gamma, beta,
#moving_mean, moving_var, #moving_mean, moving_var,
#epsilon=epsilon, is_training=False, name='output') #epsilon=epsilon, is_training=False, name='output')
xn = tf.nn.batch_normalization( xn = tf.nn.batch_normalization(
x, moving_mean, moving_var, beta, gamma, epsilon) x, moving_mean, moving_var, beta, gamma, epsilon)
if ctx.is_training:
# TODO for other towers, maybe can make it depend some op later
if ctx.is_main_training_tower:
with tf.control_dependencies([update_op1, update_op2]): with tf.control_dependencies([update_op1, update_op2]):
return tf.identity(xn, name='output') return tf.identity(xn, name='output')
else: else:
......
...@@ -10,7 +10,7 @@ from ..tfutils import get_global_step_var ...@@ -10,7 +10,7 @@ from ..tfutils import get_global_step_var
from ..tfutils.tower import TowerContext from ..tfutils.tower import TowerContext
from ..tfutils.gradproc import apply_grad_processors from ..tfutils.gradproc import apply_grad_processors
from ..tfutils.summary import summary_moving_average, add_moving_summary from ..tfutils.summary import summary_moving_average, add_moving_summary
from .input_data import QueueInput, FeedfreeInput from .input_data import QueueInput, FeedfreeInput, DummyConstantInput
from .base import Trainer from .base import Trainer
from .trainer import MultiPredictorTowerTrainer from .trainer import MultiPredictorTowerTrainer
...@@ -41,7 +41,9 @@ class SingleCostFeedfreeTrainer(FeedfreeTrainer): ...@@ -41,7 +41,9 @@ class SingleCostFeedfreeTrainer(FeedfreeTrainer):
cost_var = self.model.get_cost() cost_var = self.model.get_cost()
# GATE_NONE faster? # GATE_NONE faster?
grads = self.config.optimizer.compute_gradients( grads = self.config.optimizer.compute_gradients(
cost_var, gate_gradients=0) cost_var,
gate_gradients=tf.train.Optimizer.GATE_NONE,
colocate_gradients_with_ops=False)
add_moving_summary(cost_var) add_moving_summary(cost_var)
return cost_var, grads return cost_var, grads
......
...@@ -13,7 +13,8 @@ from ..tfutils.summary import add_moving_summary ...@@ -13,7 +13,8 @@ from ..tfutils.summary import add_moving_summary
from ..utils import logger from ..utils import logger
from ..callbacks.concurrency import StartProcOrThread from ..callbacks.concurrency import StartProcOrThread
__all__ = ['QueueInput', 'FeedfreeInput', 'TensorInput'] __all__ = ['QueueInput', 'FeedfreeInput', 'TensorInput',
'DummyConstantInput']
@six.add_metaclass(ABCMeta) @six.add_metaclass(ABCMeta)
class InputData(object): class InputData(object):
...@@ -131,6 +132,24 @@ class QueueInput(FeedfreeInput): ...@@ -131,6 +132,24 @@ class QueueInput(FeedfreeInput):
#tf.Variable(tf.ones([128], dtype=tf.int32), trainable=False)] #tf.Variable(tf.ones([128], dtype=tf.int32), trainable=False)]
return ret return ret
class DummyConstantInput(QueueInput):
""" only for debugging performance issues """
def __init__(self, ds, shapes):
super(DummyConstantInput, self).__init__(ds)
self.shapes = shapes
logger.warn("Using dummy input for debug!")
def _get_input_tensors(self):
placehdrs = self.input_placehdrs
assert len(self.shapes) == len(placehdrs)
ret = []
for idx, p in enumerate(placehdrs):
with tf.device('/gpu:0'):
ret.append(tf.get_variable('dummy-' + p.op.name,
shape=self.shapes[idx], dtype=p.dtype, trainable=False,
initializer=tf.constant_initializer()))
return ret
class TensorInput(FeedfreeInput): class TensorInput(FeedfreeInput):
def __init__(self, get_tensor_fn, size=None): def __init__(self, get_tensor_fn, size=None):
self.get_tensor_fn = get_tensor_fn self.get_tensor_fn = get_tensor_fn
......
...@@ -26,6 +26,7 @@ class MultiGPUTrainer(Trainer): ...@@ -26,6 +26,7 @@ class MultiGPUTrainer(Trainer):
""" Base class for multi-gpu training""" """ Base class for multi-gpu training"""
@staticmethod @staticmethod
def _multi_tower_grads(towers, get_tower_grad_func): def _multi_tower_grads(towers, get_tower_grad_func):
""" ret[i] is a lists of (grad,var) tuple for tower i"""
logger.info("Training a model of {} tower".format(len(towers))) logger.info("Training a model of {} tower".format(len(towers)))
grad_list = [] grad_list = []
...@@ -66,6 +67,8 @@ class SyncMultiGPUTrainer(MultiGPUTrainer, ...@@ -66,6 +67,8 @@ class SyncMultiGPUTrainer(MultiGPUTrainer,
@staticmethod @staticmethod
def _average_grads(tower_grads): def _average_grads(tower_grads):
if len(tower_grads) == 1:
return tower_grads[0]
ret = [] ret = []
with tf.name_scope('AvgGrad'): with tf.name_scope('AvgGrad'):
for grad_and_vars in zip(*tower_grads): for grad_and_vars in zip(*tower_grads):
...@@ -90,6 +93,12 @@ class SyncMultiGPUTrainer(MultiGPUTrainer, ...@@ -90,6 +93,12 @@ class SyncMultiGPUTrainer(MultiGPUTrainer,
super(SyncMultiGPUTrainer, self)._setup() super(SyncMultiGPUTrainer, self)._setup()
grad_list = MultiGPUTrainer._multi_tower_grads( grad_list = MultiGPUTrainer._multi_tower_grads(
self.config.tower, lambda: self._get_cost_and_grad()[1]) self.config.tower, lambda: self._get_cost_and_grad()[1])
# debug tower performance:
#ops = [k[0] for k in grad_list[1]] + [k[0] for k in grad_list[0]]
#self.train_op = tf.group(*ops)
#return
grads = SyncMultiGPUTrainer._average_grads(grad_list) grads = SyncMultiGPUTrainer._average_grads(grad_list)
grads = apply_grad_processors(grads, self.model.get_gradient_processor()) grads = apply_grad_processors(grads, self.model.get_gradient_processor())
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment