Commit 81127236 authored by Yuxin Wu's avatar Yuxin Wu

Fix table formatting in logs; Add support for dicts in PrintData

parent 1cd536a9
...@@ -22,7 +22,7 @@ from common import ( ...@@ -22,7 +22,7 @@ from common import (
filter_boxes_inside_shape, np_iou, point8_to_box, polygons_to_mask, filter_boxes_inside_shape, np_iou, point8_to_box, polygons_to_mask,
) )
from config import config as cfg from config import config as cfg
from dataset import DatasetRegistry from dataset import DatasetRegistry, register_coco
from utils.np_box_ops import area as np_area from utils.np_box_ops import area as np_area
from utils.np_box_ops import ioa as np_ioa from utils.np_box_ops import ioa as np_ioa
...@@ -50,7 +50,7 @@ def print_class_histogram(roidbs): ...@@ -50,7 +50,7 @@ def print_class_histogram(roidbs):
gt_classes = entry["class"][gt_inds] gt_classes = entry["class"][gt_inds]
gt_hist += np.histogram(gt_classes, bins=hist_bins)[0] gt_hist += np.histogram(gt_classes, bins=hist_bins)[0]
data = list(itertools.chain(*[[class_names[i + 1], v] for i, v in enumerate(gt_hist[1:])])) data = list(itertools.chain(*[[class_names[i + 1], v] for i, v in enumerate(gt_hist[1:])]))
COL = max(6, len(data)) COL = min(6, len(data))
total_instances = sum(data[1::2]) total_instances = sum(data[1::2])
data.extend([None] * (COL - len(data) % COL)) data.extend([None] * (COL - len(data) % COL))
data.extend(["total", total_instances]) data.extend(["total", total_instances])
...@@ -394,11 +394,12 @@ def get_eval_dataflow(name, shard=0, num_shards=1): ...@@ -394,11 +394,12 @@ def get_eval_dataflow(name, shard=0, num_shards=1):
if __name__ == "__main__": if __name__ == "__main__":
import os import os
from tensorpack.dataflow import PrintData from tensorpack.dataflow import PrintData
from config import finalize_configs
cfg.DATA.BASEDIR = os.path.expanduser("~/data/coco") register_coco(os.path.expanduser("~/data/coco"))
finalize_configs()
ds = get_train_dataflow() ds = get_train_dataflow()
ds = PrintData(ds, 100) ds = PrintData(ds, 10)
TestDataSpeed(ds, 50000).start() TestDataSpeed(ds, 50000).start()
ds.reset_state()
for k in ds: for k in ds:
pass pass
...@@ -5,7 +5,7 @@ from __future__ import division ...@@ -5,7 +5,7 @@ from __future__ import division
import itertools import itertools
import numpy as np import numpy as np
import pprint import pprint
from collections import defaultdict, deque from collections import defaultdict, deque, Mapping
from copy import copy from copy import copy
import six import six
import tqdm import tqdm
...@@ -748,7 +748,7 @@ class PrintData(ProxyDataFlow): ...@@ -748,7 +748,7 @@ class PrintData(ProxyDataFlow):
Gather useful debug information from a datapoint. Gather useful debug information from a datapoint.
Args: Args:
entry: the datapoint component entry: the datapoint component, either a list or a dict
k (int): index of this component in current datapoint k (int): index of this component in current datapoint
depth (int, optional): recursion depth depth (int, optional): recursion depth
max_depth, max_list: same as in :meth:`__init__`. max_depth, max_list: same as in :meth:`__init__`.
...@@ -779,7 +779,7 @@ class PrintData(ProxyDataFlow): ...@@ -779,7 +779,7 @@ class PrintData(ProxyDataFlow):
self.range = " in range [{}, {}]".format(el.min(), el.max()) self.range = " in range [{}, {}]".format(el.min(), el.max())
elif type(el) in numpy_scalar_types: elif type(el) in numpy_scalar_types:
self.range = " with value {}".format(el) self.range = " with value {}".format(el)
elif isinstance(el, (list)): elif isinstance(el, (list, tuple)):
self.shape = " of len {}".format(len(el)) self.shape = " of len {}".format(len(el))
if depth < max_depth: if depth < max_depth:
...@@ -805,9 +805,15 @@ class PrintData(ProxyDataFlow): ...@@ -805,9 +805,15 @@ class PrintData(ProxyDataFlow):
return str(_elementInfo(entry, k, depth, max_list)) return str(_elementInfo(entry, k, depth, max_list))
def _get_msg(self, dp): def _get_msg(self, dp):
msg = [u"datapoint %i<%i with %i components consists of" % (self.cnt, self.num, len(dp))] msg = [colored(u"datapoint %i/%i with %i components consists of" %
(self.cnt, self.num, len(dp)), "cyan")]
is_dict = isinstance(dp, Mapping)
for k, entry in enumerate(dp): for k, entry in enumerate(dp):
msg.append(self._analyze_input_data(entry, k, max_depth=self.max_depth, max_list=self.max_list)) if is_dict:
key, value = entry, dp[entry]
else:
key, value = k, entry
msg.append(self._analyze_input_data(value, key, max_depth=self.max_depth, max_list=self.max_list))
return u'\n'.join(msg) return u'\n'.join(msg)
def __iter__(self): def __iter__(self):
...@@ -815,7 +821,7 @@ class PrintData(ProxyDataFlow): ...@@ -815,7 +821,7 @@ class PrintData(ProxyDataFlow):
# it is important to place this here! otherwise it mixes the output of multiple PrintData # it is important to place this here! otherwise it mixes the output of multiple PrintData
if self.cnt == 0: if self.cnt == 0:
label = ' (%s)' % self.name if self.name is not None else "" label = ' (%s)' % self.name if self.name is not None else ""
logger.info(colored("DataFlow Info%s:" % label, 'cyan')) logger.info(colored("Contents of DataFlow%s:" % label, 'cyan'))
if self.cnt < self.num: if self.cnt < self.num:
print(self._get_msg(dp)) print(self._get_msg(dp))
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment