Commit 16216c6a authored by Yuxin Wu's avatar Yuxin Wu

don't use jet as default.

parent 58b571e1
...@@ -11,9 +11,6 @@ import argparse ...@@ -11,9 +11,6 @@ import argparse
from tensorpack import * from tensorpack import *
from tensorpack.utils.argtools import memoized from tensorpack.utils.argtools import memoized
import matplotlib.pyplot as plt
_CM = plt.get_cmap('jet')
""" """
15 channels: 15 channels:
0-1 head, neck 0-1 head, neck
...@@ -29,7 +26,7 @@ def colorize(img, heatmap): ...@@ -29,7 +26,7 @@ def colorize(img, heatmap):
""" img: bgr, [0,255] """ img: bgr, [0,255]
heatmap: [0,1] heatmap: [0,1]
""" """
heatmap = _CM(heatmap)[:, :, [2, 1, 0]] * 255.0 heatmap = viz.intensity_to_rgb(heatmap, cmap='jet')[:, :, ::-1]
return img * 0.5 + heatmap * 0.5 return img * 0.5 + heatmap * 0.5
......
...@@ -9,7 +9,7 @@ produce the sum of them. ...@@ -9,7 +9,7 @@ produce the sum of them.
Here the two Spatial Transformer branches learn to localize the two digits Here the two Spatial Transformer branches learn to localize the two digits
and warped them separately. and warped them separately.
![demo](demo.jpg) <p align="center"> <img src="./demo.jpg" width="400"> </p>
Left: input image; Middle: output of the first STN branch (which localizes the second digit); Right: output of the second STN branch. Left: input image; Middle: output of the first STN branch (which localizes the second digit); Right: output of the second STN branch.
......
...@@ -2,13 +2,13 @@ ...@@ -2,13 +2,13 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
import cv2 import cv2
import cPickle as pickle
import sys import sys
import os import os
import numpy as np import numpy as np
import tensorflow as tf import tensorflow as tf
import tensorflow.contrib.slim as slim import tensorflow.contrib.slim as slim
from tensorflow.contrib.slim.nets import resnet_v1 from tensorflow.contrib.slim.nets import resnet_v1
import tensorpack as tp import tensorpack as tp
import tensorpack.utils.viz as viz import tensorpack.utils.viz as viz
......
...@@ -102,8 +102,11 @@ class Trainer(object): ...@@ -102,8 +102,11 @@ class Trainer(object):
add scalar summary to ``self.stat_holder``. add scalar summary to ``self.stat_holder``.
Args: Args:
summary (tf.Summary): a summary object. summary (tf.Summary or str): a summary object, or a str which will
be interpreted as a serialized tf.Summary protobuf.
""" """
if isinstance(summary, six.string_types):
summary = tf.Summary.FromString(summary)
for val in summary.value: for val in summary.value:
if val.WhichOneof('value') == 'simple_value': if val.WhichOneof('value') == 'simple_value':
val.tag = re.sub('tower[p0-9]+/', '', val.tag) # TODO move to subclasses val.tag = re.sub('tower[p0-9]+/', '', val.tag) # TODO move to subclasses
...@@ -116,10 +119,6 @@ class Trainer(object): ...@@ -116,10 +119,6 @@ class Trainer(object):
def add_scalar_summary(self, name, val): def add_scalar_summary(self, name, val):
""" """
Add a scalar sumary to both TF events file and StatHolder. Add a scalar sumary to both TF events file and StatHolder.
Args:
name(str)
val(float)
""" """
self.add_summary(create_scalar_summary(name, val)) self.add_summary(create_scalar_summary(name, val))
......
...@@ -29,7 +29,7 @@ class FeedfreeTrainerBase(Trainer): ...@@ -29,7 +29,7 @@ class FeedfreeTrainerBase(Trainer):
# note that summary_op will take a data from the queue # note that summary_op will take a data from the queue
if self.summary_op is not None: if self.summary_op is not None:
summary_str = self.summary_op.eval() summary_str = self.summary_op.eval()
self.add_summary(tf.Summary.FromString(summary_str)) self.add_summary(summary_str)
def _get_input_tensors(self): def _get_input_tensors(self):
return self._input_method.get_input_tensors() return self._input_method.get_input_tensors()
......
...@@ -95,7 +95,7 @@ class SimpleTrainer(Trainer): ...@@ -95,7 +95,7 @@ class SimpleTrainer(Trainer):
if self.summary_op is not None: if self.summary_op is not None:
feed = self._input_method.next_feed() feed = self._input_method.next_feed()
summary_str = self.summary_op.eval(feed_dict=feed) summary_str = self.summary_op.eval(feed_dict=feed)
self.add_summary(tf.Summary.FromString(summary_str)) self.add_summary(summary_str)
def get_predict_func(self, input_names, output_names): def get_predict_func(self, input_names, output_names):
return self._predictor_factory.get_predictor(input_names, output_names, 0) return self._predictor_factory.get_predictor(input_names, output_names, 0)
......
...@@ -243,7 +243,7 @@ def dump_dataflow_images(df, index=0, batched=True, ...@@ -243,7 +243,7 @@ def dump_dataflow_images(df, index=0, batched=True,
vizlist = vizlist[vizsize:] vizlist = vizlist[vizsize:]
def intensity_to_rgb(intensity, cmap='jet', normalize=False): def intensity_to_rgb(intensity, cmap='cubehelix', normalize=False):
""" """
Convert a 1-channel matrix of intensities to an RGB image employing a colormap. Convert a 1-channel matrix of intensities to an RGB image employing a colormap.
This function requires matplotlib. See `matplotlib colormaps This function requires matplotlib. See `matplotlib colormaps
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment