Commit b26a945e authored by Yuxin Wu's avatar Yuxin Wu

fix bn import bug

parent d8935ef3
......@@ -3,6 +3,8 @@
A script to load and run pre-trained CPM model released by Shih-En. The original code in caffe is [here](https://github.com/shihenw/convolutional-pose-machines-release).
Reference paper: [Convolutional Pose Machines](https://arxiv.org/abs/1602.00134), Shih-En et al., CVPR16.
Also check out [Stereo Pose Machines](https://github.com/ppwwyyxx/Stereo-Pose-Machines), a __real-time__ CPM application based on tensorpack.
## Usage:
Prepare the model:
......@@ -24,5 +26,3 @@ Input image will get resized to 368x368. Note that this CPM comes without person
person has to be in the center of the image (and not too small).
![demo](demo.jpg)
For a __real-time__ CPM application in tensorpack, check out [Stereo Pose Machines](https://github.com/ppwwyyxx/Stereo-Pose-Machines).
......@@ -6,12 +6,11 @@ import tensorflow as tf
from functools import wraps
import six
import copy
import os
from ..tfutils.argscope import get_arg_scope
from ..tfutils.modelutils import get_shape_str
from ..tfutils.summary import add_activation_summary
from ..utils import logger
from ..utils import logger, building_rtfd
from ..utils.argtools import shape2d
# make sure each layer is only logged once
......@@ -100,9 +99,7 @@ def layer_register(
return wrapped_func
# need some special handling for sphinx to work with the arguments
on_doc = os.environ.get('READTHEDOCS') == 'True' \
or os.environ.get('TENSORPACK_DOC_BUILDING')
if on_doc:
if building_rtfd():
from decorator import decorator
wrapper = decorator(wrapper)
......
......@@ -9,7 +9,7 @@ from tensorflow.python.training import moving_averages
from ..tfutils.common import get_tf_version
from ..tfutils.tower import get_current_tower_context
from ..utils import logger
from ..utils import logger, building_rtfd
from ._common import layer_register
__all__ = ['BatchNorm', 'BatchNormV1', 'BatchNormV2']
......@@ -99,8 +99,25 @@ def BatchNormV1(x, use_local_stat=None, decay=0.9, epsilon=1e-5):
@layer_register(log_shape=False)
def BatchNormV2(x, use_local_stat=None, decay=0.9, epsilon=1e-5):
"""
A slightly faster but equivalent version of BatchNormV1, which uses
``fused_batch_norm`` in training.
Batch normalization layer, as described in the paper:
`Batch Normalization: Accelerating Deep Network Training by
Reducing Internal Covariance Shift <http://arxiv.org/abs/1502.03167>`_.
Args:
x (tf.Tensor): a NHWC or NC tensor.
use_local_stat (bool): whether to use mean/var of the current batch or the moving average.
Defaults to True in training and False in inference.
decay (float): decay rate of moving average.
epsilon (float): epsilon to avoid divide-by-zero.
Note:
* In multi-tower training, only the first training tower maintains a moving average.
* It automatically selects :meth:`BatchNormV1` or :meth:`BatchNormV2`
according to availability.
* This is a slightly faster but equivalent version of BatchNormV1. It uses
``fused_batch_norm`` in training.
"""
shape = x.get_shape().as_list()
assert len(shape) in [2, 4]
......@@ -160,27 +177,8 @@ def BatchNormV2(x, use_local_stat=None, decay=0.9, epsilon=1e-5):
return tf.identity(xn, name='output')
def BatchNorm(*args, **kwargs):
"""
Batch normalization layer, as described in the paper:
`Batch Normalization: Accelerating Deep Network Training by
Reducing Internal Covariance Shift <http://arxiv.org/abs/1502.03167>`_.
Args:
x (tf.Tensor): a NHWC or NC tensor.
use_local_stat (bool): whether to use mean/var of the current batch or the moving average.
Defaults to True in training and False in inference.
decay (float): decay rate of moving average.
epsilon (float): epsilon to avoid divide-by-zero.
Note:
* In multi-tower training, only the first training tower maintains a moving average.
* It automatically selects :meth:`BatchNormV1` or :meth:`BatchNormV2`
according to availability.
"""
if get_tf_version() >= 12:
return BatchNormV2(*args, **kwargs)
else:
logger.warn("BatchNorm might be faster if you update TensorFlow")
return BatchNormV1(*args, **kwargs)
if building_rtfd() or get_tf_version() >= 12:
BatchNorm = BatchNormV2
else:
logger.warn("BatchNorm might be faster if you update TensorFlow")
BatchNorm = BatchNormV1
......@@ -15,7 +15,8 @@ __all__ = ['change_env',
'get_dataset_path',
'get_tqdm_kwargs',
'get_tqdm',
'execute_only_once'
'execute_only_once',
'building_rtfd'
]
......@@ -85,3 +86,8 @@ def get_tqdm_kwargs(**kwargs):
def get_tqdm(**kwargs):
return tqdm(**get_tqdm_kwargs(**kwargs))
def building_rtfd():
return os.environ.get('READTHEDOCS') == 'True' \
or os.environ.get('TENSORPACK_DOC_BUILDING')
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment