Commit c653458c authored by Yuxin Wu's avatar Yuxin Wu

fix logdir

parent d2262d1d
...@@ -88,7 +88,7 @@ def get_model(inputs, is_training): ...@@ -88,7 +88,7 @@ def get_model(inputs, is_training):
def get_config(): def get_config():
basename = os.path.basename(__file__) basename = os.path.basename(__file__)
log_dir = os.path.join('train_log', basename[:basename.rfind('.')]) log_dir = os.path.join('train_log', basename[:basename.rfind('.')])
logger.set_logger_dir(log_dir) logger.set_logger_file(os.path.join(log_dir, 'training.log'))
dataset_train = dataset.Cifar10('train') dataset_train = dataset.Cifar10('train')
augmentors = [ augmentors = [
......
...@@ -7,29 +7,44 @@ import argparse ...@@ -7,29 +7,44 @@ import argparse
import cv2 import cv2
import tensorflow as tf import tensorflow as tf
import imp import imp
import tqdm
import os import os
from tensorpack.utils import logger
from tensorpack.utils.utils import mkdir_p from tensorpack.utils.utils import mkdir_p
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument(dest='config') parser.add_argument(dest='config')
parser.add_argument(dest='output') parser.add_argument('-o', '--output', help='output directory to dump dataset image')
parser.add_argument('-n', '--number', help='number of images to take', parser.add_argument('-n', '--number', help='number of images to dump',
default=10, type=int) default=10, type=int)
args = parser.parse_args() args = parser.parse_args()
mkdir_p(args.output)
index = 0 # TODO: as an argument?
get_config_func = imp.load_source('config_script', args.config).get_config get_config_func = imp.load_source('config_script', args.config).get_config
config = get_config_func() config = get_config_func()
cnt = 0 if args.output:
for dp in config.dataset.get_data(): mkdir_p(args.output)
imgbatch = dp[index] cnt = 0
if cnt > args.number: index = 0 # TODO: as an argument?
break for dp in config.dataset.get_data():
for bi, img in enumerate(imgbatch): imgbatch = dp[index]
cnt += 1 if cnt > args.number:
fname = os.path.join(args.output, '{:03d}-{}.png'.format(cnt, bi)) break
cv2.imwrite(fname, img * 255.0) for bi, img in enumerate(imgbatch):
cnt += 1
fname = os.path.join(args.output, '{:03d}-{}.png'.format(cnt, bi))
cv2.imwrite(fname, img)
NR_DP_TEST = 100
logger.info("Testing dataflow speed:")
with tqdm.tqdm(total=NR_DP_TEST, leave=True, unit='data points') as pbar:
for idx, dp in enumerate(config.dataset.get_data()):
if idx > NR_DP_TEST:
break
pbar.update()
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
# Author: Yuxin Wu <ppwwyyxx@gmail.com> # Author: Yuxin Wu <ppwwyyxx@gmail.com>
import tensorflow as tf import tensorflow as tf
import itertools
from tqdm import tqdm from tqdm import tqdm
from ..utils import * from ..utils import *
...@@ -47,7 +48,7 @@ class ValidationError(PeriodicCallback): ...@@ -47,7 +48,7 @@ class ValidationError(PeriodicCallback):
cost_sum = 0 cost_sum = 0
with tqdm(total=self.ds.size()) as pbar: with tqdm(total=self.ds.size()) as pbar:
for dp in self.ds.get_data(): for dp in self.ds.get_data():
feed = dict(zip(self.input_vars, dp)) feed = dict(itertools.izip(self.input_vars, dp))
batch_size = dp[0].shape[0] # assume batched input batch_size = dp[0].shape[0] # assume batched input
......
...@@ -17,18 +17,17 @@ class Flip(ImageAugmentor): ...@@ -17,18 +17,17 @@ class Flip(ImageAugmentor):
horiz, vert: True/False horiz, vert: True/False
""" """
if horiz and vert: if horiz and vert:
self.code = -1 raise ValueError("Please use two Flip, with both 0.5 prob")
elif horiz: elif horiz:
self.code = 1 self.code = 1
elif vert: elif vert:
self.code = 0 self.code = 0
else: else:
raise RuntimeError("Are you kidding?") raise ValueError("Are you kidding?")
self.prob = prob self.prob = prob
self._init() self._init()
def _augment(self, img): def _augment(self, img):
# TODO XXX prob is wrong for both mode
if self._rand_range() < self.prob: if self._rand_range() < self.prob:
img.arr = cv2.flip(img.arr, self.code) img.arr = cv2.flip(img.arr, self.code)
if img.coords: if img.coords:
......
...@@ -52,11 +52,11 @@ def set_file(path): ...@@ -52,11 +52,11 @@ def set_file(path):
filename=path, encoding='utf-8', mode='w') filename=path, encoding='utf-8', mode='w')
logger.addHandler(hdl) logger.addHandler(hdl)
global LOG_DIR global LOG_FILE
LOG_DIR = "train_log" LOG_FILE = "train_log/log.log"
def set_logger_dir(dirname): def set_logger_file(filename):
global LOG_DIR global LOG_FILE
LOG_DIR = dirname LOG_FILE = filename
mkdir_p(LOG_DIR) mkdir_p(os.path.dirname(LOG_FILE))
set_file(os.path.join(LOG_DIR, 'training.log')) set_file(LOG_FILE)
...@@ -20,6 +20,7 @@ def expand_dim_if_necessary(var, dp): ...@@ -20,6 +20,7 @@ def expand_dim_if_necessary(var, dp):
def mkdir_p(dirname): def mkdir_p(dirname):
assert dirname is not None
if dirname == '': if dirname == '':
return return
try: try:
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment