Commit c14d1496 authored by Aaron Gokaslan's avatar Aaron Gokaslan Committed by Yuxin Wu

Add dilation_rate to Conv2D (#625)

Added dilation_rate to Conv2D since Tensorflow supports it easily. Requires TF 1.5 or greater.
parent 3c62a25c
......@@ -5,6 +5,7 @@
import tensorflow as tf
from .common import layer_register, VariableHolder, rename_get_variable
from ..tfutils.common import get_tf_version_number
from ..utils.argtools import shape2d, shape4d
__all__ = ['Conv2D', 'Deconv2D']
......@@ -15,7 +16,7 @@ def Conv2D(x, out_channel, kernel_shape,
padding='SAME', stride=1,
W_init=None, b_init=None,
nl=tf.identity, split=1, use_bias=True,
data_format='NHWC'):
data_format='NHWC', dilation_rate=1):
"""
2D convolution on 4D inputs.
......@@ -31,6 +32,8 @@ def Conv2D(x, out_channel, kernel_shape,
b_init: initializer for b. Defaults to zero.
nl: a nonlinearity function.
use_bias (bool): whether to use bias.
data_format (str): 'NHWC' or 'NCHW'.
dilation_rate: (h, w) tuple or a int.
Returns:
tf.Tensor named ``output`` with attribute `variables`.
......@@ -46,12 +49,17 @@ def Conv2D(x, out_channel, kernel_shape,
assert in_channel is not None, "[Conv2D] Input cannot have unknown channel!"
assert in_channel % split == 0
assert out_channel % split == 0
assert dilation_rate == 1 or get_tf_version_number() >= 1.5, 'TF ver. 1.5 or greater required for dilations'
kernel_shape = shape2d(kernel_shape)
padding = padding.upper()
filter_shape = kernel_shape + [in_channel / split, out_channel]
stride = shape4d(stride, data_format=data_format)
kw_args = dict(data_format=data_format)
if get_tf_version_number() >= 1.5:
kw_args['dilations'] = shape4d(dilation_rate, data_format=data_format)
if W_init is None:
W_init = tf.variance_scaling_initializer(scale=2.0)
if b_init is None:
......@@ -63,11 +71,11 @@ def Conv2D(x, out_channel, kernel_shape,
b = tf.get_variable('b', [out_channel], initializer=b_init)
if split == 1:
conv = tf.nn.conv2d(x, W, stride, padding, data_format=data_format)
conv = tf.nn.conv2d(x, W, stride, padding, **kw_args)
else:
inputs = tf.split(x, split, channel_axis)
kernels = tf.split(W, split, 3)
outputs = [tf.nn.conv2d(i, k, stride, padding, data_format=data_format)
outputs = [tf.nn.conv2d(i, k, stride, padding, **kw_args)
for i, k in zip(inputs, kernels)]
conv = tf.concat(outputs, channel_axis)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment