Commit f4507d45 authored by Yuxin Wu's avatar Yuxin Wu

dump script & prefetch size

parent f18314d6
......@@ -91,10 +91,6 @@ class Model(ModelDesc):
return tf.add_n([cost, wd_cost], name='cost')
def get_config():
basename = os.path.basename(__file__)
logger.set_logger_dir(
os.path.join('train_log', basename[:basename.rfind('.')]))
# prepare dataset
dataset_train = dataset.Cifar10('train')
augmentors = [
......@@ -102,10 +98,13 @@ def get_config():
imgaug.Flip(horiz=True),
imgaug.BrightnessAdd(63),
imgaug.Contrast((0.2,1.8)),
#imgaug.GaussianDeform([(0.2, 0.2), (0.2, 0.8), (0.8,0.8), (0.8,0.2)],
#(30,30), 0.2, 3),
imgaug.MeanVarianceNormalize(all_channel=True)
]
dataset_train = AugmentImageComponent(dataset_train, augmentors)
dataset_train = BatchData(dataset_train, 128)
#dataset_train = PrefetchData(dataset_train, 3, 2)
step_per_epoch = dataset_train.size()
augmentors = [
......@@ -145,6 +144,11 @@ if __name__ == '__main__':
parser.add_argument('--gpu', help='comma separated list of GPU(s) to use.') # nargs='*' in multi mode
parser.add_argument('--load', help='load model')
args = parser.parse_args()
basename = os.path.basename(__file__)
logger.set_logger_dir(
os.path.join('train_log', basename[:basename.rfind('.')]))
if args.gpu:
os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
else:
......
......@@ -9,10 +9,10 @@ import tensorflow as tf
import imp
import tqdm
import os
from tensorpack.utils import logger
from tensorpack.utils.utils import mkdir_p
parser = argparse.ArgumentParser()
parser.add_argument(dest='config')
parser.add_argument('-o', '--output', help='output directory to dump dataset image')
......@@ -23,8 +23,6 @@ parser.add_argument('-n', '--number', help='number of images to dump',
default=10, type=int)
args = parser.parse_args()
get_config_func = imp.load_source('config_script', args.config).get_config
config = get_config_func()
......@@ -39,7 +37,7 @@ if args.output:
for bi, img in enumerate(imgbatch):
cnt += 1
fname = os.path.join(args.output, '{:03d}-{}.png'.format(cnt, bi))
cv2.imwrite(fname, img * 255)
cv2.imwrite(fname, img)
NR_DP_TEST = 100
logger.info("Testing dataflow speed:")
......
......@@ -15,11 +15,11 @@ __all__ = ['PeriodicSaver']
class PeriodicSaver(PeriodicCallback):
def __init__(self, period=1, keep_recent=10, keep_freq=0.5):
super(PeriodicSaver, self).__init__(period)
self.path = os.path.join(logger.LOG_DIR, 'model')
self.keep_recent = keep_recent
self.keep_freq = keep_freq
def _before_train(self):
self.path = os.path.join(logger.LOG_DIR, 'model')
self.saver = tf.train.Saver(
max_to_keep=self.keep_recent,
keep_checkpoint_every_n_hours=self.keep_freq)
......
......@@ -39,6 +39,9 @@ class PrefetchData(DataFlow):
self.nr_proc = nr_proc
self.nr_prefetch = nr_prefetch
def size(self):
return self.ds.size() * self.nr_proc
def get_data(self):
queue = multiprocessing.Queue(self.nr_prefetch)
procs = [PrefetchProcess(self.ds, queue) for _ in range(self.nr_proc)]
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment