Commit 6c905896 authored by Yuxin Wu's avatar Yuxin Wu

Make PrintData equivalent to the underlying ds.

parent d7f92444
...@@ -5,10 +5,10 @@ ...@@ -5,10 +5,10 @@
from __future__ import division from __future__ import division
import numpy as np import numpy as np
from copy import copy from copy import copy
import itertools
from termcolor import colored from termcolor import colored
from collections import deque, defaultdict from collections import deque, defaultdict
from six.moves import range, map from six.moves import range, map
from .base import DataFlow, ProxyDataFlow, RNGDataFlow from .base import DataFlow, ProxyDataFlow, RNGDataFlow
from ..utils import logger from ..utils import logger
from ..utils.utils import get_tqdm, get_rng from ..utils.utils import get_tqdm, get_rng
...@@ -586,11 +586,7 @@ class CacheData(ProxyDataFlow): ...@@ -586,11 +586,7 @@ class CacheData(ProxyDataFlow):
class PrintData(ProxyDataFlow): class PrintData(ProxyDataFlow):
""" """
Behave like an identity mapping but print shape and range of the first datapoint once during construction. Behave like an identity mapping, but print shape and range of the first few datapoints.
Attributes:
label (str): label to identify the data when using this debugging on multiple places.
num (int): number of iterations
Example: Example:
To enable this debugging output, you should place it somewhere in your dataflow like To enable this debugging output, you should place it somewhere in your dataflow like
...@@ -624,9 +620,9 @@ class PrintData(ProxyDataFlow): ...@@ -624,9 +620,9 @@ class PrintData(ProxyDataFlow):
def __init__(self, ds, num=1, label=None, name=None): def __init__(self, ds, num=1, label=None, name=None):
""" """
Args: Args:
ds (DataFlow): input DataFlow. ds(DataFlow): input DataFlow.
num (int): number of dataflow points to print. num(int): number of dataflow points to print.
label (str, optional): label to identify this call, when using multiple times name(str, optional): name to identify this DataFlow.
""" """
super(PrintData, self).__init__(ds) super(PrintData, self).__init__(ds)
self.num = num self.num = num
...@@ -636,14 +632,15 @@ class PrintData(ProxyDataFlow): ...@@ -636,14 +632,15 @@ class PrintData(ProxyDataFlow):
self.name = label self.name = label
else: else:
self.name = name self.name = name
self.cnt = 0
def _analyze_input_data(self, el, k, depth=1): def _analyze_input_data(self, entry, k, depth=1):
""" """
Gather useful debug information from a datapoint. Gather useful debug information from a datapoint.
Args: Args:
el: Description entry: the datapoint component
k (int): position in current datapoint k (int): index of this compoennt in current datapoint
depth (int, optional): recursion depth depth (int, optional): recursion depth
Todo: Todo:
...@@ -652,6 +649,7 @@ class PrintData(ProxyDataFlow): ...@@ -652,6 +649,7 @@ class PrintData(ProxyDataFlow):
Returns: Returns:
string: debug message string: debug message
""" """
el = entry
if isinstance(el, list): if isinstance(el, list):
return "%s is list of %i elements" % (" " * (depth * 2), len(el)) return "%s is list of %i elements" % (" " * (depth * 2), len(el))
else: else:
...@@ -678,22 +676,23 @@ class PrintData(ProxyDataFlow): ...@@ -678,22 +676,23 @@ class PrintData(ProxyDataFlow):
return ("%s dp %i: is %s of shape %s with range [%s]" % (" " * (depth * 2), k, el_type, el_shape, el_range)) return ("%s dp %i: is %s of shape %s with range [%s]" % (" " * (depth * 2), k, el_type, el_shape, el_range))
def print_info(self): def _get_msg(self, dp):
""" msg = [u"datapoint %i<%i with %i components consists of" % (self.cnt, self.num, len(dp))]
Dump gathered debugging information to stdout. for k, entry in enumerate(dp):
""" msg.append(self._analyze_input_data(entry, k))
return u'\n'.join(msg)
def get_data(self):
if self.cnt == 0:
label = "" if self.name is None else " (" + self.label + ")" label = "" if self.name is None else " (" + self.label + ")"
logger.info(colored("DataFlow Info%s:" % label, 'cyan')) logger.info(colored("DataFlow Info%s:" % label, 'cyan'))
for i, dummy in enumerate(itertools.islice(self.ds.get_data(), self.num)):
if isinstance(dummy, list): for dp in self.ds.get_data():
msg = "datapoint %i<%i with %i components consists of\n" % (i, self.num, len(dummy)) if self.cnt < self.num:
for k, entry in enumerate(dummy): print(self._get_msg(dp))
msg += self._analyze_input_data(entry, k) + '\n' self.cnt += 1
print(msg) yield dp
# reset again after print
self.ds.reset_state()
def reset_state(self): def reset_state(self):
super(PrintData, self).reset_state() super(PrintData, self).reset_state()
self.print_info() self.cnt = 0
...@@ -52,12 +52,13 @@ class PrefetchData(ProxyDataFlow): ...@@ -52,12 +52,13 @@ class PrefetchData(ProxyDataFlow):
Note: Note:
1. The underlying dataflow worker will be forked multiple times When ``nr_proc>1``. 1. The underlying dataflow worker will be forked multiple times When ``nr_proc>1``.
As a result, unless the underlying dataflow is fully shuffled, the data distribution As a result, unless the underlying dataflow is fully shuffled, the data distribution
produced by this dataflow will be wrong. produced by this dataflow will be different.
(e.g. you are likely to see duplicated datapoints at the beginning) (e.g. you are likely to see duplicated datapoints at the beginning)
2. This is significantly slower than :class:`PrefetchDataZMQ` when data is large. 2. This is significantly slower than :class:`PrefetchDataZMQ` when data is large.
3. When nesting like this: ``PrefetchDataZMQ(PrefetchData(df, nr_proc=a), nr_proc=b)``. 3. When nesting like this: ``PrefetchDataZMQ(PrefetchData(df, nr_proc=a), nr_proc=b)``.
A total of ``a`` instances of ``df`` worker processes will be created. A total of ``a`` instances of ``df`` worker processes will be created.
This is different from the behavior of :class:`PrefetchDataZMQ` This is different from the behavior of :class:`PrefetchDataZMQ`
4. `reset_state()` is a no-op. The worker processes won't get called.
""" """
def __init__(self, ds, nr_prefetch, nr_proc): def __init__(self, ds, nr_prefetch, nr_proc):
""" """
...@@ -123,7 +124,7 @@ class PrefetchDataZMQ(ProxyDataFlow): ...@@ -123,7 +124,7 @@ class PrefetchDataZMQ(ProxyDataFlow):
Note: Note:
1. The underlying dataflow worker will be forked multiple times When ``nr_proc>1``. 1. The underlying dataflow worker will be forked multiple times When ``nr_proc>1``.
As a result, unless the underlying dataflow is fully shuffled, the data distribution As a result, unless the underlying dataflow is fully shuffled, the data distribution
produced by this dataflow will be wrong. produced by this dataflow will be different.
(e.g. you are likely to see duplicated datapoints at the beginning) (e.g. you are likely to see duplicated datapoints at the beginning)
2. Once :meth:`reset_state` is called, this dataflow becomes not fork-safe. 2. Once :meth:`reset_state` is called, this dataflow becomes not fork-safe.
i.e., if you fork an already reset instance of this dataflow, i.e., if you fork an already reset instance of this dataflow,
...@@ -133,8 +134,10 @@ class PrefetchDataZMQ(ProxyDataFlow): ...@@ -133,8 +134,10 @@ class PrefetchDataZMQ(ProxyDataFlow):
Also in this case, some zmq pipes cannot be cleaned at exit. Also in this case, some zmq pipes cannot be cleaned at exit.
4. A local directory is needed to put the ZMQ pipes. 4. A local directory is needed to put the ZMQ pipes.
You can set this with env var ``$TENSORPACK_PIPEDIR`` if you're You can set this with env var ``$TENSORPACK_PIPEDIR`` if you're
running on certain non-local FS that may not support pipes, such as NFS or GlusterFS. running on non-local FS that doesn't support pipes very well, such as NFS or GlusterFS.
Please note that some non-local FS may appear to support pipes and code
may appear to run but crash with bizarre error.
5. Calling `reset_state()` more than once is a no-op, i.e. the worker processes won't get called.
""" """
def __init__(self, ds, nr_proc=1, hwm=50): def __init__(self, ds, nr_proc=1, hwm=50):
""" """
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment