Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
466b5192
Commit
466b5192
authored
Aug 29, 2017
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
simplify pre-resnet model code
parent
02f41b25
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
26 additions
and
30 deletions
+26
-30
examples/ResNet/imagenet_resnet_utils.py
examples/ResNet/imagenet_resnet_utils.py
+17
-20
examples/Saliency/CAM-resnet.py
examples/Saliency/CAM-resnet.py
+6
-9
tensorpack/dataflow/prefetch.py
tensorpack/dataflow/prefetch.py
+3
-1
No files found.
examples/ResNet/imagenet_resnet_utils.py
View file @
466b5192
...
@@ -112,25 +112,21 @@ def get_imagenet_dataflow(
...
@@ -112,25 +112,21 @@ def get_imagenet_dataflow(
return
ds
return
ds
def
resnet_shortcut
(
l
,
n_out
,
stride
):
def
resnet_shortcut
(
l
,
n_out
,
stride
,
nl
=
tf
.
identity
):
data_format
=
get_arg_scope
()[
'Conv2D'
][
'data_format'
]
data_format
=
get_arg_scope
()[
'Conv2D'
][
'data_format'
]
n_in
=
l
.
get_shape
()
.
as_list
()[
1
if
data_format
==
'NCHW'
else
3
]
n_in
=
l
.
get_shape
()
.
as_list
()[
1
if
data_format
==
'NCHW'
else
3
]
if
n_in
!=
n_out
:
# change dimension when channel is not the same
if
n_in
!=
n_out
:
# change dimension when channel is not the same
return
Conv2D
(
'convshortcut'
,
l
,
n_out
,
1
,
stride
=
stride
)
return
Conv2D
(
'convshortcut'
,
l
,
n_out
,
1
,
stride
=
stride
,
nl
=
nl
)
else
:
else
:
return
l
return
l
def
apply_preactivation
(
l
,
preact
):
def
apply_preactivation
(
l
,
preact
):
"""
"""
'no_preact' for the first resblock only, because the input is activated already.
'no_preact' for the first resblock in each group only, because the input is activated already.
'both_preact' for the first block in each group, due to the projection shotcut.
'default' for all the non-first blocks, where identity mapping is preserved on shortcut path.
'default' for all the non-first blocks, where identity mapping is preserved on shortcut path.
"""
"""
if
preact
==
'both_preact'
:
if
preact
==
'default'
:
l
=
BNReLU
(
'preact'
,
l
)
shortcut
=
l
elif
preact
==
'default'
:
shortcut
=
l
shortcut
=
l
l
=
BNReLU
(
'preact'
,
l
)
l
=
BNReLU
(
'preact'
,
l
)
else
:
else
:
...
@@ -153,14 +149,16 @@ def resnet_bottleneck(l, ch_out, stride, preact):
...
@@ -153,14 +149,16 @@ def resnet_bottleneck(l, ch_out, stride, preact):
return
l
+
resnet_shortcut
(
shortcut
,
ch_out
*
4
,
stride
)
return
l
+
resnet_shortcut
(
shortcut
,
ch_out
*
4
,
stride
)
def
resnet_group
(
l
,
name
,
block_func
,
features
,
count
,
stride
,
first
=
Fals
e
):
def
preresnet_group
(
l
,
name
,
block_func
,
features
,
count
,
strid
e
):
with
tf
.
variable_scope
(
name
):
with
tf
.
variable_scope
(
name
):
with
tf
.
variable_scope
(
'block0'
):
for
i
in
range
(
0
,
count
):
l
=
block_func
(
l
,
features
,
stride
,
'no_preact'
if
first
else
'both_preact'
)
for
i
in
range
(
1
,
count
):
with
tf
.
variable_scope
(
'block{}'
.
format
(
i
)):
with
tf
.
variable_scope
(
'block{}'
.
format
(
i
)):
l
=
block_func
(
l
,
features
,
1
,
'default'
)
# first block doesn't need activation
l
=
block_func
(
l
,
features
,
stride
if
i
==
0
else
1
,
'no_preact'
if
i
==
0
else
'default'
)
# end of each group need an extra activation
l
=
BNReLU
(
'bnlast'
,
l
)
return
l
return
l
...
@@ -170,11 +168,10 @@ def resnet_backbone(image, num_blocks, block_func):
...
@@ -170,11 +168,10 @@ def resnet_backbone(image, num_blocks, block_func):
logits
=
(
LinearWrap
(
image
)
logits
=
(
LinearWrap
(
image
)
.
Conv2D
(
'conv0'
,
64
,
7
,
stride
=
2
,
nl
=
BNReLU
)
.
Conv2D
(
'conv0'
,
64
,
7
,
stride
=
2
,
nl
=
BNReLU
)
.
MaxPooling
(
'pool0'
,
shape
=
3
,
stride
=
2
,
padding
=
'SAME'
)
.
MaxPooling
(
'pool0'
,
shape
=
3
,
stride
=
2
,
padding
=
'SAME'
)
.
apply
(
resnet_group
,
'group0'
,
block_func
,
64
,
num_blocks
[
0
],
1
,
first
=
True
)
.
apply
(
preresnet_group
,
'group0'
,
block_func
,
64
,
num_blocks
[
0
],
1
)
.
apply
(
resnet_group
,
'group1'
,
block_func
,
128
,
num_blocks
[
1
],
2
)
.
apply
(
preresnet_group
,
'group1'
,
block_func
,
128
,
num_blocks
[
1
],
2
)
.
apply
(
resnet_group
,
'group2'
,
block_func
,
256
,
num_blocks
[
2
],
2
)
.
apply
(
preresnet_group
,
'group2'
,
block_func
,
256
,
num_blocks
[
2
],
2
)
.
apply
(
resnet_group
,
'group3'
,
block_func
,
512
,
num_blocks
[
3
],
2
)
.
apply
(
preresnet_group
,
'group3'
,
block_func
,
512
,
num_blocks
[
3
],
2
)
.
BNReLU
(
'bnlast'
)
.
GlobalAvgPooling
(
'gap'
)
.
GlobalAvgPooling
(
'gap'
)
.
FullyConnected
(
'linear'
,
1000
,
nl
=
tf
.
identity
)())
.
FullyConnected
(
'linear'
,
1000
,
nl
=
tf
.
identity
)())
return
logits
return
logits
...
...
examples/Saliency/CAM-resnet.py
View file @
466b5192
...
@@ -20,7 +20,7 @@ from tensorpack.utils.gpu import get_nr_gpu
...
@@ -20,7 +20,7 @@ from tensorpack.utils.gpu import get_nr_gpu
from
tensorpack.utils
import
viz
from
tensorpack.utils
import
viz
from
imagenet_resnet_utils
import
(
from
imagenet_resnet_utils
import
(
fbresnet_augmentor
,
resnet_basicblock
,
resnet_bottleneck
,
resnet_group
,
fbresnet_augmentor
,
resnet_basicblock
,
pre
resnet_group
,
image_preprocess
,
compute_loss_and_error
)
image_preprocess
,
compute_loss_and_error
)
...
@@ -42,8 +42,6 @@ class Model(ModelDesc):
...
@@ -42,8 +42,6 @@ class Model(ModelDesc):
cfg
=
{
cfg
=
{
18
:
([
2
,
2
,
2
,
2
],
resnet_basicblock
),
18
:
([
2
,
2
,
2
,
2
],
resnet_basicblock
),
34
:
([
3
,
4
,
6
,
3
],
resnet_basicblock
),
34
:
([
3
,
4
,
6
,
3
],
resnet_basicblock
),
50
:
([
3
,
4
,
6
,
3
],
resnet_bottleneck
),
101
:
([
3
,
4
,
23
,
3
],
resnet_bottleneck
)
}
}
defs
,
block_func
=
cfg
[
DEPTH
]
defs
,
block_func
=
cfg
[
DEPTH
]
...
@@ -53,11 +51,10 @@ class Model(ModelDesc):
...
@@ -53,11 +51,10 @@ class Model(ModelDesc):
convmaps
=
(
LinearWrap
(
image
)
convmaps
=
(
LinearWrap
(
image
)
.
Conv2D
(
'conv0'
,
64
,
7
,
stride
=
2
,
nl
=
BNReLU
)
.
Conv2D
(
'conv0'
,
64
,
7
,
stride
=
2
,
nl
=
BNReLU
)
.
MaxPooling
(
'pool0'
,
shape
=
3
,
stride
=
2
,
padding
=
'SAME'
)
.
MaxPooling
(
'pool0'
,
shape
=
3
,
stride
=
2
,
padding
=
'SAME'
)
.
apply
(
resnet_group
,
'group0'
,
block_func
,
64
,
defs
[
0
],
1
,
first
=
True
)
.
apply
(
preresnet_group
,
'group0'
,
block_func
,
64
,
defs
[
0
],
1
)
.
apply
(
resnet_group
,
'group1'
,
block_func
,
128
,
defs
[
1
],
2
)
.
apply
(
preresnet_group
,
'group1'
,
block_func
,
128
,
defs
[
1
],
2
)
.
apply
(
resnet_group
,
'group2'
,
block_func
,
256
,
defs
[
2
],
2
)
.
apply
(
preresnet_group
,
'group2'
,
block_func
,
256
,
defs
[
2
],
2
)
.
apply
(
resnet_group
,
'group3new'
,
block_func
,
512
,
defs
[
3
],
1
)
.
apply
(
preresnet_group
,
'group3new'
,
block_func
,
512
,
defs
[
3
],
1
)())
.
BNReLU
(
'bnlast'
)())
print
(
convmaps
)
print
(
convmaps
)
logits
=
(
LinearWrap
(
convmaps
)
logits
=
(
LinearWrap
(
convmaps
)
.
GlobalAvgPooling
(
'gap'
)
.
GlobalAvgPooling
(
'gap'
)
...
@@ -125,7 +122,7 @@ def viz_cam(model_file, data_dir):
...
@@ -125,7 +122,7 @@ def viz_cam(model_file, data_dir):
model
=
Model
(),
model
=
Model
(),
session_init
=
get_model_loader
(
model_file
),
session_init
=
get_model_loader
(
model_file
),
input_names
=
[
'input'
,
'label'
],
input_names
=
[
'input'
,
'label'
],
output_names
=
[
'wrong-top1'
,
'bnlast/Relu'
,
'linearnew/W'
],
output_names
=
[
'wrong-top1'
,
'
group3new/
bnlast/Relu'
,
'linearnew/W'
],
return_input
=
True
return_input
=
True
)
)
meta
=
dataset
.
ILSVRCMeta
()
.
get_synset_words_1000
()
meta
=
dataset
.
ILSVRCMeta
()
.
get_synset_words_1000
()
...
...
tensorpack/dataflow/prefetch.py
View file @
466b5192
...
@@ -317,6 +317,7 @@ class ThreadedMapData(ProxyDataFlow):
...
@@ -317,6 +317,7 @@ class ThreadedMapData(ProxyDataFlow):
self
.
buffer_size
=
buffer_size
self
.
buffer_size
=
buffer_size
self
.
map_func
=
map_func
self
.
map_func
=
map_func
self
.
_threads
=
[]
self
.
_threads
=
[]
self
.
_evt
=
None
def
reset_state
(
self
):
def
reset_state
(
self
):
super
(
ThreadedMapData
,
self
)
.
reset_state
()
super
(
ThreadedMapData
,
self
)
.
reset_state
()
...
@@ -372,6 +373,7 @@ class ThreadedMapData(ProxyDataFlow):
...
@@ -372,6 +373,7 @@ class ThreadedMapData(ProxyDataFlow):
yield
self
.
_out_queue
.
get
()
yield
self
.
_out_queue
.
get
()
def
__del__
(
self
):
def
__del__
(
self
):
if
self
.
_evt
is
not
None
:
self
.
_evt
.
set
()
self
.
_evt
.
set
()
for
p
in
self
.
_threads
:
for
p
in
self
.
_threads
:
p
.
join
()
p
.
join
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment