Commit 90a14aa4 authored by Yuxin Wu's avatar Yuxin Wu

faster unpooling for 2x2 zero filled

parent 9aa390c7
......@@ -29,6 +29,8 @@ def run_test_case(case):
if __name__ == '__main__':
import tensorpack
from tensorpack.utils import logger
logger.disable_logger()
subs = tensorpack.models._test.TestModel.__subclasses__()
for cls in subs:
run_test_case(cls)
......
......@@ -63,6 +63,19 @@ def GlobalAvgPooling(x):
assert x.get_shape().ndims == 4
return tf.reduce_mean(x, [1, 2])
# https://github.com/tensorflow/tensorflow/issues/2169
def UnPooling2x2ZeroFilled(x):
out = tf.concat(3, [x, tf.zeros_like(x)])
out = tf.concat(2, [out, tf.zeros_like(out)])
sh = x.get_shape().as_list()
if None not in sh[1:]:
out_size = [-1, sh[1] * 2, sh[2] * 2, sh[3]]
return tf.reshape(out, out_size)
else:
sh = tf.shape(x)
return tf.reshape(out, [-1, sh[1] * 2, sh[2] * 2, sh[3]])
@layer_register()
def FixedUnPooling(x, shape, unpool_mat=None):
"""
......@@ -75,6 +88,11 @@ def FixedUnPooling(x, shape, unpool_mat=None):
:returns: NHWC tensor
"""
shape = shape2d(shape)
# a faster implementation for this special case
if shape[0] == 2 and shape[1] == 2 and unpool_mat is None:
return UnPooling2x2ZeroFilled(x)
input_shape = tf.shape(x)
if unpool_mat is None:
mat = np.zeros(shape, dtype='float32')
......@@ -136,18 +154,18 @@ from ._test import TestModel
class TestPool(TestModel):
def test_fixed_unpooling(self):
h, w = 3, 4
mat = np.random.rand(h, w).astype('float32')
mat = np.random.rand(h, w, 3).astype('float32')
inp = self.make_variable(mat)
inp = tf.reshape(inp, [1, h, w, 1])
inp = tf.reshape(inp, [1, h, w, 3])
output = FixedUnPooling('unpool', inp, 2)
res = self.run_variable(output)
self.assertEqual(res.shape, (1, 2*h, 2*w, 1))
self.assertEqual(res.shape, (1, 2*h, 2*w, 3))
# mat is on cornser
ele = res[0,::2,::2,0]
self.assertTrue((ele == mat).all())
self.assertTrue((ele == mat[:,:,0]).all())
# the rest are zeros
res[0,::2,::2,0] = 0
res[0,::2,::2,:] = 0
self.assertTrue((res == 0).all())
def test_upsample(self):
......@@ -166,6 +184,10 @@ class TestPool(TestModel):
diff = np.abs(res2 - res[0,:,:,0])
# not equivalent at corner
diff[0,0] = diff[-1,-1] = 0
# not equivalent to rescale on edge
diff[0,:] = 0
diff[:,0] = 0
if not diff.max() < 1e-4:
import IPython;
IPython.embed(config=IPython.terminal.ipapp.load_default_config())
self.assertTrue(diff.max() < 1e-4)
......@@ -95,6 +95,9 @@ unless you're resuming from a previous task.""".format(dirname))
LOG_FILE = os.path.join(dirname, 'log.log')
_set_file(LOG_FILE)
def disable_logger():
for func in ['info', 'warning', 'error', 'critical', 'warn', 'exception', 'debug']:
globals()[func] = lambda x: None
# export logger functions
for func in ['info', 'warning', 'error', 'critical', 'warn', 'exception', 'debug']:
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment