Commit 4888d1ea authored by Yuxin Wu's avatar Yuxin Wu

reorganize some conv2d code; bump version

parent 3d1a30ff
...@@ -47,10 +47,12 @@ variables in the current graph and variables in the `session_init` initializer. ...@@ -47,10 +47,12 @@ variables in the current graph and variables in the `session_init` initializer.
Variables that appear in only one side will be printed as warning. Variables that appear in only one side will be printed as warning.
## Transfer Learning ## Transfer Learning
Therefore, transfer learning is trivial. Therefore, transfer learning is trivial.
If you want to load a pre-trained model, just use the same variable names. If you want to load a pre-trained model, just use the same variable names.
If you want to re-train some layer, just rename either the variables in the If you want to re-train some layer, either rename the variables in the
graph or the variables in your loader. graph or rename/remove the variables in your loader.
## Resume Training ## Resume Training
......
...@@ -61,5 +61,5 @@ except ImportError: ...@@ -61,5 +61,5 @@ except ImportError:
# These lines will be programatically read/write by setup.py # These lines will be programatically read/write by setup.py
# Don't touch them. # Don't touch them.
__version__ = '0.9.1' __version__ = '0.9.2'
__git_version__ = __version__ __git_version__ = __version__
...@@ -91,7 +91,7 @@ def Conv2D( ...@@ -91,7 +91,7 @@ def Conv2D(
assert in_channel % split == 0 assert in_channel % split == 0
assert kernel_regularizer is None and bias_regularizer is None and activity_regularizer is None, \ assert kernel_regularizer is None and bias_regularizer is None and activity_regularizer is None, \
"Not supported by group conv now!" "Not supported by group conv or dilated conv!"
out_channel = filters out_channel = filters
assert out_channel % split == 0 assert out_channel % split == 0
...@@ -111,25 +111,28 @@ def Conv2D( ...@@ -111,25 +111,28 @@ def Conv2D(
if use_bias: if use_bias:
b = tf.get_variable('b', [out_channel], initializer=bias_initializer) b = tf.get_variable('b', [out_channel], initializer=bias_initializer)
conv = None if split == 1:
if get_tf_version_tuple() >= (1, 13): conv = tf.nn.conv2d(inputs, W, stride, padding.upper(), **kwargs)
try: else:
conv = tf.nn.conv2d(inputs, W, stride, padding.upper(), **kwargs) conv = None
except ValueError: if get_tf_version_tuple() >= (1, 13):
conv = None try:
log_once("CUDNN group convolution support is only available with " conv = tf.nn.conv2d(inputs, W, stride, padding.upper(), **kwargs)
"https://github.com/tensorflow/tensorflow/pull/25818 . " except ValueError:
"Will fall back to a loop-based slow implementation instead!", 'warn') log_once("CUDNN group convolution support is only available with "
if conv is None: "https://github.com/tensorflow/tensorflow/pull/25818 . "
inputs = tf.split(inputs, split, channel_axis) "Will fall back to a loop-based slow implementation instead!", 'warn')
kernels = tf.split(W, split, 3) if conv is None:
outputs = [tf.nn.conv2d(i, k, stride, padding.upper(), **kwargs) inputs = tf.split(inputs, split, channel_axis)
for i, k in zip(inputs, kernels)] kernels = tf.split(W, split, 3)
conv = tf.concat(outputs, channel_axis) outputs = [tf.nn.conv2d(i, k, stride, padding.upper(), **kwargs)
for i, k in zip(inputs, kernels)]
if activation is None: conv = tf.concat(outputs, channel_axis)
activation = tf.identity
ret = activation(tf.nn.bias_add(conv, b, data_format=data_format) if use_bias else conv, name='output') ret = tf.nn.bias_add(conv, b, data_format=data_format) if use_bias else conv
if activation is not None:
ret = activation(ret)
ret = tf.identity(ret, name='output')
ret.variables = VariableHolder(W=W) ret.variables = VariableHolder(W=W)
if use_bias: if use_bias:
...@@ -236,7 +239,11 @@ def Conv2DTranspose( ...@@ -236,7 +239,11 @@ def Conv2DTranspose(
padding=padding.upper(), padding=padding.upper(),
data_format=data_format) data_format=data_format)
conv.set_shape(tf.TensorShape([None] + out_shape3_sta)) conv.set_shape(tf.TensorShape([None] + out_shape3_sta))
ret = activation(tf.nn.bias_add(conv, b, data_format=data_format) if use_bias else conv, name='output')
ret = tf.nn.bias_add(conv, b, data_format=data_format) if use_bias else conv
if activation is not None:
ret = activation(ret)
ret = tf.identity(ret, name='output')
ret.variables = VariableHolder(W=W) ret.variables = VariableHolder(W=W)
if use_bias: if use_bias:
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment