Commit d2f95645 authored by Yuxin Wu's avatar Yuxin Wu

correct TF version in saliency example

parent 758ae94a
...@@ -11,6 +11,7 @@ from tensorflow.contrib.slim.nets import resnet_v1 ...@@ -11,6 +11,7 @@ from tensorflow.contrib.slim.nets import resnet_v1
import tensorpack as tp import tensorpack as tp
import tensorpack.utils.viz as viz import tensorpack.utils.viz as viz
from tensorpack.tfutils import get_tf_version_tuple
IMAGE_SIZE = 224 IMAGE_SIZE = 224
...@@ -29,7 +30,7 @@ def guided_relu(): ...@@ -29,7 +30,7 @@ def guided_relu():
@tf.RegisterGradient("GuidedReLU") @tf.RegisterGradient("GuidedReLU")
def GuidedReluGrad(op, grad): def GuidedReluGrad(op, grad):
return tf.where(0. < grad, return tf.where(0. < grad,
gen_nn_ops._relu_grad(grad, op.outputs[0]), gen_nn_ops.relu_grad(grad, op.outputs[0]),
tf.zeros(grad.get_shape())) tf.zeros(grad.get_shape()))
g = tf.get_default_graph() g = tf.get_default_graph()
...@@ -59,9 +60,9 @@ class Model(tp.ModelDescBase): ...@@ -59,9 +60,9 @@ class Model(tp.ModelDescBase):
def build_graph(self, orig_image): def build_graph(self, orig_image):
mean = tf.get_variable('resnet_v1_50/mean_rgb', shape=[3]) mean = tf.get_variable('resnet_v1_50/mean_rgb', shape=[3])
with guided_relu(): with guided_relu():
with slim.arg_scope(resnet_v1.resnet_arg_scope(is_training=False)): with slim.arg_scope(resnet_v1.resnet_arg_scope()):
image = tf.expand_dims(orig_image - mean, 0) image = tf.expand_dims(orig_image - mean, 0)
logits, _ = resnet_v1.resnet_v1_50(image, 1000) logits, _ = resnet_v1.resnet_v1_50(image, 1000, is_training=False)
saliency_map(logits, orig_image, name="saliency") saliency_map(logits, orig_image, name="saliency")
...@@ -103,4 +104,5 @@ if __name__ == '__main__': ...@@ -103,4 +104,5 @@ if __name__ == '__main__':
if len(sys.argv) != 2: if len(sys.argv) != 2:
tp.logger.error("Usage: {} image.jpg".format(sys.argv[0])) tp.logger.error("Usage: {} image.jpg".format(sys.argv[0]))
sys.exit(1) sys.exit(1)
assert get_tf_version_tuple() >= (1, 7), "requires TF >= 1.7"
run("resnet_v1_50.ckpt", sys.argv[1]) run("resnet_v1_50.ckpt", sys.argv[1])
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment