Commit 81127236 authored by Yuxin Wu's avatar Yuxin Wu

Fix table formatting in logs; Add support for dicts in PrintData

parent 1cd536a9
......@@ -22,7 +22,7 @@ from common import (
filter_boxes_inside_shape, np_iou, point8_to_box, polygons_to_mask,
)
from config import config as cfg
from dataset import DatasetRegistry
from dataset import DatasetRegistry, register_coco
from utils.np_box_ops import area as np_area
from utils.np_box_ops import ioa as np_ioa
......@@ -50,7 +50,7 @@ def print_class_histogram(roidbs):
gt_classes = entry["class"][gt_inds]
gt_hist += np.histogram(gt_classes, bins=hist_bins)[0]
data = list(itertools.chain(*[[class_names[i + 1], v] for i, v in enumerate(gt_hist[1:])]))
COL = max(6, len(data))
COL = min(6, len(data))
total_instances = sum(data[1::2])
data.extend([None] * (COL - len(data) % COL))
data.extend(["total", total_instances])
......@@ -394,11 +394,12 @@ def get_eval_dataflow(name, shard=0, num_shards=1):
if __name__ == "__main__":
import os
from tensorpack.dataflow import PrintData
from config import finalize_configs
cfg.DATA.BASEDIR = os.path.expanduser("~/data/coco")
register_coco(os.path.expanduser("~/data/coco"))
finalize_configs()
ds = get_train_dataflow()
ds = PrintData(ds, 100)
ds = PrintData(ds, 10)
TestDataSpeed(ds, 50000).start()
ds.reset_state()
for k in ds:
pass
......@@ -5,7 +5,7 @@ from __future__ import division
import itertools
import numpy as np
import pprint
from collections import defaultdict, deque
from collections import defaultdict, deque, Mapping
from copy import copy
import six
import tqdm
......@@ -748,7 +748,7 @@ class PrintData(ProxyDataFlow):
Gather useful debug information from a datapoint.
Args:
entry: the datapoint component
entry: the datapoint component, either a list or a dict
k (int): index of this component in current datapoint
depth (int, optional): recursion depth
max_depth, max_list: same as in :meth:`__init__`.
......@@ -779,7 +779,7 @@ class PrintData(ProxyDataFlow):
self.range = " in range [{}, {}]".format(el.min(), el.max())
elif type(el) in numpy_scalar_types:
self.range = " with value {}".format(el)
elif isinstance(el, (list)):
elif isinstance(el, (list, tuple)):
self.shape = " of len {}".format(len(el))
if depth < max_depth:
......@@ -805,9 +805,15 @@ class PrintData(ProxyDataFlow):
return str(_elementInfo(entry, k, depth, max_list))
def _get_msg(self, dp):
msg = [u"datapoint %i<%i with %i components consists of" % (self.cnt, self.num, len(dp))]
msg = [colored(u"datapoint %i/%i with %i components consists of" %
(self.cnt, self.num, len(dp)), "cyan")]
is_dict = isinstance(dp, Mapping)
for k, entry in enumerate(dp):
msg.append(self._analyze_input_data(entry, k, max_depth=self.max_depth, max_list=self.max_list))
if is_dict:
key, value = entry, dp[entry]
else:
key, value = k, entry
msg.append(self._analyze_input_data(value, key, max_depth=self.max_depth, max_list=self.max_list))
return u'\n'.join(msg)
def __iter__(self):
......@@ -815,7 +821,7 @@ class PrintData(ProxyDataFlow):
# it is important to place this here! otherwise it mixes the output of multiple PrintData
if self.cnt == 0:
label = ' (%s)' % self.name if self.name is not None else ""
logger.info(colored("DataFlow Info%s:" % label, 'cyan'))
logger.info(colored("Contents of DataFlow%s:" % label, 'cyan'))
if self.cnt < self.num:
print(self._get_msg(dp))
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment