Commit 54558074 authored by Yuxin Wu's avatar Yuxin Wu

fix summary, update dump scriptt

parent 322449d2
...@@ -15,7 +15,10 @@ from tensorpack.utils.utils import mkdir_p ...@@ -15,7 +15,10 @@ from tensorpack.utils.utils import mkdir_p
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument(dest='config') parser.add_argument(dest='config')
parser.add_argument('-o', '--output', help='output directory to dump dataset image') parser.add_argument('-o', '--output',
help='output directory to dump dataset image. If not given, will not dump images.')
parser.add_argument('-s', '--scale',
help='scale the image data (maybe by 255)', default=1, type=int)
parser.add_argument('--index', parser.add_argument('--index',
help='index of the image component in datapoint', help='index of the image component in datapoint',
default=0, type=int) default=0, type=int)
...@@ -37,7 +40,7 @@ if args.output: ...@@ -37,7 +40,7 @@ if args.output:
for bi, img in enumerate(imgbatch): for bi, img in enumerate(imgbatch):
cnt += 1 cnt += 1
fname = os.path.join(args.output, '{:03d}-{}.png'.format(cnt, bi)) fname = os.path.join(args.output, '{:03d}-{}.png'.format(cnt, bi))
cv2.imwrite(fname, img) cv2.imwrite(fname, img * args.scale)
NR_DP_TEST = 100 NR_DP_TEST = 100
logger.info("Testing dataflow speed:") logger.info("Testing dataflow speed:")
......
...@@ -16,7 +16,7 @@ __all__ = ['StatHolder', 'StatPrinter'] ...@@ -16,7 +16,7 @@ __all__ = ['StatHolder', 'StatPrinter']
class StatHolder(object): class StatHolder(object):
def __init__(self, log_dir, print_tag=None): def __init__(self, log_dir, print_tag=None):
self.print_tag = None if print_tag is None else set(print_tag) self.set_print_tag(print_tag)
self.stat_now = {} self.stat_now = {}
self.log_dir = log_dir self.log_dir = log_dir
...@@ -31,6 +31,9 @@ class StatHolder(object): ...@@ -31,6 +31,9 @@ class StatHolder(object):
def add_stat(self, k, v): def add_stat(self, k, v):
self.stat_now[k] = v self.stat_now[k] = v
def set_print_tag(self, print_tag):
self.print_tag = None if print_tag is None else set(print_tag)
def finalize(self): def finalize(self):
self._print_stat() self._print_stat()
self.stat_history.append(self.stat_now) self.stat_history.append(self.stat_now)
...@@ -56,4 +59,4 @@ class StatPrinter(Callback): ...@@ -56,4 +59,4 @@ class StatPrinter(Callback):
self.print_tag = print_tag self.print_tag = print_tag
def _before_train(self): def _before_train(self):
self.trainer.stat_holder = StatHolder(logger.LOG_DIR, self.print_tag) self.trainer.stat_holder.set_print_tag(self.print_tag)
...@@ -4,8 +4,8 @@ ...@@ -4,8 +4,8 @@
# Author: Yuxin Wu <ppwwyyxx@gmail.com> # Author: Yuxin Wu <ppwwyyxx@gmail.com>
# use user-space protobuf # use user-space protobuf
import sys, os #import sys, os
if not sys.version_info >= (3, 0): #if not sys.version_info >= (3, 0):
site = os.path.join(os.environ['HOME'], #site = os.path.join(os.environ['HOME'],
'.local/lib/python2.7/site-packages') #'.local/lib/python2.7/site-packages')
sys.path.insert(0, site) #sys.path.insert(0, site)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment