Commit d5d7270a authored by Yuxin Wu's avatar Yuxin Wu

some further simplification

parent 2d96baca
#!/usr/bin/env python #!/usr/bin/env python
# -*- coding: UTF-8 -*- # -*- coding: UTF-8 -*-
# File: cifar10-convnet.py # File: cifar-convnet.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com> # Author: Yuxin Wu <ppwwyyxx@gmail.com>
import numpy import numpy
import tensorflow as tf import tensorflow as tf
...@@ -19,7 +19,8 @@ from tensorpack.dataflow import * ...@@ -19,7 +19,8 @@ from tensorpack.dataflow import *
""" """
A small convnet model for cifar 10 or cifar100 dataset. A small convnet model for cifar 10 or cifar100 dataset.
90% validation accuracy after 40k step.
For Cifar10: 90% validation accuracy after 40k step.
""" """
class Model(ModelDesc): class Model(ModelDesc):
...@@ -141,7 +142,8 @@ if __name__ == '__main__': ...@@ -141,7 +142,8 @@ if __name__ == '__main__':
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('--gpu', help='comma separated list of GPU(s) to use.') # nargs='*' in multi mode parser.add_argument('--gpu', help='comma separated list of GPU(s) to use.') # nargs='*' in multi mode
parser.add_argument('--load', help='load model') parser.add_argument('--load', help='load model')
parser.add_argument('--classnum', help='specify cifar10 or cifar100, input 10 for cifar10 or 100 for cifar100') parser.add_argument('--classnum', help='10 for cifar10 or 100 for cifar100',
type=int, default=10)
args = parser.parse_args() args = parser.parse_args()
basename = os.path.basename(__file__) basename = os.path.basename(__file__)
...@@ -153,13 +155,8 @@ if __name__ == '__main__': ...@@ -153,13 +155,8 @@ if __name__ == '__main__':
else: else:
os.environ['CUDA_VISIBLE_DEVICES'] = '0' os.environ['CUDA_VISIBLE_DEVICES'] = '0'
if args.classnum:
cifar_classnum = int(args.classnum)
else:
cifar_classnum = 10
with tf.Graph().as_default(): with tf.Graph().as_default():
config = get_config(cifar_classnum) config = get_config(args.classnum)
if args.load: if args.load:
config.session_init = SaverRestore(args.load) config.session_init = SaverRestore(args.load)
if args.gpu: if args.gpu:
......
#!/usr/bin/env python2 #!/usr/bin/env python2
# -*- coding: UTF-8 -*- # -*- coding: UTF-8 -*-
# File: cifar10.py # File: cifar.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com> # Author: Yuxin Wu <ppwwyyxx@gmail.com>
# Yukun Chen <cykustc@gmail.com>
import os, sys import os, sys
import pickle import pickle
import numpy as np import numpy as np
...@@ -19,7 +21,7 @@ __all__ = ['Cifar10', 'Cifar100'] ...@@ -19,7 +21,7 @@ __all__ = ['Cifar10', 'Cifar100']
DATA_URL_CIFAR_10 = 'http://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz' DATA_URL_CIFAR_10 = 'http://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz'
DATA_URL_CIFAR_100 = 'https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz' DATA_URL_CIFAR_100 = 'http://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz'
def maybe_download_and_extract(dest_directory, cifar_classnum): def maybe_download_and_extract(dest_directory, cifar_classnum):
"""Download and extract the tarball from Alex's website. """Download and extract the tarball from Alex's website.
...@@ -52,7 +54,7 @@ def read_cifar(filenames, cifar_classnum): ...@@ -52,7 +54,7 @@ def read_cifar(filenames, cifar_classnum):
data = dic[b'data'] data = dic[b'data']
if cifar_classnum == 10: if cifar_classnum == 10:
label = dic[b'labels'] label = dic[b'labels']
IMG_NUM = 10000 IMG_NUM = 10000 # cifar10 data are split into blocks of 10000
elif cifar_classnum == 100: elif cifar_classnum == 100:
label = dic[b'fine_labels'] label = dic[b'fine_labels']
IMG_NUM = 50000 if 'train' in fname else 10000 IMG_NUM = 50000 if 'train' in fname else 10000
...@@ -71,10 +73,8 @@ def get_filenames(dir, cifar_classnum): ...@@ -71,10 +73,8 @@ def get_filenames(dir, cifar_classnum):
filenames.append(os.path.join( filenames.append(os.path.join(
dir, 'cifar-10-batches-py', 'test_batch')) dir, 'cifar-10-batches-py', 'test_batch'))
elif cifar_classnum == 100: elif cifar_classnum == 100:
filenames = [os.path.join( filenames = [os.path.join(dir, 'cifar-100-python', 'train'),
dir, 'cifar-100-python', 'train')] os.path.join(dir, 'cifar-100-python', 'test')]
filenames.append(os.path.join(
dir, 'cifar-100-python', 'test'))
return filenames return filenames
class CifarBase(DataFlow): class CifarBase(DataFlow):
...@@ -92,15 +92,12 @@ class CifarBase(DataFlow): ...@@ -92,15 +92,12 @@ class CifarBase(DataFlow):
assert cifar_classnum == 10 or cifar_classnum == 100 assert cifar_classnum == 10 or cifar_classnum == 100
self.cifar_classnum = cifar_classnum self.cifar_classnum = cifar_classnum
if dir is None: if dir is None:
dir = os.path.join(os.path.dirname(__file__), 'cifar-10-batches-py' dir = os.path.join(os.path.dirname(__file__),
if cifar_classnum==10 else 'cifar100_data') 'cifar{}_data'.format(cifar_classnum))
maybe_download_and_extract(dir, self.cifar_classnum) maybe_download_and_extract(dir, self.cifar_classnum)
if self.cifar_classnum == 10: fnames = get_filenames(dir, cifar_classnum)
fnames = get_filenames(dir, 10)
else:
fnames = get_filenames(dir, 100)
if train_or_test == 'train': if train_or_test == 'train':
self.fs = fnames[:5] if cifar_classnum==10 else fnames[:1] self.fs = fnames[:-1]
else: else:
self.fs = [fnames[-1]] self.fs = [fnames[-1]]
for f in self.fs: for f in self.fs:
......
...@@ -125,7 +125,6 @@ class QueueInputTrainer(Trainer): ...@@ -125,7 +125,6 @@ class QueueInputTrainer(Trainer):
def _get_model_inputs(self): def _get_model_inputs(self):
""" Dequeue a datapoint from input_queue and return""" """ Dequeue a datapoint from input_queue and return"""
ret = self.input_queue.dequeue(name='input_deque') ret = self.input_queue.dequeue(name='input_deque')
print ret
if isinstance(ret, tf.Tensor): # only one input if isinstance(ret, tf.Tensor): # only one input
ret = [ret] ret = [ret]
assert len(ret) == len(self.input_vars) assert len(ret) == len(self.input_vars)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment