Commit 5fc3a2b1 authored by Yuxin Wu's avatar Yuxin Wu

use tf.layers for pooling impl (#291)

parent 377fba1f
......@@ -140,8 +140,10 @@ class Monitors(Callback):
"""
Put a scalar.
"""
if isinstance(val, (np.float32, np.float64)):
if isinstance(val, np.floating):
val = float(val)
if isinstance(val, np.integer):
val = int(val)
self._dispatch(lambda m: m.process_scalar(name, val))
s = create_scalar_summary(name, val)
self._dispatch(lambda m: m.process_summary(s))
......
......@@ -7,7 +7,7 @@ import numpy as np
from .shape_utils import StaticDynamicShape
from .common import layer_register
from ..utils.argtools import shape2d, shape4d
from ..utils.argtools import shape2d
from ._test import TestModel
......@@ -15,20 +15,6 @@ __all__ = ['MaxPooling', 'FixedUnPooling', 'AvgPooling', 'GlobalAvgPooling',
'BilinearUpSample']
def _Pooling(func, x, shape, stride, padding, data_format):
padding = padding.upper()
shape = shape4d(shape, data_format=data_format)
if stride is None:
stride = shape
else:
stride = shape4d(stride, data_format=data_format)
return func(x, ksize=shape,
strides=stride, padding=padding,
data_format=data_format,
name='output')
@layer_register(log_shape=True)
def MaxPooling(x, shape, stride=None, padding='VALID', data_format='NHWC'):
"""
......@@ -43,8 +29,9 @@ def MaxPooling(x, shape, stride=None, padding='VALID', data_format='NHWC'):
Returns:
tf.Tensor named ``output``.
"""
return _Pooling(tf.nn.max_pool, x, shape, stride, padding,
data_format=data_format)
ret = tf.layers.max_pooling2d(x, shape, stride, padding,
'channels_last' if data_format == 'NHWC' else 'channels_first')
return tf.identity(ret, name='output')
@layer_register(log_shape=True)
......@@ -61,8 +48,9 @@ def AvgPooling(x, shape, stride=None, padding='VALID', data_format='NHWC'):
Returns:
tf.Tensor named ``output``.
"""
return _Pooling(tf.nn.avg_pool, x, shape, stride, padding,
data_format=data_format)
ret = tf.layers.average_pooling2d(x, shape, stride, padding,
'channels_last' if data_format == 'NHWC' else 'channels_first')
return tf.identity(ret, name='output')
@layer_register(log_shape=True)
......
......@@ -73,7 +73,7 @@ class SessionUpdate(object):
name, val.shape, varshape))
val = val.reshape(varshape)
# fix some common type incompatibility problem, but is certainly not enough
# fix some common type incompatibility problems, but not all
def upcast(vartype, valtype):
# allow up-casting
if vartype == tf.float64 and valtype == np.float32:
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment