Commit 085190b6 authored by Yuxin Wu's avatar Yuxin Wu

put image to tensorboard (fix #239)

parent f0017ad5
...@@ -20,7 +20,7 @@ See some [examples](examples) to learn about the framework: ...@@ -20,7 +20,7 @@ See some [examples](examples) to learn about the framework:
+ [Asynchronous Advantage Actor-Critic(A3C) with demos on OpenAI Gym](examples/A3C-Gym) + [Asynchronous Advantage Actor-Critic(A3C) with demos on OpenAI Gym](examples/A3C-Gym)
### Unsupervised Learning: ### Unsupervised Learning:
+ [Generative Adversarial Network(GAN) variants](examples/GAN), including DCGAN, InfoGAN, Conditional GAN, WGAN, DiscoGAN, Image to Image. + [Generative Adversarial Network(GAN) variants](examples/GAN), including DCGAN, InfoGAN, Conditional GAN, WGAN, BEGAN, DiscoGAN, Image to Image.
### Speech / NLP: ### Speech / NLP:
+ [LSTM-CTC for speech recognition](examples/CTC-TIMIT) + [LSTM-CTC for speech recognition](examples/CTC-TIMIT)
......
...@@ -41,7 +41,7 @@ for filename, label in filelist: ...@@ -41,7 +41,7 @@ for filename, label in filelist:
And `ds1` batch the datapoints from `ds0`, so that we can measure the speed of this DataFlow in terms of "batch per second". And `ds1` batch the datapoints from `ds0`, so that we can measure the speed of this DataFlow in terms of "batch per second".
By default, `BatchData` By default, `BatchData`
will stack the datapoints into an `numpy.ndarray`, but since images are original of different shapes, we use will stack the datapoints into an `numpy.ndarray`, but since original images are of different shapes, we use
`use_list=True` so that it just produces lists. `use_list=True` so that it just produces lists.
On an SSD you probably can already observe good speed here (e.g. 5 it/s, that is 1280 samples/s), but on HDD the speed may be just 1 it/s, On an SSD you probably can already observe good speed here (e.g. 5 it/s, that is 1280 samples/s), but on HDD the speed may be just 1 it/s,
......
...@@ -23,7 +23,7 @@ class DumpParamAsImage(Callback): ...@@ -23,7 +23,7 @@ class DumpParamAsImage(Callback):
the input pipeline). the input pipeline).
""" """
def __init__(self, tensor_name, prefix=None, map_func=None, scale=255, clip=False): def __init__(self, tensor_name, prefix=None, map_func=None, scale=255):
""" """
Args: Args:
tensor_name (str): the name of the tensor. tensor_name (str): the name of the tensor.
...@@ -31,7 +31,6 @@ class DumpParamAsImage(Callback): ...@@ -31,7 +31,6 @@ class DumpParamAsImage(Callback):
map_func: map the value of the tensor to an image or list of map_func: map the value of the tensor to an image or list of
images of shape [h, w] or [h, w, c]. If None, will use identity. images of shape [h, w] or [h, w, c]. If None, will use identity.
scale (float): a multiplier on pixel values, applied after map_func. scale (float): a multiplier on pixel values, applied after map_func.
clip (bool): whether to clip the result to [0, 255].
""" """
op_name, self.tensor_name = get_op_tensor_name(tensor_name) op_name, self.tensor_name = get_op_tensor_name(tensor_name)
self.func = map_func self.func = map_func
...@@ -41,7 +40,6 @@ class DumpParamAsImage(Callback): ...@@ -41,7 +40,6 @@ class DumpParamAsImage(Callback):
self.prefix = prefix self.prefix = prefix
self.log_dir = logger.LOG_DIR self.log_dir = logger.LOG_DIR
self.scale = scale self.scale = scale
self.clip = clip
def _before_train(self): def _before_train(self):
# TODO might not work for multiGPU? # TODO might not work for multiGPU?
...@@ -56,6 +54,7 @@ class DumpParamAsImage(Callback): ...@@ -56,6 +54,7 @@ class DumpParamAsImage(Callback):
self._dump_image(im, idx) self._dump_image(im, idx)
else: else:
self._dump_image(val) self._dump_image(val)
self.trainer.monitors.put_image(self.prefix, val)
def _dump_image(self, im, idx=None): def _dump_image(self, im, idx=None):
assert im.ndim in [2, 3], str(im.ndim) assert im.ndim in [2, 3], str(im.ndim)
...@@ -64,6 +63,5 @@ class DumpParamAsImage(Callback): ...@@ -64,6 +63,5 @@ class DumpParamAsImage(Callback):
self.prefix + '-ep{:03d}{}.png'.format( self.prefix + '-ep{:03d}{}.png'.format(
self.epoch_num, '-' + str(idx) if idx else '')) self.epoch_num, '-' + str(idx) if idx else ''))
res = im * self.scale res = im * self.scale
if self.clip: res = np.clip(res, 0, 255)
res = np.clip(res, 0, 255)
cv2.imwrite(fname, res.astype('uint8')) cv2.imwrite(fname, res.astype('uint8'))
...@@ -47,10 +47,10 @@ def summary_inferencer(trainer, infs): ...@@ -47,10 +47,10 @@ def summary_inferencer(trainer, infs):
for k, v in six.iteritems(ret): for k, v in six.iteritems(ret):
try: try:
v = float(v) v = float(v)
trainer.monitors.put_scalar(k, v)
except: except:
logger.warn("{} returns a non-scalar statistics!".format(type(inf).__name__)) logger.warn("{} returns a non-scalar statistics!".format(type(inf).__name__))
continue continue
trainer.monitors.put(k, v)
@six.add_metaclass(ABCMeta) @six.add_metaclass(ABCMeta)
......
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
# Author: Yuxin Wu <ppwwyyxxc@gmail.com> # Author: Yuxin Wu <ppwwyyxxc@gmail.com>
import os import os
import numpy as np
import shutil import shutil
import operator import operator
from collections import defaultdict from collections import defaultdict
...@@ -13,12 +14,28 @@ import re ...@@ -13,12 +14,28 @@ import re
import tensorflow as tf import tensorflow as tf
from ..utils import logger from ..utils import logger
from ..tfutils.summary import create_scalar_summary, create_image_summary
from .base import Callback from .base import Callback
__all__ = ['TrainingMonitor', 'Monitors', __all__ = ['TrainingMonitor', 'Monitors',
'TFSummaryWriter', 'JSONWriter', 'ScalarPrinter', 'SendMonitorData'] 'TFSummaryWriter', 'JSONWriter', 'ScalarPrinter', 'SendMonitorData']
def image_to_nhwc(arr):
if arr.ndim == 4:
pass
elif arr.ndim == 3:
if arr.shape[-1] in [1, 3, 4]:
arr = arr[np.newaxis, :]
else:
arr = arr[:, :, :, np.newaxis]
elif arr.ndim == 2:
arr = arr[np.newaxis, :, :, np.newaxis]
else:
raise ValueError("Array of shape {} is not an image!".format(arr.shape))
return arr
class TrainingMonitor(Callback): class TrainingMonitor(Callback):
""" """
Monitor a training progress, by processing different types of Monitor a training progress, by processing different types of
...@@ -48,7 +65,15 @@ class TrainingMonitor(Callback): ...@@ -48,7 +65,15 @@ class TrainingMonitor(Callback):
pass pass
def put_scalar(self, name, val): def put_scalar(self, name, val):
self.put(name, val) pass
def put_image(self, name, val):
"""
Args:
val (np.ndarray): 4D (NHWC) numpy array of images.
If channel is 3, assumed to be RGB.
"""
pass
# TODO put other types # TODO put other types
...@@ -76,11 +101,16 @@ class Monitors(TrainingMonitor): ...@@ -76,11 +101,16 @@ class Monitors(TrainingMonitor):
for m in self._monitors: for m in self._monitors:
m.put_scalar(name, val) m.put_scalar(name, val)
def _dispatch_put_image(self, name, val):
for m in self._monitors:
m.put_image(name, val)
def put_summary(self, summary): def put_summary(self, summary):
if isinstance(summary, six.binary_type): if isinstance(summary, six.binary_type):
summary = tf.Summary.FromString(summary) summary = tf.Summary.FromString(summary)
assert isinstance(summary, tf.Summary), type(summary) assert isinstance(summary, tf.Summary), type(summary)
# TODO remove -summary suffix for summary
self._dispatch_put_summary(summary) self._dispatch_put_summary(summary)
# TODO other types # TODO other types
...@@ -98,8 +128,20 @@ class Monitors(TrainingMonitor): ...@@ -98,8 +128,20 @@ class Monitors(TrainingMonitor):
def put_scalar(self, name, val): def put_scalar(self, name, val):
self._dispatch_put_scalar(name, val) self._dispatch_put_scalar(name, val)
s = tf.Summary() s = create_scalar_summary(name, val)
s.value.add(tag=name, simple_value=val) self._dispatch_put_summary(s)
def put_image(self, name, val):
"""
Args:
name (str):
val (np.ndarray): 2D, 3D (HWC) or 4D (NHWC) numpy array of images.
If channel is 3, assumed to be RGB.
"""
assert isinstance(val, np.ndarray)
arr = image_to_nhwc(val)
self._dispatch_put_image(name, arr)
s = create_image_summary(name, arr)
self._dispatch_put_summary(s) self._dispatch_put_summary(s)
def get_latest(self, name): def get_latest(self, name):
......
...@@ -4,7 +4,9 @@ ...@@ -4,7 +4,9 @@
import six import six
import tensorflow as tf import tensorflow as tf
import cv2
import re import re
from six.moves import range
from ..utils import logger from ..utils import logger
from ..utils.develop import log_deprecated from ..utils.develop import log_deprecated
...@@ -29,6 +31,33 @@ def create_scalar_summary(name, v): ...@@ -29,6 +31,33 @@ def create_scalar_summary(name, v):
return s return s
def create_image_summary(name, val):
"""
Args:
name(str):
val(np.ndarray): 4D tensor of NHWC
Returns:
tf.Summary:
"""
assert isinstance(name, six.string_types), type(name)
n, h, w, c = val.shape
s = tf.Summary()
for k in range(n):
tag = name if n == 1 else '{}/{}'.format(name, k)
ret, buf = cv2.imencode('.png', val[k])
assert ret, "imencode failed!"
img = tf.Summary.Image()
img.height = h
img.width = w
# 1 - grayscale 3 - RGB 4 - RGBA
img.colorspace = c
img.encoded_image_string = buf.tostring()
s.value.add(tag=tag, image=img)
return s
def add_activation_summary(x, name=None): def add_activation_summary(x, name=None):
""" """
Add summary for an activation tensor x. If name is None, use x.name. Add summary for an activation tensor x. If name is None, use x.name.
......
...@@ -87,7 +87,7 @@ def get_dataset_path(*args): ...@@ -87,7 +87,7 @@ def get_dataset_path(*args):
old_d = os.path.abspath(os.path.join( old_d = os.path.abspath(os.path.join(
os.path.dirname(__file__), '..', 'dataflow', 'dataset')) os.path.dirname(__file__), '..', 'dataflow', 'dataset'))
old_d_ret = os.path.join(old_d, *args) old_d_ret = os.path.join(old_d, *args)
new_d = os.path.expanduser('~/tensorpack_data') new_d = os.path.join(os.path.expanduser('~'), 'tensorpack_data')
if os.path.isdir(old_d_ret): if os.path.isdir(old_d_ret):
# there is an old dir containing data, use it for back-compat # there is an old dir containing data, use it for back-compat
logger.warn("You seem to have old data at {}. This is no longer \ logger.warn("You seem to have old data at {}. This is no longer \
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment