Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
e4aca035
Commit
e4aca035
authored
Feb 20, 2019
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Try using cudnn's group conv
parent
d8d35fb5
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
19 additions
and
7 deletions
+19
-7
tensorpack/models/conv2d.py
tensorpack/models/conv2d.py
+17
-6
tensorpack/tfutils/sessinit.py
tensorpack/tfutils/sessinit.py
+2
-1
No files found.
tensorpack/models/conv2d.py
View file @
e4aca035
...
@@ -5,7 +5,7 @@
...
@@ -5,7 +5,7 @@
import
tensorflow
as
tf
import
tensorflow
as
tf
from
..tfutils.common
import
get_tf_version_tuple
from
..tfutils.common
import
get_tf_version_tuple
from
..utils.argtools
import
get_data_format
,
shape2d
,
shape4d
from
..utils.argtools
import
get_data_format
,
shape2d
,
shape4d
,
log_once
from
.common
import
VariableHolder
,
layer_register
from
.common
import
VariableHolder
,
layer_register
from
.tflayer
import
convert_to_tflayer_args
,
rename_get_variable
from
.tflayer
import
convert_to_tflayer_args
,
rename_get_variable
...
@@ -108,11 +108,22 @@ def Conv2D(
...
@@ -108,11 +108,22 @@ def Conv2D(
if
use_bias
:
if
use_bias
:
b
=
tf
.
get_variable
(
'b'
,
[
out_channel
],
initializer
=
bias_initializer
)
b
=
tf
.
get_variable
(
'b'
,
[
out_channel
],
initializer
=
bias_initializer
)
conv
=
None
if
get_tf_version_tuple
()
>=
(
1
,
13
):
try
:
conv
=
tf
.
nn
.
conv2d
(
inputs
,
W
,
stride
,
padding
.
upper
(),
**
kwargs
)
except
ValueError
:
conv
=
None
log_once
(
"CUDNN group convolution support is only available with "
"https://github.com/tensorflow/tensorflow/pull/25818 . "
"Will fall back to a loop-based slow implementation instead!"
,
'warn'
)
if
conv
is
None
:
inputs
=
tf
.
split
(
inputs
,
split
,
channel_axis
)
inputs
=
tf
.
split
(
inputs
,
split
,
channel_axis
)
kernels
=
tf
.
split
(
W
,
split
,
3
)
kernels
=
tf
.
split
(
W
,
split
,
3
)
outputs
=
[
tf
.
nn
.
conv2d
(
i
,
k
,
stride
,
padding
.
upper
(),
**
kwargs
)
outputs
=
[
tf
.
nn
.
conv2d
(
i
,
k
,
stride
,
padding
.
upper
(),
**
kwargs
)
for
i
,
k
in
zip
(
inputs
,
kernels
)]
for
i
,
k
in
zip
(
inputs
,
kernels
)]
conv
=
tf
.
concat
(
outputs
,
channel_axis
)
conv
=
tf
.
concat
(
outputs
,
channel_axis
)
if
activation
is
None
:
if
activation
is
None
:
activation
=
tf
.
identity
activation
=
tf
.
identity
ret
=
activation
(
tf
.
nn
.
bias_add
(
conv
,
b
,
data_format
=
data_format
)
if
use_bias
else
conv
,
name
=
'output'
)
ret
=
activation
(
tf
.
nn
.
bias_add
(
conv
,
b
,
data_format
=
data_format
)
if
use_bias
else
conv
,
name
=
'output'
)
...
...
tensorpack/tfutils/sessinit.py
View file @
e4aca035
# -*- coding: utf-8 -*-
# -*- coding: utf-8 -*-
# File: sessinit.py
# File: sessinit.py
import
os
import
numpy
as
np
import
numpy
as
np
import
six
import
six
import
tensorflow
as
tf
import
tensorflow
as
tf
...
@@ -251,6 +251,7 @@ def get_model_loader(filename):
...
@@ -251,6 +251,7 @@ def get_model_loader(filename):
:class:`SaverRestore` (otherwise).
:class:`SaverRestore` (otherwise).
"""
"""
assert
isinstance
(
filename
,
six
.
string_types
),
filename
assert
isinstance
(
filename
,
six
.
string_types
),
filename
filename
=
os
.
path
.
expanduser
(
filename
)
if
filename
.
endswith
(
'.npy'
):
if
filename
.
endswith
(
'.npy'
):
assert
tf
.
gfile
.
Exists
(
filename
),
filename
assert
tf
.
gfile
.
Exists
(
filename
),
filename
return
DictRestore
(
np
.
load
(
filename
,
encoding
=
'latin1'
)
.
item
())
return
DictRestore
(
np
.
load
(
filename
,
encoding
=
'latin1'
)
.
item
())
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment