Commit 61305080 authored by Yuxin Wu's avatar Yuxin Wu

allow use of EMA in training BN

parent bfbc3a83
......@@ -8,6 +8,7 @@ from tensorflow.contrib.framework import add_model_variable
from tensorflow.python.training import moving_averages
from tensorflow.python.layers.normalization import BatchNorm as TF_BatchNorm
from ..utils import logger
from ..tfutils.tower import get_current_tower_context
from ..tfutils.collection import backup_collection, restore_collection
from .common import layer_register, VariableHolder
......@@ -129,7 +130,8 @@ def BatchNorm(x, use_local_stat=None, decay=0.9, epsilon=1e-5,
if ndims == 2:
xn = tf.squeeze(xn, [1, 2])
else:
assert not ctx.is_training, "In training, local statistics has to be used!"
if ctx.is_training:
logger.warn("[BatchNorm] Using moving_mean/moving_variance in training.")
# non-fused op is faster for inference
if ndims == 4 and data_format == 'NCHW':
[g, b, mm, mv] = [reshape_for_bn(_, ndims, n_out, data_format)
......@@ -142,7 +144,7 @@ def BatchNorm(x, use_local_stat=None, decay=0.9, epsilon=1e-5,
# maintain EMA only on one GPU is OK, even in replicated mode.
# because training time doesn't use EMA
if ctx.is_main_training_tower:
if ctx.is_main_training_tower and use_local_stat:
ret = update_bn_ema(xn, batch_mean, batch_var, moving_mean, moving_var, decay)
else:
ret = tf.identity(xn, name='output')
......
......@@ -33,7 +33,7 @@ class BoxBase(object):
return self.w * self.h
def is_box(self):
return self.area() > 0
return self.w > 0 and self.h > 0
class IntBox(BoxBase):
......@@ -69,6 +69,18 @@ class IntBox(BoxBase):
return False
return True
def clip_by_shape(self, shape):
"""
Clip xs and ys to be valid coordinates inside shape
Args:
shape: int [h, w] or None.
"""
self.x1 = np.clip(self.x1, 0, shape[1] - 1)
self.x2 = np.clip(self.x2, 0, shape[1] - 1)
self.y1 = np.clip(self.y1, 0, shape[0] - 1)
self.y2 = np.clip(self.y2, 0, shape[0] - 1)
def roi(self, img):
assert self.validate(img.shape[:2]), "{} vs {}".format(self, img.shape[:2])
return img[self.y1:self.y2 + 1, self.x1:self.x2 + 1]
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment