Commit b6aced91 authored by Patrick Wieschollek's avatar Patrick Wieschollek Committed by Yuxin Wu

Fix PrintData argument (#303)

It was always confusing, that most functions use `name` and
PrintData expected `label`.
parent 6be159d3
...@@ -11,6 +11,7 @@ from collections import deque, defaultdict ...@@ -11,6 +11,7 @@ from collections import deque, defaultdict
from six.moves import range, map from six.moves import range, map
from .base import DataFlow, ProxyDataFlow, RNGDataFlow from .base import DataFlow, ProxyDataFlow, RNGDataFlow
from ..utils import logger, get_tqdm, get_rng from ..utils import logger, get_tqdm, get_rng
from ..utils.develop import log_deprecated
__all__ = ['TestDataSpeed', 'PrintData', 'BatchData', 'BatchDataByShape', 'FixedSizeData', 'MapData', __all__ = ['TestDataSpeed', 'PrintData', 'BatchData', 'BatchDataByShape', 'FixedSizeData', 'MapData',
'MapDataComponent', 'RepeatedData', 'RepeatedDataPoint', 'RandomChooseData', 'MapDataComponent', 'RepeatedData', 'RepeatedDataPoint', 'RandomChooseData',
...@@ -614,7 +615,7 @@ class PrintData(ProxyDataFlow): ...@@ -614,7 +615,7 @@ class PrintData(ProxyDataFlow):
dp 3: is ndarray of shape (9, 9) with range [-0.5585, 0.5585] dp 3: is ndarray of shape (9, 9) with range [-0.5585, 0.5585]
""" """
def __init__(self, ds, num=1, label=""): def __init__(self, ds, num=1, label=None, name=None):
""" """
Args: Args:
ds (DataFlow): input DataFlow. ds (DataFlow): input DataFlow.
...@@ -623,7 +624,12 @@ class PrintData(ProxyDataFlow): ...@@ -623,7 +624,12 @@ class PrintData(ProxyDataFlow):
""" """
super(PrintData, self).__init__(ds) super(PrintData, self).__init__(ds)
self.num = num self.num = num
self.label = label
if label:
log_deprecated("PrintData(label, ...", "Use PrintData(name, ... instead.")
self.name = label
else:
self.name = name
def _analyze_input_data(self, el, k, depth=1): def _analyze_input_data(self, el, k, depth=1):
""" """
...@@ -676,7 +682,7 @@ class PrintData(ProxyDataFlow): ...@@ -676,7 +682,7 @@ class PrintData(ProxyDataFlow):
msg.append("datapoint %i<%i with %i components consists of" % (i, self.num, len(dummy))) msg.append("datapoint %i<%i with %i components consists of" % (i, self.num, len(dummy)))
for k, entry in enumerate(dummy): for k, entry in enumerate(dummy):
msg.append(self._analyze_input_data(entry, k)) msg.append(self._analyze_input_data(entry, k))
label = "" if self.label is "" else " (" + self.label + ")" label = "" if self.name is None else " (" + self.label + ")"
logger.info(colored("DataFlow Info%s:" % label, 'cyan') + '\n'.join(msg)) logger.info(colored("DataFlow Info%s:" % label, 'cyan') + '\n'.join(msg))
# reset again after print # reset again after print
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment