Commit 90a14aa4 authored by Yuxin Wu's avatar Yuxin Wu

faster unpooling for 2x2 zero filled

parent 9aa390c7
...@@ -29,6 +29,8 @@ def run_test_case(case): ...@@ -29,6 +29,8 @@ def run_test_case(case):
if __name__ == '__main__': if __name__ == '__main__':
import tensorpack import tensorpack
from tensorpack.utils import logger
logger.disable_logger()
subs = tensorpack.models._test.TestModel.__subclasses__() subs = tensorpack.models._test.TestModel.__subclasses__()
for cls in subs: for cls in subs:
run_test_case(cls) run_test_case(cls)
......
...@@ -63,6 +63,19 @@ def GlobalAvgPooling(x): ...@@ -63,6 +63,19 @@ def GlobalAvgPooling(x):
assert x.get_shape().ndims == 4 assert x.get_shape().ndims == 4
return tf.reduce_mean(x, [1, 2]) return tf.reduce_mean(x, [1, 2])
# https://github.com/tensorflow/tensorflow/issues/2169
def UnPooling2x2ZeroFilled(x):
out = tf.concat(3, [x, tf.zeros_like(x)])
out = tf.concat(2, [out, tf.zeros_like(out)])
sh = x.get_shape().as_list()
if None not in sh[1:]:
out_size = [-1, sh[1] * 2, sh[2] * 2, sh[3]]
return tf.reshape(out, out_size)
else:
sh = tf.shape(x)
return tf.reshape(out, [-1, sh[1] * 2, sh[2] * 2, sh[3]])
@layer_register() @layer_register()
def FixedUnPooling(x, shape, unpool_mat=None): def FixedUnPooling(x, shape, unpool_mat=None):
""" """
...@@ -75,6 +88,11 @@ def FixedUnPooling(x, shape, unpool_mat=None): ...@@ -75,6 +88,11 @@ def FixedUnPooling(x, shape, unpool_mat=None):
:returns: NHWC tensor :returns: NHWC tensor
""" """
shape = shape2d(shape) shape = shape2d(shape)
# a faster implementation for this special case
if shape[0] == 2 and shape[1] == 2 and unpool_mat is None:
return UnPooling2x2ZeroFilled(x)
input_shape = tf.shape(x) input_shape = tf.shape(x)
if unpool_mat is None: if unpool_mat is None:
mat = np.zeros(shape, dtype='float32') mat = np.zeros(shape, dtype='float32')
...@@ -136,18 +154,18 @@ from ._test import TestModel ...@@ -136,18 +154,18 @@ from ._test import TestModel
class TestPool(TestModel): class TestPool(TestModel):
def test_fixed_unpooling(self): def test_fixed_unpooling(self):
h, w = 3, 4 h, w = 3, 4
mat = np.random.rand(h, w).astype('float32') mat = np.random.rand(h, w, 3).astype('float32')
inp = self.make_variable(mat) inp = self.make_variable(mat)
inp = tf.reshape(inp, [1, h, w, 1]) inp = tf.reshape(inp, [1, h, w, 3])
output = FixedUnPooling('unpool', inp, 2) output = FixedUnPooling('unpool', inp, 2)
res = self.run_variable(output) res = self.run_variable(output)
self.assertEqual(res.shape, (1, 2*h, 2*w, 1)) self.assertEqual(res.shape, (1, 2*h, 2*w, 3))
# mat is on cornser # mat is on cornser
ele = res[0,::2,::2,0] ele = res[0,::2,::2,0]
self.assertTrue((ele == mat).all()) self.assertTrue((ele == mat[:,:,0]).all())
# the rest are zeros # the rest are zeros
res[0,::2,::2,0] = 0 res[0,::2,::2,:] = 0
self.assertTrue((res == 0).all()) self.assertTrue((res == 0).all())
def test_upsample(self): def test_upsample(self):
...@@ -166,6 +184,10 @@ class TestPool(TestModel): ...@@ -166,6 +184,10 @@ class TestPool(TestModel):
diff = np.abs(res2 - res[0,:,:,0]) diff = np.abs(res2 - res[0,:,:,0])
# not equivalent at corner # not equivalent to rescale on edge
diff[0,0] = diff[-1,-1] = 0 diff[0,:] = 0
diff[:,0] = 0
if not diff.max() < 1e-4:
import IPython;
IPython.embed(config=IPython.terminal.ipapp.load_default_config())
self.assertTrue(diff.max() < 1e-4) self.assertTrue(diff.max() < 1e-4)
...@@ -95,6 +95,9 @@ unless you're resuming from a previous task.""".format(dirname)) ...@@ -95,6 +95,9 @@ unless you're resuming from a previous task.""".format(dirname))
LOG_FILE = os.path.join(dirname, 'log.log') LOG_FILE = os.path.join(dirname, 'log.log')
_set_file(LOG_FILE) _set_file(LOG_FILE)
def disable_logger():
for func in ['info', 'warning', 'error', 'critical', 'warn', 'exception', 'debug']:
globals()[func] = lambda x: None
# export logger functions # export logger functions
for func in ['info', 'warning', 'error', 'critical', 'warn', 'exception', 'debug']: for func in ['info', 'warning', 'error', 'critical', 'warn', 'exception', 'debug']:
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment