Commit b8349bcf authored by Patrick Wieschollek's avatar Patrick Wieschollek Committed by Yuxin Wu

Printdata (#657)

* a simple generic unction to visualize tensors allowing the copy-past
in other examples

* PrintData was buggy

This fixes the output order of PrintData and supports a constrained
recursion to inspect nested objects. Further, it supports more types
and keeps the output minimal by providing only necessary information
(no shape information for scalars).

* fix doc

* cleanup
parent 68edaa0c
...@@ -45,6 +45,23 @@ def BNLReLU(x, name=None): ...@@ -45,6 +45,23 @@ def BNLReLU(x, name=None):
return tf.nn.leaky_relu(x, alpha=0.2, name=name) return tf.nn.leaky_relu(x, alpha=0.2, name=name)
def visualize_tensors(name, imgs, scale_func=lambda x: (x + 1.) * 128., max_outputs=1):
"""Generate tensor for TensorBoard (casting, clipping)
Args:
name: name for visualization operation
*imgs: multiple tensors as list
scale_func: scale input tensors to fit range [0, 255]
Example:
visualize_tensors('viz1', [img1])
visualize_tensors('viz2', [img1, img2, img3], max_outputs=max(30, BATCH))
"""
xy = scale_func(tf.concat(imgs, axis=2))
xy = tf.cast(tf.clip_by_value(xy, 0, 255), tf.uint8, name='viz')
tf.summary.image(name, xy, max_outputs=30)
class Model(GANModelDesc): class Model(GANModelDesc):
def _get_inputs(self): def _get_inputs(self):
SHAPE = 256 SHAPE = 256
...@@ -122,9 +139,8 @@ class Model(GANModelDesc): ...@@ -122,9 +139,8 @@ class Model(GANModelDesc):
if OUT_CH == 1: if OUT_CH == 1:
output = tf.image.grayscale_to_rgb(output) output = tf.image.grayscale_to_rgb(output)
fake_output = tf.image.grayscale_to_rgb(fake_output) fake_output = tf.image.grayscale_to_rgb(fake_output)
viz = (tf.concat([input, output, fake_output], 2) + 1.0) * 128.0
viz = tf.cast(tf.clip_by_value(viz, 0, 255), tf.uint8, name='viz') visualize_tensors('input,output,fake', [input, output, fake_output], max_outputs=max(30, BATCH))
tf.summary.image('input,output,fake', viz, max_outputs=max(30, BATCH))
self.collect_variables() self.collect_variables()
......
...@@ -52,7 +52,7 @@ class Model(ModelDesc): ...@@ -52,7 +52,7 @@ class Model(ModelDesc):
.MaxPooling('pool1', 2) .MaxPooling('pool1', 2)
.Conv2D('conv3') .Conv2D('conv3')
.FullyConnected('fc0', 512, activation=tf.nn.relu) .FullyConnected('fc0', 512, activation=tf.nn.relu)
.Dropout('dropout', 0.5) .Dropout('dropout', rate=0.5)
.FullyConnected('fc1', 10, activation=tf.identity)()) .FullyConnected('fc1', 10, activation=tf.identity)())
tf.nn.softmax(logits, name='prob') # a Bx10 with probabilities tf.nn.softmax(logits, name='prob') # a Bx10 with probabilities
...@@ -97,6 +97,9 @@ class Model(ModelDesc): ...@@ -97,6 +97,9 @@ class Model(ModelDesc):
def get_data(): def get_data():
train = BatchData(dataset.Mnist('train'), 128) train = BatchData(dataset.Mnist('train'), 128)
test = BatchData(dataset.Mnist('test'), 256, remainder=True) test = BatchData(dataset.Mnist('test'), 256, remainder=True)
train = PrintData(train)
return train, test return train, test
......
...@@ -6,6 +6,7 @@ from __future__ import division ...@@ -6,6 +6,7 @@ from __future__ import division
import numpy as np import numpy as np
from copy import copy from copy import copy
import pprint import pprint
import itertools
from termcolor import colored from termcolor import colored
from collections import deque, defaultdict from collections import deque, defaultdict
from six.moves import range, map from six.moves import range, map
...@@ -645,9 +646,9 @@ class PrintData(ProxyDataFlow): ...@@ -645,9 +646,9 @@ class PrintData(ProxyDataFlow):
.. code-block:: python .. code-block:: python
def get_data(): def get_data():
ds = CaffeLMDB('path/to/lmdb') ds = SomeDataSource('path/to/lmdb')
ds = SomeInscrutableMappings(ds) ds = SomeInscrutableMappings(ds)
ds = PrintData(ds, num=2) ds = PrintData(ds, num=2, max_list=2)
return ds return ds
ds = get_data() ds = get_data()
...@@ -657,23 +658,31 @@ class PrintData(ProxyDataFlow): ...@@ -657,23 +658,31 @@ class PrintData(ProxyDataFlow):
[0110 09:22:21 @common.py:589] DataFlow Info: [0110 09:22:21 @common.py:589] DataFlow Info:
datapoint 0<2 with 4 components consists of datapoint 0<2 with 4 components consists of
dp 0: is float of shape () with range [0.0816501893251] 0: float with value 0.0816501893251
dp 1: is ndarray of shape (64, 64) with range [0.1300, 0.6895] 1: ndarray:int32 of shape (64,) in range [0, 10]
dp 2: is ndarray of shape (64, 64) with range [-1.2248, 1.2177] 2: ndarray:float32 of shape (64, 64) in range [-1.2248, 1.2177]
dp 3: is ndarray of shape (9, 9) with range [-0.6045, 0.6045] 3: list of len 50
0: ndarray:int32 of shape (64, 64) in range [-128, 80]
1: ndarray:float32 of shape (64, 64) in range [0.8400, 0.6845]
...
datapoint 1<2 with 4 components consists of datapoint 1<2 with 4 components consists of
dp 0: is float of shape () with range [5.88252075399] 0: float with value 5.88252075399
dp 1: is ndarray of shape (64, 64) with range [0.0072, 0.9371] 1: ndarray:int32 of shape (64,) in range [0, 10]
dp 2: is ndarray of shape (64, 64) with range [-0.9011, 0.8491] 2: ndarray:float32 of shape (64, 64) with range [-0.9011, 0.8491]
dp 3: is ndarray of shape (9, 9) with range [-0.5585, 0.5585] 3: list of len 50
0: ndarray:int32 of shape (64, 64) in range [-70, 50]
1: ndarray:float32 of shape (64, 64) in range [0.7400, 0.3545]
...
""" """
def __init__(self, ds, num=1, label=None, name=None): def __init__(self, ds, num=1, label=None, name=None, max_depth=3, max_list=3):
""" """
Args: Args:
ds(DataFlow): input DataFlow. ds (DataFlow): input DataFlow.
num(int): number of dataflow points to print. num (int): number of dataflow points to print.
name(str, optional): name to identify this DataFlow. name (str, optional): name to identify this DataFlow.
max_depth (int, optional): stop output when too deep recursion in sub elements
max_list (int, optional): stop output when too many sub elements
""" """
super(PrintData, self).__init__(ds) super(PrintData, self).__init__(ds)
self.num = num self.num = num
...@@ -684,8 +693,10 @@ class PrintData(ProxyDataFlow): ...@@ -684,8 +693,10 @@ class PrintData(ProxyDataFlow):
else: else:
self.name = name self.name = name
self.cnt = 0 self.cnt = 0
self.max_depth = max_depth
self.max_list = max_list
def _analyze_input_data(self, entry, k, depth=1): def _analyze_input_data(self, entry, k, depth=1, max_depth=3, max_list=3):
""" """
Gather useful debug information from a datapoint. Gather useful debug information from a datapoint.
...@@ -693,52 +704,72 @@ class PrintData(ProxyDataFlow): ...@@ -693,52 +704,72 @@ class PrintData(ProxyDataFlow):
entry: the datapoint component entry: the datapoint component
k (int): index of this compoennt in current datapoint k (int): index of this compoennt in current datapoint
depth (int, optional): recursion depth depth (int, optional): recursion depth
max_depth, max_list: same as in :meth:`__init__`.
Todo:
* call this recursively and stop when depth>n for some n if an element is a list
Returns: Returns:
string: debug message string: debug message
""" """
el = entry
if isinstance(el, list): class _elementInfo(object):
return "%s is list of %i elements" % (" " * (depth * 2), len(el)) def __init__(self, el, pos, depth=0, max_list=3):
else: self.shape = ""
el_type = el.__class__.__name__ self.type = type(el).__name__
self.dtype = ""
self.range = ""
self.sub_elements = []
self.ident = " " * (depth * 2)
self.pos = pos
numpy_scalar_types = list(itertools.chain(*np.sctypes.values()))
if isinstance(el, (int, float, bool)): if isinstance(el, (int, float, bool)):
el_max = el_min = el self.range = " with value {}".format(el)
el_shape = "()" elif type(el) is np.ndarray:
el_range = el self.shape = " of shape {}".format(el.shape)
self.dtype = ":{}".format(str(el.dtype))
self.range = " in range [{}, {}]".format(el.min(), el.max())
elif type(el) in numpy_scalar_types:
self.range = " with value {}".format(el)
elif isinstance(el, (list)):
self.shape = " of len {}".format(len(el))
if depth < max_depth:
for k, subel in enumerate(el):
if k < max_list:
self.sub_elements.append(_elementInfo(subel, k, depth + 1, max_list))
else:
self.sub_elements.append(" " * ((depth + 1) * 2) + '...')
break
else: else:
el_shape = "n.A." if len(el) > 0:
if hasattr(el, 'shape'): self.sub_elements.append(" " * ((depth + 1) * 2) + ' ...')
el_shape = el.shape
el_max, el_min = None, None def __str__(self):
if hasattr(el, 'max'): strings = []
el_max = el.max() vals = (self.ident, self.pos, self.type, self.dtype, self.shape, self.range)
if hasattr(el, 'min'): strings.append("{}{}: {}{}{}{}".format(*vals))
el_min = el.min()
el_range = ("None, None") for k, el in enumerate(self.sub_elements):
if el_max is not None or el_min is not None: strings.append(str(el))
el_range = "%.4f, %.4f" % (el_min, el_max) return "\n".join(strings)
return ("%s dp %i: is %s of shape %s with range [%s]" % (" " * (depth * 2), k, el_type, el_shape, el_range)) return str(_elementInfo(entry, k, depth, max_list))
def _get_msg(self, dp): def _get_msg(self, dp):
msg = [u"datapoint %i<%i with %i components consists of" % (self.cnt, self.num, len(dp))] msg = [u"datapoint %i<%i with %i components consists of" % (self.cnt, self.num, len(dp))]
for k, entry in enumerate(dp): for k, entry in enumerate(dp):
msg.append(self._analyze_input_data(entry, k)) msg.append(self._analyze_input_data(entry, k, max_depth=self.max_depth, max_list=self.max_list))
return u'\n'.join(msg) return u'\n'.join(msg)
def get_data(self): def get_data(self):
for dp in self.ds.get_data():
# it is important to place this here! otherwise it mixes the output of multiple PrintData
if self.cnt == 0: if self.cnt == 0:
label = "" if self.name is None else " (" + self.label + ")" label = ' (%s)' % self.name if self.name is not None else ""
logger.info(colored("DataFlow Info%s:" % label, 'cyan')) logger.info(colored("DataFlow Info%s:" % label, 'cyan'))
for dp in self.ds.get_data():
if self.cnt < self.num: if self.cnt < self.num:
print(self._get_msg(dp)) print(self._get_msg(dp))
self.cnt += 1 self.cnt += 1
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment