Commit 5fc3a2b1 authored by Yuxin Wu's avatar Yuxin Wu

use tf.layers for pooling impl (#291)

parent 377fba1f
...@@ -140,8 +140,10 @@ class Monitors(Callback): ...@@ -140,8 +140,10 @@ class Monitors(Callback):
""" """
Put a scalar. Put a scalar.
""" """
if isinstance(val, (np.float32, np.float64)): if isinstance(val, np.floating):
val = float(val) val = float(val)
if isinstance(val, np.integer):
val = int(val)
self._dispatch(lambda m: m.process_scalar(name, val)) self._dispatch(lambda m: m.process_scalar(name, val))
s = create_scalar_summary(name, val) s = create_scalar_summary(name, val)
self._dispatch(lambda m: m.process_summary(s)) self._dispatch(lambda m: m.process_summary(s))
......
...@@ -7,7 +7,7 @@ import numpy as np ...@@ -7,7 +7,7 @@ import numpy as np
from .shape_utils import StaticDynamicShape from .shape_utils import StaticDynamicShape
from .common import layer_register from .common import layer_register
from ..utils.argtools import shape2d, shape4d from ..utils.argtools import shape2d
from ._test import TestModel from ._test import TestModel
...@@ -15,20 +15,6 @@ __all__ = ['MaxPooling', 'FixedUnPooling', 'AvgPooling', 'GlobalAvgPooling', ...@@ -15,20 +15,6 @@ __all__ = ['MaxPooling', 'FixedUnPooling', 'AvgPooling', 'GlobalAvgPooling',
'BilinearUpSample'] 'BilinearUpSample']
def _Pooling(func, x, shape, stride, padding, data_format):
padding = padding.upper()
shape = shape4d(shape, data_format=data_format)
if stride is None:
stride = shape
else:
stride = shape4d(stride, data_format=data_format)
return func(x, ksize=shape,
strides=stride, padding=padding,
data_format=data_format,
name='output')
@layer_register(log_shape=True) @layer_register(log_shape=True)
def MaxPooling(x, shape, stride=None, padding='VALID', data_format='NHWC'): def MaxPooling(x, shape, stride=None, padding='VALID', data_format='NHWC'):
""" """
...@@ -43,8 +29,9 @@ def MaxPooling(x, shape, stride=None, padding='VALID', data_format='NHWC'): ...@@ -43,8 +29,9 @@ def MaxPooling(x, shape, stride=None, padding='VALID', data_format='NHWC'):
Returns: Returns:
tf.Tensor named ``output``. tf.Tensor named ``output``.
""" """
return _Pooling(tf.nn.max_pool, x, shape, stride, padding, ret = tf.layers.max_pooling2d(x, shape, stride, padding,
data_format=data_format) 'channels_last' if data_format == 'NHWC' else 'channels_first')
return tf.identity(ret, name='output')
@layer_register(log_shape=True) @layer_register(log_shape=True)
...@@ -61,8 +48,9 @@ def AvgPooling(x, shape, stride=None, padding='VALID', data_format='NHWC'): ...@@ -61,8 +48,9 @@ def AvgPooling(x, shape, stride=None, padding='VALID', data_format='NHWC'):
Returns: Returns:
tf.Tensor named ``output``. tf.Tensor named ``output``.
""" """
return _Pooling(tf.nn.avg_pool, x, shape, stride, padding, ret = tf.layers.average_pooling2d(x, shape, stride, padding,
data_format=data_format) 'channels_last' if data_format == 'NHWC' else 'channels_first')
return tf.identity(ret, name='output')
@layer_register(log_shape=True) @layer_register(log_shape=True)
......
...@@ -73,7 +73,7 @@ class SessionUpdate(object): ...@@ -73,7 +73,7 @@ class SessionUpdate(object):
name, val.shape, varshape)) name, val.shape, varshape))
val = val.reshape(varshape) val = val.reshape(varshape)
# fix some common type incompatibility problem, but is certainly not enough # fix some common type incompatibility problems, but not all
def upcast(vartype, valtype): def upcast(vartype, valtype):
# allow up-casting # allow up-casting
if vartype == tf.float64 and valtype == np.float32: if vartype == tf.float64 and valtype == np.float32:
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment