Commit 58b571e1 authored by Patrick Wieschollek's avatar Patrick Wieschollek Committed by Yuxin Wu

Add example for extracting saliency using ResNet-50 with tfSlim (#82)

* Add example for extracting saliency using ResNet-50 with tfSlim

* use existing mean. document the symbolic functions

* move function out of utils.py

* further simplify code. move saliency back to symbolic_functions.

* rename the script. add stacked demo image.

* remove cat.jpg
parent 2bec1d62
## Visualize Saliency Maps
Implement the Guided-ReLU visualization used in the paper:
* [Striving for Simplicity: The All Convolutional Net](https://arxiv.org/abs/1412.6806)
`saliency-maps.py` takes an image, and produce its saliency map by running a ResNet-50 and backprop its maximum
activations back to the input image space.
Similar techinques can be used to visualize the concept learned by each filter in the network.
Usage:
````bash
wget http://download.tensorflow.org/models/resnet_v1_50_2016_08_28.tar.gz
tar -xzvf resnet_v1_50_2016_08_28.tar.gz
./saliency-maps.py cat.jpg
````
<p align="center"> <img src="./guided-relu-demo.jpg" width="800"> </p>
Left to right:
+ the original cat image
+ the magnitude in the saliency map
+ the magnitude blended with the original image
+ positive correlated pixels (keep original color)
+ negative correlated pixels (keep original color)
#!/usr/bin/env python
# -*- coding: utf-8 -*-
import cv2
import cPickle as pickle
import sys
import os
import numpy as np
import tensorflow as tf
import tensorflow.contrib.slim as slim
from tensorflow.contrib.slim.nets import resnet_v1
import tensorpack as tp
import tensorpack.utils.viz as viz
IMAGE_SIZE = 224
class Model(tp.ModelDesc):
def _get_input_vars(self):
return [tp.InputVar(tf.float32, (IMAGE_SIZE, IMAGE_SIZE, 3), 'image')]
def _build_graph(self, input_vars):
orig_image = input_vars[0]
mean = tf.get_variable('resnet_v1_50/mean_rgb', shape=[3])
with tp.symbolic_functions.guided_relu():
with slim.arg_scope(resnet_v1.resnet_arg_scope(is_training=False)):
image = tf.expand_dims(orig_image - mean, 0)
logits, _ = resnet_v1.resnet_v1_50(image, 1000)
tp.symbolic_functions.saliency_map(logits, orig_image, name="saliency")
def run(model_path, image_path):
predict_func = tp.OfflinePredictor(tp.PredictConfig(
model=Model(),
session_init=tp.get_model_loader(model_path),
input_names=['image'],
output_names=['saliency']))
im = cv2.imread(image_path)
assert im is not None and im.ndim == 3, image_path
# resnet expect RGB inputs of 224x224x3
im = cv2.resize(im, (IMAGE_SIZE, IMAGE_SIZE))
im = im.astype(np.float32)[:, :, ::-1]
saliency_images = predict_func([im])[0]
abs_saliency = np.abs(saliency_images).max(axis=-1)
pos_saliency = np.maximum(0, saliency_images)
neg_saliency = np.maximum(0, -saliency_images)
pos_saliency -= pos_saliency.min()
pos_saliency /= pos_saliency.max()
cv2.imwrite('pos.jpg', pos_saliency * 255)
neg_saliency -= neg_saliency.min()
neg_saliency /= neg_saliency.max()
cv2.imwrite('neg.jpg', neg_saliency * 255)
abs_saliency = viz.intensity_to_rgb(abs_saliency, normalize=True)[:, :, ::-1] # bgr
cv2.imwrite("abs-saliency.jpg", abs_saliency)
rsl = im * 0.2 + abs_saliency * 0.8
cv2.imwrite("blended.jpg", rsl)
if __name__ == '__main__':
if len(sys.argv) != 2:
tp.logger.error("Usage: {} image.jpg".format(sys.argv[0]))
sys.exit(1)
run("resnet_v1_50.ckpt", sys.argv[1])
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
# Author: Yuxin Wu <ppwwyyxx@gmail.com> # Author: Yuxin Wu <ppwwyyxx@gmail.com>
import tensorflow as tf import tensorflow as tf
from contextlib import contextmanager
import numpy as np import numpy as np
...@@ -146,3 +147,68 @@ def get_scalar_var(name, init_value, summary=False, trainable=False): ...@@ -146,3 +147,68 @@ def get_scalar_var(name, init_value, summary=False, trainable=False):
# this is recognized in callbacks.StatHolder # this is recognized in callbacks.StatHolder
tf.summary.scalar(name + '-summary', ret) tf.summary.scalar(name + '-summary', ret)
return ret return ret
def psnr_loss(prediction, ground_truth, name='psnr_loss'):
"""Negative `Peek Signal to Noise Ratio <https://en.wikipedia.org/wiki/Peak_signal-to-noise_ratio>`_.
.. math::
PSNR = 20 \cdot log_{10}(MAX_p) - 10 \cdot log_{10}(MSE)
This function assumes the maximum possible value of the signal is 1,
therefore the PSNR is simply ``- 10 * log10(MSE)``.
Args:
prediction: a :class:`tf.Tensor` representing the prediction signal.
ground_truth: another :class:`tf.Tensor` with the same shape.
Returns:
A scalar tensor. The negative PSNR (for minimization).
"""
def log10(x):
numerator = tf.log(x)
denominator = tf.log(tf.constant(10, dtype=numerator.dtype))
return numerator / denominator
return tf.multiply(log10(tf.reduce_mean(tf.square(prediction - ground_truth))),
10., name=name)
@contextmanager
def guided_relu():
"""
Returns:
A context where the gradient of :meth:`tf.nn.relu` is replaced by
guided back-propagation, as described in the paper:
`Striving for Simplicity: The All Convolutional Net
<https://arxiv.org/abs/1412.6806>`_
"""
from tensorflow.python.ops import gen_nn_ops # noqa
@tf.RegisterGradient("GuidedReLU")
def _GuidedReluGrad(op, grad):
return tf.where(0. < grad,
gen_nn_ops._relu_grad(grad, op.outputs[0]),
tf.zeros(grad.get_shape()))
g = tf.get_default_graph()
with g.gradient_override_map({'Relu': 'GuidedReLU'}):
yield
def saliency_map(output, input, name="saliency_map"):
"""
Produce a saliency map as described in the paper:
`Deep Inside Convolutional Networks: Visualising Image Classification Models and Saliency Maps
<https://arxiv.org/abs/1312.6034>`_.
The saliency map is the gradient of the max element in output w.r.t input.
Returns:
tf.Tensor: the saliency map. Has the same shape as input.
"""
max_outp = tf.reduce_max(output, 1)
saliency_op = tf.gradients(max_outp, input)[:][0]
saliency_op = tf.identity(saliency_op, name=name)
return saliency_op
...@@ -16,7 +16,7 @@ __all__ = ['change_env', ...@@ -16,7 +16,7 @@ __all__ = ['change_env',
'get_tqdm_kwargs', 'get_tqdm_kwargs',
'get_tqdm', 'get_tqdm',
'execute_only_once', 'execute_only_once',
'building_rtfd' 'building_rtfd',
] ]
......
...@@ -16,8 +16,9 @@ try: ...@@ -16,8 +16,9 @@ try:
except ImportError: except ImportError:
pass pass
__all__ = ['pyplot2img', 'interactive_imshow', 'build_patch_list', __all__ = ['pyplot2img', 'interactive_imshow', 'build_patch_list',
'pyplot_viz', 'dump_dataflow_images'] 'pyplot_viz', 'dump_dataflow_images', 'intensity_to_rgb']
def pyplot2img(plt): def pyplot2img(plt):
...@@ -242,6 +243,34 @@ def dump_dataflow_images(df, index=0, batched=True, ...@@ -242,6 +243,34 @@ def dump_dataflow_images(df, index=0, batched=True,
vizlist = vizlist[vizsize:] vizlist = vizlist[vizsize:]
def intensity_to_rgb(intensity, cmap='jet', normalize=False):
"""
Convert a 1-channel matrix of intensities to an RGB image employing a colormap.
This function requires matplotlib. See `matplotlib colormaps
<http://matplotlib.org/examples/color/colormaps_reference.html>`_ for a
list of available colormap.
Args:
intensity (np.ndarray): array of intensities such as saliency.
cmap (str): name of the colormap to use.
normalize (bool): if True, will normalize the intensity so that it has
minimum 0 and maximum 1.
Returns:
np.ndarray: an RGB float32 image in range [0, 255], a colored heatmap.
"""
assert intensity.ndim == 2, intensity.shape
intensity = intensity.astype("float")
if normalize:
intensity -= intensity.min()
intensity /= intensity.max()
cmap = plt.get_cmap(cmap)
intensity = cmap(intensity)[..., :3]
return intensity.astype('float32') * 255.0
if __name__ == '__main__': if __name__ == '__main__':
imglist = [] imglist = []
for i in range(100): for i in range(100):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment