Commit 6c905896 authored by Yuxin Wu's avatar Yuxin Wu

Make PrintData equivalent to the underlying ds.

parent d7f92444
......@@ -5,10 +5,10 @@
from __future__ import division
import numpy as np
from copy import copy
import itertools
from termcolor import colored
from collections import deque, defaultdict
from six.moves import range, map
from .base import DataFlow, ProxyDataFlow, RNGDataFlow
from ..utils import logger
from ..utils.utils import get_tqdm, get_rng
......@@ -586,11 +586,7 @@ class CacheData(ProxyDataFlow):
class PrintData(ProxyDataFlow):
"""
Behave like an identity mapping but print shape and range of the first datapoint once during construction.
Attributes:
label (str): label to identify the data when using this debugging on multiple places.
num (int): number of iterations
Behave like an identity mapping, but print shape and range of the first few datapoints.
Example:
To enable this debugging output, you should place it somewhere in your dataflow like
......@@ -624,9 +620,9 @@ class PrintData(ProxyDataFlow):
def __init__(self, ds, num=1, label=None, name=None):
"""
Args:
ds (DataFlow): input DataFlow.
num (int): number of dataflow points to print.
label (str, optional): label to identify this call, when using multiple times
ds(DataFlow): input DataFlow.
num(int): number of dataflow points to print.
name(str, optional): name to identify this DataFlow.
"""
super(PrintData, self).__init__(ds)
self.num = num
......@@ -636,14 +632,15 @@ class PrintData(ProxyDataFlow):
self.name = label
else:
self.name = name
self.cnt = 0
def _analyze_input_data(self, el, k, depth=1):
def _analyze_input_data(self, entry, k, depth=1):
"""
Gather useful debug information from a datapoint.
Args:
el: Description
k (int): position in current datapoint
entry: the datapoint component
k (int): index of this compoennt in current datapoint
depth (int, optional): recursion depth
Todo:
......@@ -652,6 +649,7 @@ class PrintData(ProxyDataFlow):
Returns:
string: debug message
"""
el = entry
if isinstance(el, list):
return "%s is list of %i elements" % (" " * (depth * 2), len(el))
else:
......@@ -678,22 +676,23 @@ class PrintData(ProxyDataFlow):
return ("%s dp %i: is %s of shape %s with range [%s]" % (" " * (depth * 2), k, el_type, el_shape, el_range))
def print_info(self):
"""
Dump gathered debugging information to stdout.
"""
label = "" if self.name is None else " (" + self.label + ")"
logger.info(colored("DataFlow Info%s:" % label, 'cyan'))
for i, dummy in enumerate(itertools.islice(self.ds.get_data(), self.num)):
if isinstance(dummy, list):
msg = "datapoint %i<%i with %i components consists of\n" % (i, self.num, len(dummy))
for k, entry in enumerate(dummy):
msg += self._analyze_input_data(entry, k) + '\n'
print(msg)
def _get_msg(self, dp):
msg = [u"datapoint %i<%i with %i components consists of" % (self.cnt, self.num, len(dp))]
for k, entry in enumerate(dp):
msg.append(self._analyze_input_data(entry, k))
return u'\n'.join(msg)
# reset again after print
self.ds.reset_state()
def get_data(self):
if self.cnt == 0:
label = "" if self.name is None else " (" + self.label + ")"
logger.info(colored("DataFlow Info%s:" % label, 'cyan'))
for dp in self.ds.get_data():
if self.cnt < self.num:
print(self._get_msg(dp))
self.cnt += 1
yield dp
def reset_state(self):
super(PrintData, self).reset_state()
self.print_info()
self.cnt = 0
......@@ -52,12 +52,13 @@ class PrefetchData(ProxyDataFlow):
Note:
1. The underlying dataflow worker will be forked multiple times When ``nr_proc>1``.
As a result, unless the underlying dataflow is fully shuffled, the data distribution
produced by this dataflow will be wrong.
produced by this dataflow will be different.
(e.g. you are likely to see duplicated datapoints at the beginning)
2. This is significantly slower than :class:`PrefetchDataZMQ` when data is large.
3. When nesting like this: ``PrefetchDataZMQ(PrefetchData(df, nr_proc=a), nr_proc=b)``.
A total of ``a`` instances of ``df`` worker processes will be created.
This is different from the behavior of :class:`PrefetchDataZMQ`
4. `reset_state()` is a no-op. The worker processes won't get called.
"""
def __init__(self, ds, nr_prefetch, nr_proc):
"""
......@@ -123,7 +124,7 @@ class PrefetchDataZMQ(ProxyDataFlow):
Note:
1. The underlying dataflow worker will be forked multiple times When ``nr_proc>1``.
As a result, unless the underlying dataflow is fully shuffled, the data distribution
produced by this dataflow will be wrong.
produced by this dataflow will be different.
(e.g. you are likely to see duplicated datapoints at the beginning)
2. Once :meth:`reset_state` is called, this dataflow becomes not fork-safe.
i.e., if you fork an already reset instance of this dataflow,
......@@ -133,8 +134,10 @@ class PrefetchDataZMQ(ProxyDataFlow):
Also in this case, some zmq pipes cannot be cleaned at exit.
4. A local directory is needed to put the ZMQ pipes.
You can set this with env var ``$TENSORPACK_PIPEDIR`` if you're
running on certain non-local FS that may not support pipes, such as NFS or GlusterFS.
running on non-local FS that doesn't support pipes very well, such as NFS or GlusterFS.
Please note that some non-local FS may appear to support pipes and code
may appear to run but crash with bizarre error.
5. Calling `reset_state()` more than once is a no-op, i.e. the worker processes won't get called.
"""
def __init__(self, ds, nr_proc=1, hwm=50):
"""
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment