Commit 02ef551b authored by Yuxin Wu's avatar Yuxin Wu

remove use_local_stat warnings from BN

parent fd9edc3b
...@@ -9,7 +9,7 @@ For certain tasks, you do need a new trainer. ...@@ -9,7 +9,7 @@ For certain tasks, you do need a new trainer.
Trainers just run __some__ iterations, so there is no limit in where the data come from or what to do in an iteration. Trainers just run __some__ iterations, so there is no limit in where the data come from or what to do in an iteration.
The existing common trainers all do two things: The existing common trainers all do two things:
1. Setup the graph and input pipeline, from `TrainConfig`. 1. Setup the graph and input pipeline, using the given `TrainConfig`.
2. Minimize `model.cost` in each iteration. 2. Minimize `model.cost` in each iteration.
But you can customize it by using the base `Trainer` class. But you can customize it by using the base `Trainer` class.
......
...@@ -10,7 +10,6 @@ from tensorflow.python.layers.normalization import BatchNorm as TF_BatchNorm ...@@ -10,7 +10,6 @@ from tensorflow.python.layers.normalization import BatchNorm as TF_BatchNorm
from ..tfutils.tower import get_current_tower_context from ..tfutils.tower import get_current_tower_context
from ..tfutils.collection import backup_collection, restore_collection from ..tfutils.collection import backup_collection, restore_collection
from ..utils import logger
from .common import layer_register, VariableHolder from .common import layer_register, VariableHolder
__all__ = ['BatchNorm', 'BatchRenorm'] __all__ = ['BatchNorm', 'BatchRenorm']
...@@ -116,10 +115,7 @@ def BatchNorm(x, use_local_stat=None, decay=0.9, epsilon=1e-5, ...@@ -116,10 +115,7 @@ def BatchNorm(x, use_local_stat=None, decay=0.9, epsilon=1e-5,
ctx = get_current_tower_context() ctx = get_current_tower_context()
if use_local_stat is None: if use_local_stat is None:
use_local_stat = ctx.is_training use_local_stat = ctx.is_training
elif use_local_stat != ctx.is_training: use_local_stat = bool(use_local_stat)
# we allow the use of local_stat in testing (only print warnings)
# because it is useful to certain applications.
logger.warn("[BatchNorm] use_local_stat != is_training")
if use_local_stat: if use_local_stat:
if ndims == 2: if ndims == 2:
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment