Commit b26a945e authored by Yuxin Wu's avatar Yuxin Wu

fix bn import bug

parent d8935ef3
...@@ -3,6 +3,8 @@ ...@@ -3,6 +3,8 @@
A script to load and run pre-trained CPM model released by Shih-En. The original code in caffe is [here](https://github.com/shihenw/convolutional-pose-machines-release). A script to load and run pre-trained CPM model released by Shih-En. The original code in caffe is [here](https://github.com/shihenw/convolutional-pose-machines-release).
Reference paper: [Convolutional Pose Machines](https://arxiv.org/abs/1602.00134), Shih-En et al., CVPR16. Reference paper: [Convolutional Pose Machines](https://arxiv.org/abs/1602.00134), Shih-En et al., CVPR16.
Also check out [Stereo Pose Machines](https://github.com/ppwwyyxx/Stereo-Pose-Machines), a __real-time__ CPM application based on tensorpack.
## Usage: ## Usage:
Prepare the model: Prepare the model:
...@@ -24,5 +26,3 @@ Input image will get resized to 368x368. Note that this CPM comes without person ...@@ -24,5 +26,3 @@ Input image will get resized to 368x368. Note that this CPM comes without person
person has to be in the center of the image (and not too small). person has to be in the center of the image (and not too small).
![demo](demo.jpg) ![demo](demo.jpg)
For a __real-time__ CPM application in tensorpack, check out [Stereo Pose Machines](https://github.com/ppwwyyxx/Stereo-Pose-Machines).
...@@ -6,12 +6,11 @@ import tensorflow as tf ...@@ -6,12 +6,11 @@ import tensorflow as tf
from functools import wraps from functools import wraps
import six import six
import copy import copy
import os
from ..tfutils.argscope import get_arg_scope from ..tfutils.argscope import get_arg_scope
from ..tfutils.modelutils import get_shape_str from ..tfutils.modelutils import get_shape_str
from ..tfutils.summary import add_activation_summary from ..tfutils.summary import add_activation_summary
from ..utils import logger from ..utils import logger, building_rtfd
from ..utils.argtools import shape2d from ..utils.argtools import shape2d
# make sure each layer is only logged once # make sure each layer is only logged once
...@@ -100,9 +99,7 @@ def layer_register( ...@@ -100,9 +99,7 @@ def layer_register(
return wrapped_func return wrapped_func
# need some special handling for sphinx to work with the arguments # need some special handling for sphinx to work with the arguments
on_doc = os.environ.get('READTHEDOCS') == 'True' \ if building_rtfd():
or os.environ.get('TENSORPACK_DOC_BUILDING')
if on_doc:
from decorator import decorator from decorator import decorator
wrapper = decorator(wrapper) wrapper = decorator(wrapper)
......
...@@ -9,7 +9,7 @@ from tensorflow.python.training import moving_averages ...@@ -9,7 +9,7 @@ from tensorflow.python.training import moving_averages
from ..tfutils.common import get_tf_version from ..tfutils.common import get_tf_version
from ..tfutils.tower import get_current_tower_context from ..tfutils.tower import get_current_tower_context
from ..utils import logger from ..utils import logger, building_rtfd
from ._common import layer_register from ._common import layer_register
__all__ = ['BatchNorm', 'BatchNormV1', 'BatchNormV2'] __all__ = ['BatchNorm', 'BatchNormV1', 'BatchNormV2']
...@@ -99,8 +99,25 @@ def BatchNormV1(x, use_local_stat=None, decay=0.9, epsilon=1e-5): ...@@ -99,8 +99,25 @@ def BatchNormV1(x, use_local_stat=None, decay=0.9, epsilon=1e-5):
@layer_register(log_shape=False) @layer_register(log_shape=False)
def BatchNormV2(x, use_local_stat=None, decay=0.9, epsilon=1e-5): def BatchNormV2(x, use_local_stat=None, decay=0.9, epsilon=1e-5):
""" """
A slightly faster but equivalent version of BatchNormV1, which uses Batch normalization layer, as described in the paper:
``fused_batch_norm`` in training. `Batch Normalization: Accelerating Deep Network Training by
Reducing Internal Covariance Shift <http://arxiv.org/abs/1502.03167>`_.
Args:
x (tf.Tensor): a NHWC or NC tensor.
use_local_stat (bool): whether to use mean/var of the current batch or the moving average.
Defaults to True in training and False in inference.
decay (float): decay rate of moving average.
epsilon (float): epsilon to avoid divide-by-zero.
Note:
* In multi-tower training, only the first training tower maintains a moving average.
* It automatically selects :meth:`BatchNormV1` or :meth:`BatchNormV2`
according to availability.
* This is a slightly faster but equivalent version of BatchNormV1. It uses
``fused_batch_norm`` in training.
""" """
shape = x.get_shape().as_list() shape = x.get_shape().as_list()
assert len(shape) in [2, 4] assert len(shape) in [2, 4]
...@@ -160,27 +177,8 @@ def BatchNormV2(x, use_local_stat=None, decay=0.9, epsilon=1e-5): ...@@ -160,27 +177,8 @@ def BatchNormV2(x, use_local_stat=None, decay=0.9, epsilon=1e-5):
return tf.identity(xn, name='output') return tf.identity(xn, name='output')
def BatchNorm(*args, **kwargs): if building_rtfd() or get_tf_version() >= 12:
""" BatchNorm = BatchNormV2
Batch normalization layer, as described in the paper: else:
`Batch Normalization: Accelerating Deep Network Training by logger.warn("BatchNorm might be faster if you update TensorFlow")
Reducing Internal Covariance Shift <http://arxiv.org/abs/1502.03167>`_. BatchNorm = BatchNormV1
Args:
x (tf.Tensor): a NHWC or NC tensor.
use_local_stat (bool): whether to use mean/var of the current batch or the moving average.
Defaults to True in training and False in inference.
decay (float): decay rate of moving average.
epsilon (float): epsilon to avoid divide-by-zero.
Note:
* In multi-tower training, only the first training tower maintains a moving average.
* It automatically selects :meth:`BatchNormV1` or :meth:`BatchNormV2`
according to availability.
"""
if get_tf_version() >= 12:
return BatchNormV2(*args, **kwargs)
else:
logger.warn("BatchNorm might be faster if you update TensorFlow")
return BatchNormV1(*args, **kwargs)
...@@ -15,7 +15,8 @@ __all__ = ['change_env', ...@@ -15,7 +15,8 @@ __all__ = ['change_env',
'get_dataset_path', 'get_dataset_path',
'get_tqdm_kwargs', 'get_tqdm_kwargs',
'get_tqdm', 'get_tqdm',
'execute_only_once' 'execute_only_once',
'building_rtfd'
] ]
...@@ -85,3 +86,8 @@ def get_tqdm_kwargs(**kwargs): ...@@ -85,3 +86,8 @@ def get_tqdm_kwargs(**kwargs):
def get_tqdm(**kwargs): def get_tqdm(**kwargs):
return tqdm(**get_tqdm_kwargs(**kwargs)) return tqdm(**get_tqdm_kwargs(**kwargs))
def building_rtfd():
return os.environ.get('READTHEDOCS') == 'True' \
or os.environ.get('TENSORPACK_DOC_BUILDING')
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment