Commit c653458c authored by Yuxin Wu's avatar Yuxin Wu

fix logdir

parent d2262d1d
......@@ -88,7 +88,7 @@ def get_model(inputs, is_training):
def get_config():
basename = os.path.basename(__file__)
log_dir = os.path.join('train_log', basename[:basename.rfind('.')])
logger.set_logger_dir(log_dir)
logger.set_logger_file(os.path.join(log_dir, 'training.log'))
dataset_train = dataset.Cifar10('train')
augmentors = [
......
......@@ -7,29 +7,44 @@ import argparse
import cv2
import tensorflow as tf
import imp
import tqdm
import os
from tensorpack.utils import logger
from tensorpack.utils.utils import mkdir_p
parser = argparse.ArgumentParser()
parser.add_argument(dest='config')
parser.add_argument(dest='output')
parser.add_argument('-n', '--number', help='number of images to take',
parser.add_argument('-o', '--output', help='output directory to dump dataset image')
parser.add_argument('-n', '--number', help='number of images to dump',
default=10, type=int)
args = parser.parse_args()
mkdir_p(args.output)
index = 0 # TODO: as an argument?
get_config_func = imp.load_source('config_script', args.config).get_config
config = get_config_func()
cnt = 0
for dp in config.dataset.get_data():
if args.output:
mkdir_p(args.output)
cnt = 0
index = 0 # TODO: as an argument?
for dp in config.dataset.get_data():
imgbatch = dp[index]
if cnt > args.number:
break
for bi, img in enumerate(imgbatch):
cnt += 1
fname = os.path.join(args.output, '{:03d}-{}.png'.format(cnt, bi))
cv2.imwrite(fname, img * 255.0)
cv2.imwrite(fname, img)
NR_DP_TEST = 100
logger.info("Testing dataflow speed:")
with tqdm.tqdm(total=NR_DP_TEST, leave=True, unit='data points') as pbar:
for idx, dp in enumerate(config.dataset.get_data()):
if idx > NR_DP_TEST:
break
pbar.update()
......@@ -4,6 +4,7 @@
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
import tensorflow as tf
import itertools
from tqdm import tqdm
from ..utils import *
......@@ -47,7 +48,7 @@ class ValidationError(PeriodicCallback):
cost_sum = 0
with tqdm(total=self.ds.size()) as pbar:
for dp in self.ds.get_data():
feed = dict(zip(self.input_vars, dp))
feed = dict(itertools.izip(self.input_vars, dp))
batch_size = dp[0].shape[0] # assume batched input
......
......@@ -17,18 +17,17 @@ class Flip(ImageAugmentor):
horiz, vert: True/False
"""
if horiz and vert:
self.code = -1
raise ValueError("Please use two Flip, with both 0.5 prob")
elif horiz:
self.code = 1
elif vert:
self.code = 0
else:
raise RuntimeError("Are you kidding?")
raise ValueError("Are you kidding?")
self.prob = prob
self._init()
def _augment(self, img):
# TODO XXX prob is wrong for both mode
if self._rand_range() < self.prob:
img.arr = cv2.flip(img.arr, self.code)
if img.coords:
......
......@@ -52,11 +52,11 @@ def set_file(path):
filename=path, encoding='utf-8', mode='w')
logger.addHandler(hdl)
global LOG_DIR
LOG_DIR = "train_log"
def set_logger_dir(dirname):
global LOG_DIR
LOG_DIR = dirname
mkdir_p(LOG_DIR)
set_file(os.path.join(LOG_DIR, 'training.log'))
global LOG_FILE
LOG_FILE = "train_log/log.log"
def set_logger_file(filename):
global LOG_FILE
LOG_FILE = filename
mkdir_p(os.path.dirname(LOG_FILE))
set_file(LOG_FILE)
......@@ -20,6 +20,7 @@ def expand_dim_if_necessary(var, dp):
def mkdir_p(dirname):
assert dirname is not None
if dirname == '':
return
try:
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment