Commit e4aca035 authored by Yuxin Wu's avatar Yuxin Wu

Try using cudnn's group conv

parent d8d35fb5
...@@ -5,7 +5,7 @@ ...@@ -5,7 +5,7 @@
import tensorflow as tf import tensorflow as tf
from ..tfutils.common import get_tf_version_tuple from ..tfutils.common import get_tf_version_tuple
from ..utils.argtools import get_data_format, shape2d, shape4d from ..utils.argtools import get_data_format, shape2d, shape4d, log_once
from .common import VariableHolder, layer_register from .common import VariableHolder, layer_register
from .tflayer import convert_to_tflayer_args, rename_get_variable from .tflayer import convert_to_tflayer_args, rename_get_variable
...@@ -108,11 +108,22 @@ def Conv2D( ...@@ -108,11 +108,22 @@ def Conv2D(
if use_bias: if use_bias:
b = tf.get_variable('b', [out_channel], initializer=bias_initializer) b = tf.get_variable('b', [out_channel], initializer=bias_initializer)
conv = None
if get_tf_version_tuple() >= (1, 13):
try:
conv = tf.nn.conv2d(inputs, W, stride, padding.upper(), **kwargs)
except ValueError:
conv = None
log_once("CUDNN group convolution support is only available with "
"https://github.com/tensorflow/tensorflow/pull/25818 . "
"Will fall back to a loop-based slow implementation instead!", 'warn')
if conv is None:
inputs = tf.split(inputs, split, channel_axis) inputs = tf.split(inputs, split, channel_axis)
kernels = tf.split(W, split, 3) kernels = tf.split(W, split, 3)
outputs = [tf.nn.conv2d(i, k, stride, padding.upper(), **kwargs) outputs = [tf.nn.conv2d(i, k, stride, padding.upper(), **kwargs)
for i, k in zip(inputs, kernels)] for i, k in zip(inputs, kernels)]
conv = tf.concat(outputs, channel_axis) conv = tf.concat(outputs, channel_axis)
if activation is None: if activation is None:
activation = tf.identity activation = tf.identity
ret = activation(tf.nn.bias_add(conv, b, data_format=data_format) if use_bias else conv, name='output') ret = activation(tf.nn.bias_add(conv, b, data_format=data_format) if use_bias else conv, name='output')
......
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# File: sessinit.py # File: sessinit.py
import os
import numpy as np import numpy as np
import six import six
import tensorflow as tf import tensorflow as tf
...@@ -251,6 +251,7 @@ def get_model_loader(filename): ...@@ -251,6 +251,7 @@ def get_model_loader(filename):
:class:`SaverRestore` (otherwise). :class:`SaverRestore` (otherwise).
""" """
assert isinstance(filename, six.string_types), filename assert isinstance(filename, six.string_types), filename
filename = os.path.expanduser(filename)
if filename.endswith('.npy'): if filename.endswith('.npy'):
assert tf.gfile.Exists(filename), filename assert tf.gfile.Exists(filename), filename
return DictRestore(np.load(filename, encoding='latin1').item()) return DictRestore(np.load(filename, encoding='latin1').item())
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment