Commit d5d7270a authored by Yuxin Wu's avatar Yuxin Wu

some further simplification

parent 2d96baca
#!/usr/bin/env python
# -*- coding: UTF-8 -*-
# File: cifar10-convnet.py
# File: cifar-convnet.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
import numpy
import tensorflow as tf
......@@ -19,7 +19,8 @@ from tensorpack.dataflow import *
"""
A small convnet model for cifar 10 or cifar100 dataset.
90% validation accuracy after 40k step.
For Cifar10: 90% validation accuracy after 40k step.
"""
class Model(ModelDesc):
......@@ -141,7 +142,8 @@ if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--gpu', help='comma separated list of GPU(s) to use.') # nargs='*' in multi mode
parser.add_argument('--load', help='load model')
parser.add_argument('--classnum', help='specify cifar10 or cifar100, input 10 for cifar10 or 100 for cifar100')
parser.add_argument('--classnum', help='10 for cifar10 or 100 for cifar100',
type=int, default=10)
args = parser.parse_args()
basename = os.path.basename(__file__)
......@@ -153,13 +155,8 @@ if __name__ == '__main__':
else:
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
if args.classnum:
cifar_classnum = int(args.classnum)
else:
cifar_classnum = 10
with tf.Graph().as_default():
config = get_config(cifar_classnum)
config = get_config(args.classnum)
if args.load:
config.session_init = SaverRestore(args.load)
if args.gpu:
......
#!/usr/bin/env python2
# -*- coding: UTF-8 -*-
# File: cifar10.py
# File: cifar.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
# Yukun Chen <cykustc@gmail.com>
import os, sys
import pickle
import numpy as np
......@@ -19,7 +21,7 @@ __all__ = ['Cifar10', 'Cifar100']
DATA_URL_CIFAR_10 = 'http://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz'
DATA_URL_CIFAR_100 = 'https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz'
DATA_URL_CIFAR_100 = 'http://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz'
def maybe_download_and_extract(dest_directory, cifar_classnum):
"""Download and extract the tarball from Alex's website.
......@@ -52,7 +54,7 @@ def read_cifar(filenames, cifar_classnum):
data = dic[b'data']
if cifar_classnum == 10:
label = dic[b'labels']
IMG_NUM = 10000
IMG_NUM = 10000 # cifar10 data are split into blocks of 10000
elif cifar_classnum == 100:
label = dic[b'fine_labels']
IMG_NUM = 50000 if 'train' in fname else 10000
......@@ -71,10 +73,8 @@ def get_filenames(dir, cifar_classnum):
filenames.append(os.path.join(
dir, 'cifar-10-batches-py', 'test_batch'))
elif cifar_classnum == 100:
filenames = [os.path.join(
dir, 'cifar-100-python', 'train')]
filenames.append(os.path.join(
dir, 'cifar-100-python', 'test'))
filenames = [os.path.join(dir, 'cifar-100-python', 'train'),
os.path.join(dir, 'cifar-100-python', 'test')]
return filenames
class CifarBase(DataFlow):
......@@ -92,15 +92,12 @@ class CifarBase(DataFlow):
assert cifar_classnum == 10 or cifar_classnum == 100
self.cifar_classnum = cifar_classnum
if dir is None:
dir = os.path.join(os.path.dirname(__file__), 'cifar-10-batches-py'
if cifar_classnum==10 else 'cifar100_data')
dir = os.path.join(os.path.dirname(__file__),
'cifar{}_data'.format(cifar_classnum))
maybe_download_and_extract(dir, self.cifar_classnum)
if self.cifar_classnum == 10:
fnames = get_filenames(dir, 10)
else:
fnames = get_filenames(dir, 100)
fnames = get_filenames(dir, cifar_classnum)
if train_or_test == 'train':
self.fs = fnames[:5] if cifar_classnum==10 else fnames[:1]
self.fs = fnames[:-1]
else:
self.fs = [fnames[-1]]
for f in self.fs:
......
......@@ -125,7 +125,6 @@ class QueueInputTrainer(Trainer):
def _get_model_inputs(self):
""" Dequeue a datapoint from input_queue and return"""
ret = self.input_queue.dequeue(name='input_deque')
print ret
if isinstance(ret, tf.Tensor): # only one input
ret = [ret]
assert len(ret) == len(self.input_vars)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment