Commit c713cb75 authored by Yuxin Wu's avatar Yuxin Wu

add unpooling

parent c653458c
......@@ -2,11 +2,13 @@
# -*- coding: UTF-8 -*-
# File: pool.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
import tensorflow as tf
import numpy
from ._common import *
import tensorflow as tf
from ..utils.symbolic_functions import *
__all__ = ['MaxPooling']
__all__ = ['MaxPooling', 'FixedUnPooling']
@layer_register()
def MaxPooling(x, shape, stride=None, padding='VALID'):
......@@ -24,3 +26,33 @@ def MaxPooling(x, shape, stride=None, padding='VALID'):
return tf.nn.max_pool(x, ksize=shape, strides=stride, padding=padding)
@layer_register()
def FixedUnPooling(x, shape, unpool_mat=None):
"""
Unpool the input with a fixed mat to perform kronecker product with
x: 4D tensor of (b, h, w, c)
shape: int or list/tuple of length 2
unpool_mat: a tf matrix with size=shape. if None, will use a zero-mat with a 1 as first element
"""
shape = shape2d(shape)
input_shape = x.get_shape().as_list()
assert len(input_shape) == 4
if unpool_mat is None:
mat = np.zeros(shape)
mat[0][0] = 1
unpool_mat = tf.Variable(mat, trainable=False, name='unpool_mat')
assert unpool_mat.get_shape().as_list() == list(shape)
fx = flatten(tf.transpose(x, [0, 3, 1, 2]))
fx = tf.expand_dims(fx, -1) # (bchw)x1
mat = tf.expand_dims(flatten(unpool_mat), 0) #1x(shxsw)
prod = tf.matmul(fx, mat) #(bchw) x(shxsw)
prod = tf.reshape(prod, [-1, input_shape[3],
input_shape[1], input_shape[2],
shape[0], shape[1]])
prod = tf.transpose(prod, [0, 2, 4, 3, 5, 1])
prod = tf.reshape(prod, [-1, input_shape[1] * shape[0],
input_shape[2] * shape[1],
input_shape[3]])
return prod
......@@ -5,7 +5,6 @@
import tensorflow as tf
import numpy as np
__all__ = ['one_hot', 'batch_flatten', 'logSoftmax']
def one_hot(y, num_labels):
with tf.op_scope([y, num_labels], 'one_hot'):
......@@ -18,6 +17,10 @@ def one_hot(y, num_labels):
onehot_labels.set_shape([None, num_labels])
return tf.cast(onehot_labels, tf.float32)
def flatten(x):
total_dim = np.prod(x.get_shape().as_list())
return tf.reshape(x, [total_dim])
def batch_flatten(x):
total_dim = np.prod(x.get_shape()[1:].as_list())
return tf.reshape(x, [-1, total_dim])
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment