Commit f4507d45 authored by Yuxin Wu's avatar Yuxin Wu

dump script & prefetch size

parent f18314d6
...@@ -91,10 +91,6 @@ class Model(ModelDesc): ...@@ -91,10 +91,6 @@ class Model(ModelDesc):
return tf.add_n([cost, wd_cost], name='cost') return tf.add_n([cost, wd_cost], name='cost')
def get_config(): def get_config():
basename = os.path.basename(__file__)
logger.set_logger_dir(
os.path.join('train_log', basename[:basename.rfind('.')]))
# prepare dataset # prepare dataset
dataset_train = dataset.Cifar10('train') dataset_train = dataset.Cifar10('train')
augmentors = [ augmentors = [
...@@ -102,10 +98,13 @@ def get_config(): ...@@ -102,10 +98,13 @@ def get_config():
imgaug.Flip(horiz=True), imgaug.Flip(horiz=True),
imgaug.BrightnessAdd(63), imgaug.BrightnessAdd(63),
imgaug.Contrast((0.2,1.8)), imgaug.Contrast((0.2,1.8)),
#imgaug.GaussianDeform([(0.2, 0.2), (0.2, 0.8), (0.8,0.8), (0.8,0.2)],
#(30,30), 0.2, 3),
imgaug.MeanVarianceNormalize(all_channel=True) imgaug.MeanVarianceNormalize(all_channel=True)
] ]
dataset_train = AugmentImageComponent(dataset_train, augmentors) dataset_train = AugmentImageComponent(dataset_train, augmentors)
dataset_train = BatchData(dataset_train, 128) dataset_train = BatchData(dataset_train, 128)
#dataset_train = PrefetchData(dataset_train, 3, 2)
step_per_epoch = dataset_train.size() step_per_epoch = dataset_train.size()
augmentors = [ augmentors = [
...@@ -145,6 +144,11 @@ if __name__ == '__main__': ...@@ -145,6 +144,11 @@ if __name__ == '__main__':
parser.add_argument('--gpu', help='comma separated list of GPU(s) to use.') # nargs='*' in multi mode parser.add_argument('--gpu', help='comma separated list of GPU(s) to use.') # nargs='*' in multi mode
parser.add_argument('--load', help='load model') parser.add_argument('--load', help='load model')
args = parser.parse_args() args = parser.parse_args()
basename = os.path.basename(__file__)
logger.set_logger_dir(
os.path.join('train_log', basename[:basename.rfind('.')]))
if args.gpu: if args.gpu:
os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
else: else:
......
...@@ -9,10 +9,10 @@ import tensorflow as tf ...@@ -9,10 +9,10 @@ import tensorflow as tf
import imp import imp
import tqdm import tqdm
import os import os
from tensorpack.utils import logger from tensorpack.utils import logger
from tensorpack.utils.utils import mkdir_p from tensorpack.utils.utils import mkdir_p
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument(dest='config') parser.add_argument(dest='config')
parser.add_argument('-o', '--output', help='output directory to dump dataset image') parser.add_argument('-o', '--output', help='output directory to dump dataset image')
...@@ -23,8 +23,6 @@ parser.add_argument('-n', '--number', help='number of images to dump', ...@@ -23,8 +23,6 @@ parser.add_argument('-n', '--number', help='number of images to dump',
default=10, type=int) default=10, type=int)
args = parser.parse_args() args = parser.parse_args()
get_config_func = imp.load_source('config_script', args.config).get_config get_config_func = imp.load_source('config_script', args.config).get_config
config = get_config_func() config = get_config_func()
...@@ -39,7 +37,7 @@ if args.output: ...@@ -39,7 +37,7 @@ if args.output:
for bi, img in enumerate(imgbatch): for bi, img in enumerate(imgbatch):
cnt += 1 cnt += 1
fname = os.path.join(args.output, '{:03d}-{}.png'.format(cnt, bi)) fname = os.path.join(args.output, '{:03d}-{}.png'.format(cnt, bi))
cv2.imwrite(fname, img * 255) cv2.imwrite(fname, img)
NR_DP_TEST = 100 NR_DP_TEST = 100
logger.info("Testing dataflow speed:") logger.info("Testing dataflow speed:")
......
...@@ -15,11 +15,11 @@ __all__ = ['PeriodicSaver'] ...@@ -15,11 +15,11 @@ __all__ = ['PeriodicSaver']
class PeriodicSaver(PeriodicCallback): class PeriodicSaver(PeriodicCallback):
def __init__(self, period=1, keep_recent=10, keep_freq=0.5): def __init__(self, period=1, keep_recent=10, keep_freq=0.5):
super(PeriodicSaver, self).__init__(period) super(PeriodicSaver, self).__init__(period)
self.path = os.path.join(logger.LOG_DIR, 'model')
self.keep_recent = keep_recent self.keep_recent = keep_recent
self.keep_freq = keep_freq self.keep_freq = keep_freq
def _before_train(self): def _before_train(self):
self.path = os.path.join(logger.LOG_DIR, 'model')
self.saver = tf.train.Saver( self.saver = tf.train.Saver(
max_to_keep=self.keep_recent, max_to_keep=self.keep_recent,
keep_checkpoint_every_n_hours=self.keep_freq) keep_checkpoint_every_n_hours=self.keep_freq)
......
...@@ -39,6 +39,9 @@ class PrefetchData(DataFlow): ...@@ -39,6 +39,9 @@ class PrefetchData(DataFlow):
self.nr_proc = nr_proc self.nr_proc = nr_proc
self.nr_prefetch = nr_prefetch self.nr_prefetch = nr_prefetch
def size(self):
return self.ds.size() * self.nr_proc
def get_data(self): def get_data(self):
queue = multiprocessing.Queue(self.nr_prefetch) queue = multiprocessing.Queue(self.nr_prefetch)
procs = [PrefetchProcess(self.ds, queue) for _ in range(self.nr_proc)] procs = [PrefetchProcess(self.ds, queue) for _ in range(self.nr_proc)]
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment