Commit 88907b83 authored by Yuxin Wu's avatar Yuxin Wu

simplify FixedUnpooling with tensordot, add NCHW support (but not for the...

simplify FixedUnpooling with tensordot, add NCHW support (but not for the special 2x2 case) (fix #314)
parent bd5e0591
......@@ -98,23 +98,23 @@ def UnPooling2x2ZeroFilled(x):
@layer_register()
def FixedUnPooling(x, shape, unpool_mat=None):
def FixedUnPooling(x, shape, unpool_mat=None, data_format='NHWC'):
"""
Unpool the input with a fixed matrix to perform kronecker product with.
Args:
x (tf.Tensor): a NHWC tensor
x (tf.Tensor): a 4D image tensor
shape: int or (h, w) tuple
unpool_mat: a tf.Tensor or np.ndarray 2D matrix with size=shape.
If is None, will use a matrix with 1 at top-left corner.
Returns:
tf.Tensor: a NHWC tensor.
tf.Tensor: a 4D image tensor.
"""
shape = shape2d(shape)
# a faster implementation for this special case
if shape[0] == 2 and shape[1] == 2 and unpool_mat is None:
if shape[0] == 2 and shape[1] == 2 and unpool_mat is None and data_format == 'NHWC':
return UnPooling2x2ZeroFilled(x)
input_shape = tf.shape(x)
......@@ -126,16 +126,21 @@ def FixedUnPooling(x, shape, unpool_mat=None):
unpool_mat = tf.constant(unpool_mat, name='unpool_mat')
assert unpool_mat.get_shape().as_list() == list(shape)
if data_format == 'NHWC':
x = tf.transpose(x, [0, 3, 1, 2])
# perform a tensor-matrix kronecker product
fx = tf.reshape(tf.transpose(x, [0, 3, 1, 2]), [-1])
fx = tf.expand_dims(fx, -1) # (bchw)x1
mat = tf.expand_dims(tf.reshape(unpool_mat, [-1]), 0) # 1x(shxsw)
prod = tf.matmul(fx, mat) # (bchw) x(shxsw)
prod = tf.reshape(prod, tf.stack(
[-1, input_shape[3], input_shape[1], input_shape[2], shape[0], shape[1]]))
x = tf.expand_dims(x, -1) # bchwx1
mat = tf.expand_dims(unpool_mat, 0) # 1xshxsw
prod = tf.tensordot(x, mat, axes=1) # bxcxhxwxshxsw
if data_format == 'NHWC':
prod = tf.transpose(prod, [0, 2, 4, 3, 5, 1])
prod = tf.reshape(prod, tf.stack(
[-1, input_shape[1] * shape[0], input_shape[2] * shape[1], input_shape[3]]))
else:
prod = tf.transpose(prod, [0, 1, 2, 4, 3, 5])
prod = tf.reshape(prod, tf.stack(
[-1, input_shape[3], input_shape[1] * shape[0], input_shape[2] * shape[1]]))
# TODO static shape inference
return prod
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment