Commit ff0f2cf7 authored by Patrick Wieschollek's avatar Patrick Wieschollek Committed by Yuxin Wu

More verbose information about dataflow (#90)

Add PrintData to verify the shapes of incoming data from the dataflow it is good
to actually output them for debugging reasons.
parent 827791cb
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
from __future__ import division from __future__ import division
import numpy as np import numpy as np
from termcolor import colored
from collections import deque, defaultdict from collections import deque, defaultdict
from six.moves import range, map from six.moves import range, map
from .base import DataFlow, ProxyDataFlow, RNGDataFlow from .base import DataFlow, ProxyDataFlow, RNGDataFlow
...@@ -12,7 +13,7 @@ from ..utils import logger, get_tqdm ...@@ -12,7 +13,7 @@ from ..utils import logger, get_tqdm
__all__ = ['TestDataSpeed', 'BatchData', 'BatchDataByShape', 'FixedSizeData', 'MapData', __all__ = ['TestDataSpeed', 'BatchData', 'BatchDataByShape', 'FixedSizeData', 'MapData',
'MapDataComponent', 'RepeatedData', 'RandomChooseData', 'MapDataComponent', 'RepeatedData', 'RandomChooseData',
'RandomMixData', 'JoinData', 'ConcatData', 'SelectComponent', 'RandomMixData', 'JoinData', 'ConcatData', 'SelectComponent',
'LocallyShuffleData'] 'LocallyShuffleData', 'PrintData']
class TestDataSpeed(ProxyDataFlow): class TestDataSpeed(ProxyDataFlow):
...@@ -471,3 +472,130 @@ class LocallyShuffleData(ProxyDataFlow, RNGDataFlow): ...@@ -471,3 +472,130 @@ class LocallyShuffleData(ProxyDataFlow, RNGDataFlow):
for v in self.q: for v in self.q:
yield v yield v
return return
class PrintData(ProxyDataFlow):
"""
Behave like an identity mapping but print shapes of produced datapoints once during construction.
Attributes:
label (str): label to identify the data when using this debugging on multiple places
num (int): number of iterations
Example:
To enable this debugging output, you should place it somewhere in your dataflow like
def get_data():
ds = CaffeLMDB('path/to/lmdb')
ds = SomeInscrutableMappings(ds)
ds = PrintData(ds, num=2)
return ds
ds = get_data()
The output looks like:
[0110 09:22:21 @common.py:589] DataFlow Info:
datapoint 0<2 with 4 elements consists of
dp 0: is float of shape () with range [0.0816501893251]
dp 1: is ndarray of shape (64, 64) with range [0.1300, 0.6895]
dp 2: is ndarray of shape (64, 64) with range [-1.2248, 1.2177]
dp 3: is ndarray of shape (9, 9) with range [-0.6045, 0.6045]
datapoint 1<2 with 4 elements consists of
dp 0: is float of shape () with range [5.88252075399]
dp 1: is ndarray of shape (64, 64) with range [0.0072, 0.9371]
dp 2: is ndarray of shape (64, 64) with range [-0.9011, 0.8491]
dp 3: is ndarray of shape (9, 9) with range [-0.5585, 0.5585]
"""
def __init__(self, ds, num=1, label=""):
"""
Args:
ds (DataFlow): input DataFlow.
num (int): number of dataflow points.
label (str, optional): label to identify this call, when using multiple times
"""
super(PrintData, self).__init__(ds)
self.num = num
self.label = label
self.print_info()
def analyze_input_data(self, el, k, depth=1):
"""
Gather useful debug information from a datapoint.
Args:
el: Description
k (int): position in current datapoint
depth (int, optional): recursion depth
Todo:
* call this recursively and stop when depth>n for some n if an element is a list
Returns:
string: debug message
"""
if isinstance(el, list):
return "%s is list of %i elements " % (" " * (depth * 2), len(el))
else:
el_type = el.__class__.__name__
if isinstance(el, (int, float, bool)):
el_max = el_min = el
el_shape = "()"
el_range = el
else:
el_shape = "n.A."
if hasattr(el, 'shape'):
el_shape = el.shape
el_max, el_min = None, None
if hasattr(el, 'max'):
el_max = el.max()
if hasattr(el, 'min'):
el_min = el.min()
el_range = ("None, None")
if el_max is not None or el_min is not None:
el_range = "%.4f, %.4f" % (el_min, el_max)
return ("%s dp %i: is %s of shape %s with range [%s]" % (" " * (depth * 2), k, el_type, el_shape, el_range))
def print_info(self):
"""
Dump gathered debugging information to stdout.
"""
def cutoff(gen, num=1):
"""
Stop a generator after n iterations.
Args:
gen (PyGenObject): arbitrary generator
num (int, optional): number of maximal iterations
Yields:
element from generator object
"""
c = 0
for el in gen:
yield el
c += 1
if c == num:
break
ds = self.ds
ds.reset_state()
msg = [""]
for i, dummy in enumerate(cutoff(ds.get_data(), self.num)):
if isinstance(dummy, list):
msg.append("datapoint %i<%i with %i elements consists of" % (i, self.num, len(dummy)))
for k, entry in enumerate(dummy):
msg.append(self.analyze_input_data(entry, k))
label = "" if self.label is "" else " (" + self.label + ")"
logger.info(colored("DataFlow Info%s:" % label, 'cyan') + '\n'.join(msg))
self.ds.reset_state()
def get_data(self):
return self.ds.get_data()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment