Commit 16216c6a authored by Yuxin Wu's avatar Yuxin Wu

don't use jet as default.

parent 58b571e1
......@@ -11,9 +11,6 @@ import argparse
from tensorpack import *
from tensorpack.utils.argtools import memoized
import matplotlib.pyplot as plt
_CM = plt.get_cmap('jet')
"""
15 channels:
0-1 head, neck
......@@ -29,7 +26,7 @@ def colorize(img, heatmap):
""" img: bgr, [0,255]
heatmap: [0,1]
"""
heatmap = _CM(heatmap)[:, :, [2, 1, 0]] * 255.0
heatmap = viz.intensity_to_rgb(heatmap, cmap='jet')[:, :, ::-1]
return img * 0.5 + heatmap * 0.5
......
......@@ -9,7 +9,7 @@ produce the sum of them.
Here the two Spatial Transformer branches learn to localize the two digits
and warped them separately.
![demo](demo.jpg)
<p align="center"> <img src="./demo.jpg" width="400"> </p>
Left: input image; Middle: output of the first STN branch (which localizes the second digit); Right: output of the second STN branch.
......
......@@ -2,13 +2,13 @@
# -*- coding: utf-8 -*-
import cv2
import cPickle as pickle
import sys
import os
import numpy as np
import tensorflow as tf
import tensorflow.contrib.slim as slim
from tensorflow.contrib.slim.nets import resnet_v1
import tensorpack as tp
import tensorpack.utils.viz as viz
......
......@@ -102,8 +102,11 @@ class Trainer(object):
add scalar summary to ``self.stat_holder``.
Args:
summary (tf.Summary): a summary object.
summary (tf.Summary or str): a summary object, or a str which will
be interpreted as a serialized tf.Summary protobuf.
"""
if isinstance(summary, six.string_types):
summary = tf.Summary.FromString(summary)
for val in summary.value:
if val.WhichOneof('value') == 'simple_value':
val.tag = re.sub('tower[p0-9]+/', '', val.tag) # TODO move to subclasses
......@@ -116,10 +119,6 @@ class Trainer(object):
def add_scalar_summary(self, name, val):
"""
Add a scalar sumary to both TF events file and StatHolder.
Args:
name(str)
val(float)
"""
self.add_summary(create_scalar_summary(name, val))
......
......@@ -29,7 +29,7 @@ class FeedfreeTrainerBase(Trainer):
# note that summary_op will take a data from the queue
if self.summary_op is not None:
summary_str = self.summary_op.eval()
self.add_summary(tf.Summary.FromString(summary_str))
self.add_summary(summary_str)
def _get_input_tensors(self):
return self._input_method.get_input_tensors()
......
......@@ -95,7 +95,7 @@ class SimpleTrainer(Trainer):
if self.summary_op is not None:
feed = self._input_method.next_feed()
summary_str = self.summary_op.eval(feed_dict=feed)
self.add_summary(tf.Summary.FromString(summary_str))
self.add_summary(summary_str)
def get_predict_func(self, input_names, output_names):
return self._predictor_factory.get_predictor(input_names, output_names, 0)
......
......@@ -243,7 +243,7 @@ def dump_dataflow_images(df, index=0, batched=True,
vizlist = vizlist[vizsize:]
def intensity_to_rgb(intensity, cmap='jet', normalize=False):
def intensity_to_rgb(intensity, cmap='cubehelix', normalize=False):
"""
Convert a 1-channel matrix of intensities to an RGB image employing a colormap.
This function requires matplotlib. See `matplotlib colormaps
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment