Commit 7eb73782 authored by Yuxin Wu's avatar Yuxin Wu

shuffle cifar/mnist data

parent ef4a15ca
......@@ -18,7 +18,7 @@ from tensorpack.dataflow import *
from tensorpack.dataflow import imgaug
"""
CIFAR10 90% validation accuracy after 100k step, 91% after 160k step.
CIFAR10 90% validation accuracy after 100k step
"""
BATCH_SIZE = 128
......
......@@ -5,6 +5,7 @@
import os, sys
import pickle
import numpy as np
import random
from six.moves import urllib, range
import copy
import tarfile
......@@ -65,10 +66,11 @@ class Cifar10(DataFlow):
Return [image, label],
image is 32x32x3 in the range [0,255]
"""
def __init__(self, train_or_test, dir=None):
def __init__(self, train_or_test, shuffle=True, dir=None):
"""
Args:
train_or_test: string either 'train' or 'test'
shuffle: default to True
"""
assert train_or_test in ['train', 'test']
if dir is None:
......@@ -86,13 +88,17 @@ class Cifar10(DataFlow):
self.train_or_test = train_or_test
self.dir = dir
self.data = read_cifar10(self.fs)
self.shuffle = shuffle
def size(self):
return 50000 if self.train_or_test == 'train' else 10000
def get_data(self):
for k in self.data:
yield k
idxs = list(range(len(self.data)))
if self.shuffle:
random.shuffle(idxs)
for k in idxs:
yield self.data[k]
def get_per_pixel_mean(self):
"""
......
......@@ -5,7 +5,7 @@
import os
import gzip
import random
import numpy
from six.moves import urllib, range
......@@ -100,7 +100,7 @@ class Mnist(DataFlow):
Return [image, label],
image is 28x28 in the range [0,1]
"""
def __init__(self, train_or_test, dir=None):
def __init__(self, train_or_test, shuffle=True, dir=None):
"""
Args:
train_or_test: string either 'train' or 'test'
......@@ -109,6 +109,7 @@ class Mnist(DataFlow):
dir = os.path.join(os.path.dirname(__file__), 'mnist_data')
assert train_or_test in ['train', 'test']
self.train_or_test = train_or_test
self.shuffle = shuffle
TRAIN_IMAGES = 'train-images-idx3-ubyte.gz'
TRAIN_LABELS = 'train-labels-idx1-ubyte.gz'
......@@ -136,7 +137,10 @@ class Mnist(DataFlow):
def get_data(self):
ds = self.train if self.train_or_test == 'train' else self.test
for k in range(ds.num_examples):
idxs = list(range(ds.num_examples))
if self.shuffle:
random.shuffle(idxs)
for k in idxs:
img = ds.images[k].reshape((28, 28))
label = ds.labels[k]
yield [img, label]
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment