Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
d6c2d6b3
Commit
d6c2d6b3
authored
May 30, 2017
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
add CAM
parent
38d26977
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
294 additions
and
5 deletions
+294
-5
README.md
README.md
+2
-2
examples/README.md
examples/README.md
+2
-2
examples/Saliency/CAM-demo.jpg
examples/Saliency/CAM-demo.jpg
+0
-0
examples/Saliency/CAM-resnet.py
examples/Saliency/CAM-resnet.py
+265
-0
examples/Saliency/README.md
examples/Saliency/README.md
+25
-1
No files found.
README.md
View file @
d6c2d6b3
...
...
@@ -12,8 +12,8 @@ See some [examples](examples) to learn about the framework:
+
[
Generative Adversarial Network(GAN) variants
](
examples/GAN
)
, including DCGAN, InfoGAN, Conditional GAN, WGAN, BEGAN, DiscoGAN, Image to Image, CycleGAN.
+
[
Fully-convolutional Network for Holistically-Nested Edge Detection(HED)
](
examples/HED
)
+
[
Spatial Transformer Networks on MNIST addition
](
examples/SpatialTransformer
)
+
[
Visualize
Saliency Maps by Guided ReLU
](
examples/Saliency
)
+
[
Similarity
L
earning on MNIST
](
examples/SimilarityLearning
)
+
[
Visualize
CNN saliency maps
](
examples/Saliency
)
+
[
Similarity
l
earning on MNIST
](
examples/SimilarityLearning
)
### Reinforcement Learning:
+
[
Deep Q-Network(DQN) variants on Atari games
](
examples/DeepQNetwork
)
, including DQN, DoubleDQN, DuelingDQN.
...
...
examples/README.md
View file @
d6c2d6b3
...
...
@@ -17,8 +17,8 @@ Training examples with __reproducible__ and meaningful performance.
+
[
InceptionV3 with 74% accuracy (similar to the official code)
](
Inception/inceptionv3.py
)
+
[
Fully-convolutional Network for Holistically-Nested Edge Detection(HED)
](
HED
)
+
[
Spatial Transformer Networks on MNIST addition
](
SpatialTransformer
)
+
[
Visualize
Saliency Maps by Guided ReLU
](
Saliency
)
+
[
Similarity
L
earning on MNIST
](
SimilarityLearning
)
+
[
Visualize
CNN saliency maps
](
Saliency
)
+
[
Similarity
l
earning on MNIST
](
SimilarityLearning
)
+
Load a pre-trained
[
AlexNet
](
load-alexnet.py
)
or
[
VGG16
](
load-vgg16.py
)
model.
+
Load a pre-trained
[
Convolutional Pose Machines
](
ConvolutionalPoseMachines/
)
.
...
...
examples/Saliency/CAM-demo.jpg
0 → 100644
View file @
d6c2d6b3
42.7 KB
examples/Saliency/CAM-resnet.py
0 → 100755
View file @
d6c2d6b3
#!/usr/bin/env python
# -*- coding: UTF-8 -*-
# File: CAM-resnet.py
import
cv2
import
sys
import
argparse
import
numpy
as
np
import
os
import
multiprocessing
import
tensorflow
as
tf
from
tensorflow.contrib.layers
import
variance_scaling_initializer
from
tensorpack
import
*
from
tensorpack.tfutils.symbolic_functions
import
*
from
tensorpack.tfutils.summary
import
*
TOTAL_BATCH_SIZE
=
256
INPUT_SHAPE
=
224
DEPTH
=
None
class
Model
(
ModelDesc
):
def
_get_inputs
(
self
):
return
[
InputDesc
(
tf
.
uint8
,
[
None
,
INPUT_SHAPE
,
INPUT_SHAPE
,
3
],
'input'
),
InputDesc
(
tf
.
int32
,
[
None
],
'label'
)]
def
_build_graph
(
self
,
inputs
):
image
,
label
=
inputs
image
=
tf
.
cast
(
image
,
tf
.
float32
)
*
(
1.0
/
255
)
image_mean
=
tf
.
constant
([
0.485
,
0.456
,
0.406
],
dtype
=
tf
.
float32
)
image_std
=
tf
.
constant
([
0.229
,
0.224
,
0.225
],
dtype
=
tf
.
float32
)
image
=
(
image
-
image_mean
)
/
image_std
image
=
tf
.
transpose
(
image
,
[
0
,
3
,
1
,
2
])
def
shortcut
(
l
,
n_in
,
n_out
,
stride
):
if
n_in
!=
n_out
:
return
Conv2D
(
'convshortcut'
,
l
,
n_out
,
1
,
stride
=
stride
)
else
:
return
l
def
basicblock
(
l
,
ch_out
,
stride
,
preact
):
ch_in
=
l
.
get_shape
()
.
as_list
()[
1
]
if
preact
==
'both_preact'
:
l
=
BNReLU
(
'preact'
,
l
)
input
=
l
elif
preact
!=
'no_preact'
:
input
=
l
l
=
BNReLU
(
'preact'
,
l
)
else
:
input
=
l
l
=
Conv2D
(
'conv1'
,
l
,
ch_out
,
3
,
stride
=
stride
,
nl
=
BNReLU
)
l
=
Conv2D
(
'conv2'
,
l
,
ch_out
,
3
)
return
l
+
shortcut
(
input
,
ch_in
,
ch_out
,
stride
)
def
bottleneck
(
l
,
ch_out
,
stride
,
preact
):
ch_in
=
l
.
get_shape
()
.
as_list
()[
1
]
if
preact
==
'both_preact'
:
l
=
BNReLU
(
'preact'
,
l
)
input
=
l
elif
preact
!=
'no_preact'
:
input
=
l
l
=
BNReLU
(
'preact'
,
l
)
else
:
input
=
l
l
=
Conv2D
(
'conv1'
,
l
,
ch_out
,
1
,
nl
=
BNReLU
)
l
=
Conv2D
(
'conv2'
,
l
,
ch_out
,
3
,
stride
=
stride
,
nl
=
BNReLU
)
l
=
Conv2D
(
'conv3'
,
l
,
ch_out
*
4
,
1
)
return
l
+
shortcut
(
input
,
ch_in
,
ch_out
*
4
,
stride
)
def
layer
(
l
,
layername
,
block_func
,
features
,
count
,
stride
,
first
=
False
):
with
tf
.
variable_scope
(
layername
):
with
tf
.
variable_scope
(
'block0'
):
l
=
block_func
(
l
,
features
,
stride
,
'no_preact'
if
first
else
'both_preact'
)
for
i
in
range
(
1
,
count
):
with
tf
.
variable_scope
(
'block{}'
.
format
(
i
)):
l
=
block_func
(
l
,
features
,
1
,
'default'
)
return
l
cfg
=
{
18
:
([
2
,
2
,
2
,
2
],
basicblock
),
34
:
([
3
,
4
,
6
,
3
],
basicblock
),
50
:
([
3
,
4
,
6
,
3
],
bottleneck
),
101
:
([
3
,
4
,
23
,
3
],
bottleneck
)
}
defs
,
block_func
=
cfg
[
DEPTH
]
with
argscope
(
Conv2D
,
nl
=
tf
.
identity
,
use_bias
=
False
,
W_init
=
variance_scaling_initializer
(
mode
=
'FAN_OUT'
)),
\
argscope
([
Conv2D
,
MaxPooling
,
GlobalAvgPooling
,
BatchNorm
],
data_format
=
'NCHW'
):
convmaps
=
(
LinearWrap
(
image
)
.
Conv2D
(
'conv0'
,
64
,
7
,
stride
=
2
,
nl
=
BNReLU
)
.
MaxPooling
(
'pool0'
,
shape
=
3
,
stride
=
2
,
padding
=
'SAME'
)
.
apply
(
layer
,
'group0'
,
block_func
,
64
,
defs
[
0
],
1
,
first
=
True
)
.
apply
(
layer
,
'group1'
,
block_func
,
128
,
defs
[
1
],
2
)
.
apply
(
layer
,
'group2'
,
block_func
,
256
,
defs
[
2
],
2
)
.
apply
(
layer
,
'group3new'
,
block_func
,
512
,
defs
[
3
],
1
)
.
BNReLU
(
'bnlast'
)())
print
(
convmaps
)
logits
=
(
LinearWrap
(
convmaps
)
.
GlobalAvgPooling
(
'gap'
)
.
FullyConnected
(
'linearnew'
,
1000
,
nl
=
tf
.
identity
)())
loss
=
tf
.
nn
.
sparse_softmax_cross_entropy_with_logits
(
logits
=
logits
,
labels
=
label
)
loss
=
tf
.
reduce_mean
(
loss
,
name
=
'xentropy-loss'
)
wrong
=
prediction_incorrect
(
logits
,
label
,
1
,
name
=
'wrong-top1'
)
add_moving_summary
(
tf
.
reduce_mean
(
wrong
,
name
=
'train-error-top1'
))
wrong
=
prediction_incorrect
(
logits
,
label
,
5
,
name
=
'wrong-top5'
)
add_moving_summary
(
tf
.
reduce_mean
(
wrong
,
name
=
'train-error-top5'
))
wd_cost
=
regularize_cost
(
'.*/W'
,
l2_regularizer
(
1e-4
),
name
=
'l2_regularize_loss'
)
add_moving_summary
(
loss
,
wd_cost
)
self
.
cost
=
tf
.
add_n
([
loss
,
wd_cost
],
name
=
'cost'
)
def
_get_optimizer
(
self
):
lr
=
get_scalar_var
(
'learning_rate'
,
0.1
,
summary
=
True
)
opt
=
tf
.
train
.
MomentumOptimizer
(
lr
,
0.9
,
use_nesterov
=
True
)
gradprocs
=
[
gradproc
.
ScaleGradient
(
[(
'conv0.*'
,
0.1
),
(
'group[0-2].*'
,
0.1
)])]
return
optimizer
.
apply_grad_processors
(
opt
,
gradprocs
)
# completely copied from imagenet-resnet.py example
def
get_data
(
train_or_test
):
isTrain
=
train_or_test
==
'train'
datadir
=
args
.
data
ds
=
dataset
.
ILSVRC12
(
datadir
,
train_or_test
,
shuffle
=
True
if
isTrain
else
False
,
dir_structure
=
'train'
)
if
isTrain
:
class
Resize
(
imgaug
.
ImageAugmentor
):
def
_augment
(
self
,
img
,
_
):
h
,
w
=
img
.
shape
[:
2
]
area
=
h
*
w
for
_
in
range
(
10
):
targetArea
=
self
.
rng
.
uniform
(
0.08
,
1.0
)
*
area
aspectR
=
self
.
rng
.
uniform
(
0.75
,
1.333
)
ww
=
int
(
np
.
sqrt
(
targetArea
*
aspectR
))
hh
=
int
(
np
.
sqrt
(
targetArea
/
aspectR
))
if
self
.
rng
.
uniform
()
<
0.5
:
ww
,
hh
=
hh
,
ww
if
hh
<=
h
and
ww
<=
w
:
x1
=
0
if
w
==
ww
else
self
.
rng
.
randint
(
0
,
w
-
ww
)
y1
=
0
if
h
==
hh
else
self
.
rng
.
randint
(
0
,
h
-
hh
)
out
=
img
[
y1
:
y1
+
hh
,
x1
:
x1
+
ww
]
out
=
cv2
.
resize
(
out
,
(
224
,
224
),
interpolation
=
cv2
.
INTER_CUBIC
)
return
out
out
=
cv2
.
resize
(
img
,
(
224
,
224
),
interpolation
=
cv2
.
INTER_CUBIC
)
return
out
augmentors
=
[
Resize
(),
imgaug
.
RandomOrderAug
(
[
imgaug
.
Brightness
(
30
,
clip
=
False
),
imgaug
.
Contrast
((
0.8
,
1.2
),
clip
=
False
),
imgaug
.
Saturation
(
0.4
),
imgaug
.
Lighting
(
0.1
,
eigval
=
[
0.2175
,
0.0188
,
0.0045
][::
-
1
],
eigvec
=
np
.
array
(
[[
-
0.5675
,
0.7192
,
0.4009
],
[
-
0.5808
,
-
0.0045
,
-
0.8140
],
[
-
0.5836
,
-
0.6948
,
0.4203
]],
dtype
=
'float32'
)[::
-
1
,
::
-
1
]
)]),
imgaug
.
Clip
(),
imgaug
.
Flip
(
horiz
=
True
),
imgaug
.
ToUint8
()
]
else
:
augmentors
=
[
imgaug
.
ResizeShortestEdge
(
256
),
imgaug
.
CenterCrop
((
224
,
224
)),
imgaug
.
ToUint8
()
]
ds
=
AugmentImageComponent
(
ds
,
augmentors
,
copy
=
False
)
if
isTrain
:
ds
=
PrefetchDataZMQ
(
ds
,
min
(
20
,
multiprocessing
.
cpu_count
()))
ds
=
BatchData
(
ds
,
BATCH_SIZE
,
remainder
=
not
isTrain
)
return
ds
def
get_config
():
dataset_train
=
get_data
(
'train'
)
dataset_val
=
get_data
(
'val'
)
return
TrainConfig
(
model
=
Model
(),
dataflow
=
dataset_train
,
callbacks
=
[
ModelSaver
(),
InferenceRunner
(
dataset_val
,
[
ClassificationError
(
'wrong-top1'
,
'val-error-top1'
),
ClassificationError
(
'wrong-top5'
,
'val-error-top5'
)]),
ScheduledHyperParamSetter
(
'learning_rate'
,
[(
30
,
1e-2
),
(
60
,
1e-3
),
(
85
,
1e-4
),
(
95
,
1e-5
)]),
],
steps_per_epoch
=
5000
,
max_epoch
=
110
,
)
def
viz_cam
(
model_file
,
data_dir
):
ds
=
get_data
(
'val'
)
pred_config
=
PredictConfig
(
model
=
Model
(),
session_init
=
get_model_loader
(
model_file
),
input_names
=
[
'input'
,
'label'
],
output_names
=
[
'wrong-top1'
,
'bnlast/Relu'
,
'linearnew/W'
],
return_input
=
True
)
meta
=
dataset
.
ILSVRCMeta
()
.
get_synset_words_1000
()
pred
=
SimpleDatasetPredictor
(
pred_config
,
ds
)
cnt
=
0
for
inp
,
outp
in
pred
.
get_result
():
images
,
labels
=
inp
wrongs
,
convmaps
,
W
=
outp
batch
=
wrongs
.
shape
[
0
]
for
i
in
range
(
batch
):
if
wrongs
[
i
]:
continue
weight
=
W
[:,
[
labels
[
i
]]]
.
T
# 512x1
convmap
=
convmaps
[
i
,:,:,:]
# 512xhxw
mergedmap
=
np
.
matmul
(
weight
,
convmap
.
reshape
((
512
,
-
1
)))
.
reshape
(
14
,
14
)
mergedmap
=
cv2
.
resize
(
mergedmap
,
(
224
,
224
))
heatmap
=
viz
.
intensity_to_rgb
(
mergedmap
,
normalize
=
True
)
blend
=
images
[
i
]
*
0.5
+
heatmap
*
0.5
concat
=
np
.
concatenate
((
images
[
i
],
heatmap
,
blend
),
axis
=
1
)
classname
=
meta
[
labels
[
i
]]
.
split
(
','
)[
0
]
cv2
.
imwrite
(
'cam{}-{}.jpg'
.
format
(
cnt
,
classname
),
concat
)
cnt
+=
1
if
cnt
==
500
:
return
if
__name__
==
'__main__'
:
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
'--gpu'
,
help
=
'comma separated list of GPU(s) to use.'
,
required
=
True
)
parser
.
add_argument
(
'--data'
,
help
=
'ILSVRC dataset dir'
)
parser
.
add_argument
(
'--depth'
,
type
=
int
,
default
=
18
)
parser
.
add_argument
(
'--load'
,
help
=
'load model'
)
parser
.
add_argument
(
'--cam'
,
action
=
'store_true'
)
args
=
parser
.
parse_args
()
DEPTH
=
args
.
depth
os
.
environ
[
'CUDA_VISIBLE_DEVICES'
]
=
args
.
gpu
if
args
.
cam
:
BATCH_SIZE
=
128
# something that can run on one gpu
viz_cam
(
args
.
load
,
args
.
data
)
sys
.
exit
()
NR_GPU
=
len
(
args
.
gpu
.
split
(
','
))
BATCH_SIZE
=
TOTAL_BATCH_SIZE
//
NR_GPU
logger
.
auto_set_dir
()
config
=
get_config
()
if
args
.
load
:
config
.
session_init
=
get_model_loader
(
args
.
load
)
config
.
nr_tower
=
NR_GPU
SyncMultiGPUTrainer
(
config
)
.
train
()
examples/Saliency/README.md
View file @
d6c2d6b3
## Visualize Saliency Maps
## Visualize Saliency Maps
& Class Activation Maps
Implement the Guided-ReLU visualization used in the paper:
*
[
Striving for Simplicity: The All Convolutional Net
](
https://arxiv.org/abs/1412.6806
)
And the class activation mapping (CAM) visualization proposed in the paper:
*
[
Learning Deep Features for Discriminative Localization
](
http://cnnlocalization.csail.mit.edu/
)
## Saliency Maps
`saliency-maps.py`
takes an image, and produce its saliency map by running a ResNet-50 and backprop its maximum
activations back to the input image space.
Similar techinques can be used to visualize the concept learned by each filter in the network.
...
...
@@ -23,3 +29,21 @@ Left to right:
+
the magnitude blended with the original image
+
positive correlated pixels (keep original color)
+
negative correlated pixels (keep original color)
## CAM
`CAM-resnet.py`
fine-tune a variant of ResNet to have 2x larger last-layer feature maps, then produce CAM visualizations.
Usage:
1.
Fine tune or retrain the ResNet:
```
bash
./CAM-resnet.py
--data
/path/to/imagenet
[
--load
ImageNet-ResNet18.npy]
[
--gpu
0,1,2,3]
```
Pretrained and fine-tuned ResNet can be downloaded
[
here
](
https://drive.google.com/open?id=0B9IPQTvr2BBkTXBlZmh1cmlnQ0k
)
and
[
here
](
https://drive.google.com/open?id=0B9IPQTvr2BBkQk9qcmtGSERlNUk
)
.
2.
Generate CAM on ImageNet validation set:
```
bash
./CAM-resnet.py
--data
/path/to/imagenet
--load
ImageNet-ResNet18-2xGAP.npy
--cam
```
<p
align=
"center"
>
<img
src=
"./CAM-demo.jpg"
width=
"900"
>
</p>
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment