Commit ff0f2cf7 authored by Patrick Wieschollek's avatar Patrick Wieschollek Committed by Yuxin Wu

More verbose information about dataflow (#90)

Add PrintData to verify the shapes of incoming data from the dataflow it is good
to actually output them for debugging reasons.
parent 827791cb
......@@ -4,6 +4,7 @@
from __future__ import division
import numpy as np
from termcolor import colored
from collections import deque, defaultdict
from six.moves import range, map
from .base import DataFlow, ProxyDataFlow, RNGDataFlow
......@@ -12,7 +13,7 @@ from ..utils import logger, get_tqdm
__all__ = ['TestDataSpeed', 'BatchData', 'BatchDataByShape', 'FixedSizeData', 'MapData',
'MapDataComponent', 'RepeatedData', 'RandomChooseData',
'RandomMixData', 'JoinData', 'ConcatData', 'SelectComponent',
'LocallyShuffleData']
'LocallyShuffleData', 'PrintData']
class TestDataSpeed(ProxyDataFlow):
......@@ -471,3 +472,130 @@ class LocallyShuffleData(ProxyDataFlow, RNGDataFlow):
for v in self.q:
yield v
return
class PrintData(ProxyDataFlow):
"""
Behave like an identity mapping but print shapes of produced datapoints once during construction.
Attributes:
label (str): label to identify the data when using this debugging on multiple places
num (int): number of iterations
Example:
To enable this debugging output, you should place it somewhere in your dataflow like
def get_data():
ds = CaffeLMDB('path/to/lmdb')
ds = SomeInscrutableMappings(ds)
ds = PrintData(ds, num=2)
return ds
ds = get_data()
The output looks like:
[0110 09:22:21 @common.py:589] DataFlow Info:
datapoint 0<2 with 4 elements consists of
dp 0: is float of shape () with range [0.0816501893251]
dp 1: is ndarray of shape (64, 64) with range [0.1300, 0.6895]
dp 2: is ndarray of shape (64, 64) with range [-1.2248, 1.2177]
dp 3: is ndarray of shape (9, 9) with range [-0.6045, 0.6045]
datapoint 1<2 with 4 elements consists of
dp 0: is float of shape () with range [5.88252075399]
dp 1: is ndarray of shape (64, 64) with range [0.0072, 0.9371]
dp 2: is ndarray of shape (64, 64) with range [-0.9011, 0.8491]
dp 3: is ndarray of shape (9, 9) with range [-0.5585, 0.5585]
"""
def __init__(self, ds, num=1, label=""):
"""
Args:
ds (DataFlow): input DataFlow.
num (int): number of dataflow points.
label (str, optional): label to identify this call, when using multiple times
"""
super(PrintData, self).__init__(ds)
self.num = num
self.label = label
self.print_info()
def analyze_input_data(self, el, k, depth=1):
"""
Gather useful debug information from a datapoint.
Args:
el: Description
k (int): position in current datapoint
depth (int, optional): recursion depth
Todo:
* call this recursively and stop when depth>n for some n if an element is a list
Returns:
string: debug message
"""
if isinstance(el, list):
return "%s is list of %i elements " % (" " * (depth * 2), len(el))
else:
el_type = el.__class__.__name__
if isinstance(el, (int, float, bool)):
el_max = el_min = el
el_shape = "()"
el_range = el
else:
el_shape = "n.A."
if hasattr(el, 'shape'):
el_shape = el.shape
el_max, el_min = None, None
if hasattr(el, 'max'):
el_max = el.max()
if hasattr(el, 'min'):
el_min = el.min()
el_range = ("None, None")
if el_max is not None or el_min is not None:
el_range = "%.4f, %.4f" % (el_min, el_max)
return ("%s dp %i: is %s of shape %s with range [%s]" % (" " * (depth * 2), k, el_type, el_shape, el_range))
def print_info(self):
"""
Dump gathered debugging information to stdout.
"""
def cutoff(gen, num=1):
"""
Stop a generator after n iterations.
Args:
gen (PyGenObject): arbitrary generator
num (int, optional): number of maximal iterations
Yields:
element from generator object
"""
c = 0
for el in gen:
yield el
c += 1
if c == num:
break
ds = self.ds
ds.reset_state()
msg = [""]
for i, dummy in enumerate(cutoff(ds.get_data(), self.num)):
if isinstance(dummy, list):
msg.append("datapoint %i<%i with %i elements consists of" % (i, self.num, len(dummy)))
for k, entry in enumerate(dummy):
msg.append(self.analyze_input_data(entry, k))
label = "" if self.label is "" else " (" + self.label + ")"
logger.info(colored("DataFlow Info%s:" % label, 'cyan') + '\n'.join(msg))
self.ds.reset_state()
def get_data(self):
return self.ds.get_data()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment