Commit 8e167804 authored by Yuxin Wu's avatar Yuxin Wu

Better static shape inference in FixedUnPooling

parent 88907b83
......@@ -6,6 +6,7 @@
import tensorflow as tf
from .common import layer_register, VariableHolder
from ..utils.argtools import shape2d, shape4d
from .shape_utils import StaticDynamicAxis
__all__ = ['Conv2D', 'Deconv2D']
......@@ -78,19 +79,6 @@ def Conv2D(x, out_channel, kernel_shape,
return ret
class StaticDynamicShape(object):
def __init__(self, static, dynamic):
self.static = static
self.dynamic = dynamic
def apply(self, f):
try:
st = f(self.static)
return StaticDynamicShape(st, st)
except:
return StaticDynamicShape(None, f(self.dynamic))
@layer_register()
def Deconv2D(x, out_shape, kernel_shape,
stride, padding='SAME',
......@@ -134,13 +122,13 @@ def Deconv2D(x, out_shape, kernel_shape,
if isinstance(out_shape, int):
out_channel = out_shape
if data_format == 'NHWC':
shp3_0 = StaticDynamicShape(in_shape[1], in_shape_dyn[1]).apply(lambda x: stride2d[0] * x)
shp3_1 = StaticDynamicShape(in_shape[2], in_shape_dyn[2]).apply(lambda x: stride2d[1] * x)
shp3_0 = StaticDynamicAxis(in_shape[1], in_shape_dyn[1]).apply(lambda x: stride2d[0] * x)
shp3_1 = StaticDynamicAxis(in_shape[2], in_shape_dyn[2]).apply(lambda x: stride2d[1] * x)
shp3_dyn = [shp3_0.dynamic, shp3_1.dynamic, out_channel]
shp3_static = [shp3_0.static, shp3_1.static, out_channel]
else:
shp3_0 = StaticDynamicShape(in_shape[2], in_shape_dyn[2]).apply(lambda x: stride2d[0] * x)
shp3_1 = StaticDynamicShape(in_shape[3], in_shape_dyn[3]).apply(lambda x: stride2d[1] * x)
shp3_0 = StaticDynamicAxis(in_shape[2], in_shape_dyn[2]).apply(lambda x: stride2d[0] * x)
shp3_1 = StaticDynamicAxis(in_shape[3], in_shape_dyn[3]).apply(lambda x: stride2d[1] * x)
shp3_dyn = [out_channel, shp3_0.dynamic, shp3_1.dynamic]
shp3_static = [out_channel, shp3_0.static, shp3_1.static]
else:
......
......@@ -5,6 +5,7 @@
import tensorflow as tf
import numpy as np
from .shape_utils import StaticDynamicShape
from .common import layer_register
from ..utils.argtools import shape2d, shape4d
from ._test import TestModel
......@@ -75,7 +76,7 @@ def GlobalAvgPooling(x, data_format='NHWC'):
Returns:
tf.Tensor: a NC tensor named ``output``.
"""
assert x.get_shape().ndims == 4
assert x.shape.ndims == 4
assert data_format in ['NHWC', 'NCHW']
axis = [1, 2] if data_format == 'NHWC' else [2, 3]
return tf.reduce_mean(x, axis, name='output')
......@@ -93,7 +94,6 @@ def UnPooling2x2ZeroFilled(x):
else:
shv = tf.shape(x)
ret = tf.reshape(out, tf.stack([-1, shv[1] * 2, shv[2] * 2, sh[3]]))
ret.set_shape([None, None, None, sh[3]])
return ret
......@@ -113,35 +113,40 @@ def FixedUnPooling(x, shape, unpool_mat=None, data_format='NHWC'):
"""
shape = shape2d(shape)
output_shape = StaticDynamicShape(x)
output_shape.apply(1 if data_format == 'NHWC' else 2, lambda x: x * shape[0])
output_shape.apply(2 if data_format == 'NHWC' else 3, lambda x: x * shape[1])
# a faster implementation for this special case
if shape[0] == 2 and shape[1] == 2 and unpool_mat is None and data_format == 'NHWC':
return UnPooling2x2ZeroFilled(x)
input_shape = tf.shape(x)
if unpool_mat is None:
mat = np.zeros(shape, dtype='float32')
mat[0][0] = 1
unpool_mat = tf.constant(mat, name='unpool_mat')
elif isinstance(unpool_mat, np.ndarray):
unpool_mat = tf.constant(unpool_mat, name='unpool_mat')
assert unpool_mat.get_shape().as_list() == list(shape)
if data_format == 'NHWC':
x = tf.transpose(x, [0, 3, 1, 2])
# perform a tensor-matrix kronecker product
x = tf.expand_dims(x, -1) # bchwx1
mat = tf.expand_dims(unpool_mat, 0) # 1xshxsw
prod = tf.tensordot(x, mat, axes=1) # bxcxhxwxshxsw
if data_format == 'NHWC':
prod = tf.transpose(prod, [0, 2, 4, 3, 5, 1])
prod = tf.reshape(prod, tf.stack(
[-1, input_shape[1] * shape[0], input_shape[2] * shape[1], input_shape[3]]))
ret = UnPooling2x2ZeroFilled(x)
else:
prod = tf.transpose(prod, [0, 1, 2, 4, 3, 5])
prod = tf.reshape(prod, tf.stack(
[-1, input_shape[3], input_shape[1] * shape[0], input_shape[2] * shape[1]]))
# TODO static shape inference
return prod
# check unpool_mat
if unpool_mat is None:
mat = np.zeros(shape, dtype='float32')
mat[0][0] = 1
unpool_mat = tf.constant(mat, name='unpool_mat')
elif isinstance(unpool_mat, np.ndarray):
unpool_mat = tf.constant(unpool_mat, name='unpool_mat')
assert unpool_mat.shape.as_list() == list(shape)
if data_format == 'NHWC':
x = tf.transpose(x, [0, 3, 1, 2])
# perform a tensor-matrix kronecker product
x = tf.expand_dims(x, -1) # bchwx1
mat = tf.expand_dims(unpool_mat, 0) # 1xshxsw
ret = tf.tensordot(x, mat, axes=1) # bxcxhxwxshxsw
if data_format == 'NHWC':
ret = tf.transpose(ret, [0, 2, 4, 3, 5, 1])
else:
ret = tf.transpose(ret, [0, 1, 2, 4, 3, 5])
shape3_dyn = [output_shape.get_dynamic(k) for k in range(1, 4)]
ret = tf.reshape(ret, tf.stack([-1] + shape3_dyn))
ret.set_shape(tf.TensorShape(output_shape.get_static()))
return ret
@layer_register()
......@@ -156,7 +161,7 @@ def BilinearUpSample(x, shape):
Returns:
tf.Tensor: a NHWC tensor.
"""
inp_shape = x.get_shape().as_list()
inp_shape = x.shape.as_list()
ch = inp_shape[3]
assert ch is not None
......@@ -199,18 +204,19 @@ def BilinearUpSample(x, shape):
class TestPool(TestModel):
def test_FixedUnPooling(self):
h, w = 3, 4
scale = 2
mat = np.random.rand(h, w, 3).astype('float32')
inp = self.make_variable(mat)
inp = tf.reshape(inp, [1, h, w, 3])
output = FixedUnPooling('unpool', inp, 2)
output = FixedUnPooling('unpool', inp, scale)
res = self.run_variable(output)
self.assertEqual(res.shape, (1, 2 * h, 2 * w, 3))
self.assertEqual(res.shape, (1, scale * h, scale * w, 3))
# mat is on cornser
ele = res[0, ::2, ::2, 0]
ele = res[0, ::scale, ::scale, 0]
self.assertTrue((ele == mat[:, :, 0]).all())
# the rest are zeros
res[0, ::2, ::2, :] = 0
res[0, ::scale, ::scale, :] = 0
self.assertTrue((res == 0).all())
def test_BilinearUpSample(self):
......
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# File: shape_utils.py
import tensorflow as tf
__all__ = ['StaticDynamicAxis', 'StaticDynamicShape']
class StaticDynamicAxis(object):
def __init__(self, static, dynamic):
self.static = static
self.dynamic = dynamic
def apply(self, f):
try:
st = f(self.static)
return StaticDynamicAxis(st, st)
except:
return StaticDynamicAxis(None, f(self.dynamic))
def __str__(self):
return "S={}, D={}".format(str(self.static), str(self.dynamic))
def DynamicLazyAxis(shape, idx):
return lambda: shape[idx]
def StaticLazyAxis(dim):
return lambda: dim
class StaticDynamicShape(object):
def __init__(self, tensor):
assert isinstance(tensor, tf.Tensor), tensor
ndims = tensor.shape.ndims
self.static = tensor.shape.as_list()
if tensor.shape.is_fully_defined():
self.dynamic = self.static[:]
else:
dynamic = tf.shape(tensor)
self.dynamic = [DynamicLazyAxis(dynamic, k) for k in range(ndims)]
for k in range(ndims):
if self.static[k] is not None:
self.dynamic[k] = StaticLazyAxis(self.static[k])
def apply(self, axis, f):
if self.static[axis] is not None:
try:
st = f(self.static[axis])
self.static[axis] = st
self.dynamic[axis] = StaticLazyAxis(st)
return
except:
pass
self.static[axis] = None
dyn = self.dynamic[axis]
self.dynamic[axis] = lambda: f(dyn())
def get_static(self):
return self.static
@property
def ndims(self):
return len(self.static)
def get_dynamic(self, axis=None):
if axis is None:
return [self.dynamic[k]() for k in range(self.ndims)]
return self.dynamic[axis]()
if __name__ == '__main__':
x = tf.placeholder(tf.float32, shape=[None, 3, None, 10])
shape = StaticDynamicShape(x)
shape.apply(1, lambda x: x * 3)
shape.apply(2, lambda x: x + 5)
print(shape.get_static())
print(shape.get_dynamic())
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment