Commit 085190b6 authored by Yuxin Wu's avatar Yuxin Wu

put image to tensorboard (fix #239)

parent f0017ad5
......@@ -20,7 +20,7 @@ See some [examples](examples) to learn about the framework:
+ [Asynchronous Advantage Actor-Critic(A3C) with demos on OpenAI Gym](examples/A3C-Gym)
### Unsupervised Learning:
+ [Generative Adversarial Network(GAN) variants](examples/GAN), including DCGAN, InfoGAN, Conditional GAN, WGAN, DiscoGAN, Image to Image.
+ [Generative Adversarial Network(GAN) variants](examples/GAN), including DCGAN, InfoGAN, Conditional GAN, WGAN, BEGAN, DiscoGAN, Image to Image.
### Speech / NLP:
+ [LSTM-CTC for speech recognition](examples/CTC-TIMIT)
......
......@@ -41,7 +41,7 @@ for filename, label in filelist:
And `ds1` batch the datapoints from `ds0`, so that we can measure the speed of this DataFlow in terms of "batch per second".
By default, `BatchData`
will stack the datapoints into an `numpy.ndarray`, but since images are original of different shapes, we use
will stack the datapoints into an `numpy.ndarray`, but since original images are of different shapes, we use
`use_list=True` so that it just produces lists.
On an SSD you probably can already observe good speed here (e.g. 5 it/s, that is 1280 samples/s), but on HDD the speed may be just 1 it/s,
......
......@@ -23,7 +23,7 @@ class DumpParamAsImage(Callback):
the input pipeline).
"""
def __init__(self, tensor_name, prefix=None, map_func=None, scale=255, clip=False):
def __init__(self, tensor_name, prefix=None, map_func=None, scale=255):
"""
Args:
tensor_name (str): the name of the tensor.
......@@ -31,7 +31,6 @@ class DumpParamAsImage(Callback):
map_func: map the value of the tensor to an image or list of
images of shape [h, w] or [h, w, c]. If None, will use identity.
scale (float): a multiplier on pixel values, applied after map_func.
clip (bool): whether to clip the result to [0, 255].
"""
op_name, self.tensor_name = get_op_tensor_name(tensor_name)
self.func = map_func
......@@ -41,7 +40,6 @@ class DumpParamAsImage(Callback):
self.prefix = prefix
self.log_dir = logger.LOG_DIR
self.scale = scale
self.clip = clip
def _before_train(self):
# TODO might not work for multiGPU?
......@@ -56,6 +54,7 @@ class DumpParamAsImage(Callback):
self._dump_image(im, idx)
else:
self._dump_image(val)
self.trainer.monitors.put_image(self.prefix, val)
def _dump_image(self, im, idx=None):
assert im.ndim in [2, 3], str(im.ndim)
......@@ -64,6 +63,5 @@ class DumpParamAsImage(Callback):
self.prefix + '-ep{:03d}{}.png'.format(
self.epoch_num, '-' + str(idx) if idx else ''))
res = im * self.scale
if self.clip:
res = np.clip(res, 0, 255)
res = np.clip(res, 0, 255)
cv2.imwrite(fname, res.astype('uint8'))
......@@ -47,10 +47,10 @@ def summary_inferencer(trainer, infs):
for k, v in six.iteritems(ret):
try:
v = float(v)
trainer.monitors.put_scalar(k, v)
except:
logger.warn("{} returns a non-scalar statistics!".format(type(inf).__name__))
continue
trainer.monitors.put(k, v)
@six.add_metaclass(ABCMeta)
......
......@@ -4,6 +4,7 @@
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
import os
import numpy as np
import shutil
import operator
from collections import defaultdict
......@@ -13,12 +14,28 @@ import re
import tensorflow as tf
from ..utils import logger
from ..tfutils.summary import create_scalar_summary, create_image_summary
from .base import Callback
__all__ = ['TrainingMonitor', 'Monitors',
'TFSummaryWriter', 'JSONWriter', 'ScalarPrinter', 'SendMonitorData']
def image_to_nhwc(arr):
if arr.ndim == 4:
pass
elif arr.ndim == 3:
if arr.shape[-1] in [1, 3, 4]:
arr = arr[np.newaxis, :]
else:
arr = arr[:, :, :, np.newaxis]
elif arr.ndim == 2:
arr = arr[np.newaxis, :, :, np.newaxis]
else:
raise ValueError("Array of shape {} is not an image!".format(arr.shape))
return arr
class TrainingMonitor(Callback):
"""
Monitor a training progress, by processing different types of
......@@ -48,7 +65,15 @@ class TrainingMonitor(Callback):
pass
def put_scalar(self, name, val):
self.put(name, val)
pass
def put_image(self, name, val):
"""
Args:
val (np.ndarray): 4D (NHWC) numpy array of images.
If channel is 3, assumed to be RGB.
"""
pass
# TODO put other types
......@@ -76,11 +101,16 @@ class Monitors(TrainingMonitor):
for m in self._monitors:
m.put_scalar(name, val)
def _dispatch_put_image(self, name, val):
for m in self._monitors:
m.put_image(name, val)
def put_summary(self, summary):
if isinstance(summary, six.binary_type):
summary = tf.Summary.FromString(summary)
assert isinstance(summary, tf.Summary), type(summary)
# TODO remove -summary suffix for summary
self._dispatch_put_summary(summary)
# TODO other types
......@@ -98,8 +128,20 @@ class Monitors(TrainingMonitor):
def put_scalar(self, name, val):
self._dispatch_put_scalar(name, val)
s = tf.Summary()
s.value.add(tag=name, simple_value=val)
s = create_scalar_summary(name, val)
self._dispatch_put_summary(s)
def put_image(self, name, val):
"""
Args:
name (str):
val (np.ndarray): 2D, 3D (HWC) or 4D (NHWC) numpy array of images.
If channel is 3, assumed to be RGB.
"""
assert isinstance(val, np.ndarray)
arr = image_to_nhwc(val)
self._dispatch_put_image(name, arr)
s = create_image_summary(name, arr)
self._dispatch_put_summary(s)
def get_latest(self, name):
......
......@@ -4,7 +4,9 @@
import six
import tensorflow as tf
import cv2
import re
from six.moves import range
from ..utils import logger
from ..utils.develop import log_deprecated
......@@ -29,6 +31,33 @@ def create_scalar_summary(name, v):
return s
def create_image_summary(name, val):
"""
Args:
name(str):
val(np.ndarray): 4D tensor of NHWC
Returns:
tf.Summary:
"""
assert isinstance(name, six.string_types), type(name)
n, h, w, c = val.shape
s = tf.Summary()
for k in range(n):
tag = name if n == 1 else '{}/{}'.format(name, k)
ret, buf = cv2.imencode('.png', val[k])
assert ret, "imencode failed!"
img = tf.Summary.Image()
img.height = h
img.width = w
# 1 - grayscale 3 - RGB 4 - RGBA
img.colorspace = c
img.encoded_image_string = buf.tostring()
s.value.add(tag=tag, image=img)
return s
def add_activation_summary(x, name=None):
"""
Add summary for an activation tensor x. If name is None, use x.name.
......
......@@ -87,7 +87,7 @@ def get_dataset_path(*args):
old_d = os.path.abspath(os.path.join(
os.path.dirname(__file__), '..', 'dataflow', 'dataset'))
old_d_ret = os.path.join(old_d, *args)
new_d = os.path.expanduser('~/tensorpack_data')
new_d = os.path.join(os.path.expanduser('~'), 'tensorpack_data')
if os.path.isdir(old_d_ret):
# there is an old dir containing data, use it for back-compat
logger.warn("You seem to have old data at {}. This is no longer \
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment