Commit b8349bcf authored by Patrick Wieschollek's avatar Patrick Wieschollek Committed by Yuxin Wu

Printdata (#657)

* a simple generic unction to visualize tensors allowing the copy-past
in other examples

* PrintData was buggy

This fixes the output order of PrintData and supports a constrained
recursion to inspect nested objects. Further, it supports more types
and keeps the output minimal by providing only necessary information
(no shape information for scalars).

* fix doc

* cleanup
parent 68edaa0c
......@@ -45,6 +45,23 @@ def BNLReLU(x, name=None):
return tf.nn.leaky_relu(x, alpha=0.2, name=name)
def visualize_tensors(name, imgs, scale_func=lambda x: (x + 1.) * 128., max_outputs=1):
"""Generate tensor for TensorBoard (casting, clipping)
Args:
name: name for visualization operation
*imgs: multiple tensors as list
scale_func: scale input tensors to fit range [0, 255]
Example:
visualize_tensors('viz1', [img1])
visualize_tensors('viz2', [img1, img2, img3], max_outputs=max(30, BATCH))
"""
xy = scale_func(tf.concat(imgs, axis=2))
xy = tf.cast(tf.clip_by_value(xy, 0, 255), tf.uint8, name='viz')
tf.summary.image(name, xy, max_outputs=30)
class Model(GANModelDesc):
def _get_inputs(self):
SHAPE = 256
......@@ -122,9 +139,8 @@ class Model(GANModelDesc):
if OUT_CH == 1:
output = tf.image.grayscale_to_rgb(output)
fake_output = tf.image.grayscale_to_rgb(fake_output)
viz = (tf.concat([input, output, fake_output], 2) + 1.0) * 128.0
viz = tf.cast(tf.clip_by_value(viz, 0, 255), tf.uint8, name='viz')
tf.summary.image('input,output,fake', viz, max_outputs=max(30, BATCH))
visualize_tensors('input,output,fake', [input, output, fake_output], max_outputs=max(30, BATCH))
self.collect_variables()
......
......@@ -52,7 +52,7 @@ class Model(ModelDesc):
.MaxPooling('pool1', 2)
.Conv2D('conv3')
.FullyConnected('fc0', 512, activation=tf.nn.relu)
.Dropout('dropout', 0.5)
.Dropout('dropout', rate=0.5)
.FullyConnected('fc1', 10, activation=tf.identity)())
tf.nn.softmax(logits, name='prob') # a Bx10 with probabilities
......@@ -97,6 +97,9 @@ class Model(ModelDesc):
def get_data():
train = BatchData(dataset.Mnist('train'), 128)
test = BatchData(dataset.Mnist('test'), 256, remainder=True)
train = PrintData(train)
return train, test
......
......@@ -6,6 +6,7 @@ from __future__ import division
import numpy as np
from copy import copy
import pprint
import itertools
from termcolor import colored
from collections import deque, defaultdict
from six.moves import range, map
......@@ -645,9 +646,9 @@ class PrintData(ProxyDataFlow):
.. code-block:: python
def get_data():
ds = CaffeLMDB('path/to/lmdb')
ds = SomeDataSource('path/to/lmdb')
ds = SomeInscrutableMappings(ds)
ds = PrintData(ds, num=2)
ds = PrintData(ds, num=2, max_list=2)
return ds
ds = get_data()
......@@ -657,23 +658,31 @@ class PrintData(ProxyDataFlow):
[0110 09:22:21 @common.py:589] DataFlow Info:
datapoint 0<2 with 4 components consists of
dp 0: is float of shape () with range [0.0816501893251]
dp 1: is ndarray of shape (64, 64) with range [0.1300, 0.6895]
dp 2: is ndarray of shape (64, 64) with range [-1.2248, 1.2177]
dp 3: is ndarray of shape (9, 9) with range [-0.6045, 0.6045]
0: float with value 0.0816501893251
1: ndarray:int32 of shape (64,) in range [0, 10]
2: ndarray:float32 of shape (64, 64) in range [-1.2248, 1.2177]
3: list of len 50
0: ndarray:int32 of shape (64, 64) in range [-128, 80]
1: ndarray:float32 of shape (64, 64) in range [0.8400, 0.6845]
...
datapoint 1<2 with 4 components consists of
dp 0: is float of shape () with range [5.88252075399]
dp 1: is ndarray of shape (64, 64) with range [0.0072, 0.9371]
dp 2: is ndarray of shape (64, 64) with range [-0.9011, 0.8491]
dp 3: is ndarray of shape (9, 9) with range [-0.5585, 0.5585]
0: float with value 5.88252075399
1: ndarray:int32 of shape (64,) in range [0, 10]
2: ndarray:float32 of shape (64, 64) with range [-0.9011, 0.8491]
3: list of len 50
0: ndarray:int32 of shape (64, 64) in range [-70, 50]
1: ndarray:float32 of shape (64, 64) in range [0.7400, 0.3545]
...
"""
def __init__(self, ds, num=1, label=None, name=None):
def __init__(self, ds, num=1, label=None, name=None, max_depth=3, max_list=3):
"""
Args:
ds(DataFlow): input DataFlow.
num(int): number of dataflow points to print.
name(str, optional): name to identify this DataFlow.
ds (DataFlow): input DataFlow.
num (int): number of dataflow points to print.
name (str, optional): name to identify this DataFlow.
max_depth (int, optional): stop output when too deep recursion in sub elements
max_list (int, optional): stop output when too many sub elements
"""
super(PrintData, self).__init__(ds)
self.num = num
......@@ -684,8 +693,10 @@ class PrintData(ProxyDataFlow):
else:
self.name = name
self.cnt = 0
self.max_depth = max_depth
self.max_list = max_list
def _analyze_input_data(self, entry, k, depth=1):
def _analyze_input_data(self, entry, k, depth=1, max_depth=3, max_list=3):
"""
Gather useful debug information from a datapoint.
......@@ -693,52 +704,72 @@ class PrintData(ProxyDataFlow):
entry: the datapoint component
k (int): index of this compoennt in current datapoint
depth (int, optional): recursion depth
Todo:
* call this recursively and stop when depth>n for some n if an element is a list
max_depth, max_list: same as in :meth:`__init__`.
Returns:
string: debug message
"""
el = entry
if isinstance(el, list):
return "%s is list of %i elements" % (" " * (depth * 2), len(el))
else:
el_type = el.__class__.__name__
class _elementInfo(object):
def __init__(self, el, pos, depth=0, max_list=3):
self.shape = ""
self.type = type(el).__name__
self.dtype = ""
self.range = ""
self.sub_elements = []
self.ident = " " * (depth * 2)
self.pos = pos
numpy_scalar_types = list(itertools.chain(*np.sctypes.values()))
if isinstance(el, (int, float, bool)):
el_max = el_min = el
el_shape = "()"
el_range = el
self.range = " with value {}".format(el)
elif type(el) is np.ndarray:
self.shape = " of shape {}".format(el.shape)
self.dtype = ":{}".format(str(el.dtype))
self.range = " in range [{}, {}]".format(el.min(), el.max())
elif type(el) in numpy_scalar_types:
self.range = " with value {}".format(el)
elif isinstance(el, (list)):
self.shape = " of len {}".format(len(el))
if depth < max_depth:
for k, subel in enumerate(el):
if k < max_list:
self.sub_elements.append(_elementInfo(subel, k, depth + 1, max_list))
else:
self.sub_elements.append(" " * ((depth + 1) * 2) + '...')
break
else:
el_shape = "n.A."
if hasattr(el, 'shape'):
el_shape = el.shape
if len(el) > 0:
self.sub_elements.append(" " * ((depth + 1) * 2) + ' ...')
el_max, el_min = None, None
if hasattr(el, 'max'):
el_max = el.max()
if hasattr(el, 'min'):
el_min = el.min()
def __str__(self):
strings = []
vals = (self.ident, self.pos, self.type, self.dtype, self.shape, self.range)
strings.append("{}{}: {}{}{}{}".format(*vals))
el_range = ("None, None")
if el_max is not None or el_min is not None:
el_range = "%.4f, %.4f" % (el_min, el_max)
for k, el in enumerate(self.sub_elements):
strings.append(str(el))
return "\n".join(strings)
return ("%s dp %i: is %s of shape %s with range [%s]" % (" " * (depth * 2), k, el_type, el_shape, el_range))
return str(_elementInfo(entry, k, depth, max_list))
def _get_msg(self, dp):
msg = [u"datapoint %i<%i with %i components consists of" % (self.cnt, self.num, len(dp))]
for k, entry in enumerate(dp):
msg.append(self._analyze_input_data(entry, k))
msg.append(self._analyze_input_data(entry, k, max_depth=self.max_depth, max_list=self.max_list))
return u'\n'.join(msg)
def get_data(self):
for dp in self.ds.get_data():
# it is important to place this here! otherwise it mixes the output of multiple PrintData
if self.cnt == 0:
label = "" if self.name is None else " (" + self.label + ")"
label = ' (%s)' % self.name if self.name is not None else ""
logger.info(colored("DataFlow Info%s:" % label, 'cyan'))
for dp in self.ds.get_data():
if self.cnt < self.num:
print(self._get_msg(dp))
self.cnt += 1
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment