Commit 7eb73782 authored by Yuxin Wu's avatar Yuxin Wu

shuffle cifar/mnist data

parent ef4a15ca
...@@ -18,7 +18,7 @@ from tensorpack.dataflow import * ...@@ -18,7 +18,7 @@ from tensorpack.dataflow import *
from tensorpack.dataflow import imgaug from tensorpack.dataflow import imgaug
""" """
CIFAR10 90% validation accuracy after 100k step, 91% after 160k step. CIFAR10 90% validation accuracy after 100k step
""" """
BATCH_SIZE = 128 BATCH_SIZE = 128
......
...@@ -5,6 +5,7 @@ ...@@ -5,6 +5,7 @@
import os, sys import os, sys
import pickle import pickle
import numpy as np import numpy as np
import random
from six.moves import urllib, range from six.moves import urllib, range
import copy import copy
import tarfile import tarfile
...@@ -65,10 +66,11 @@ class Cifar10(DataFlow): ...@@ -65,10 +66,11 @@ class Cifar10(DataFlow):
Return [image, label], Return [image, label],
image is 32x32x3 in the range [0,255] image is 32x32x3 in the range [0,255]
""" """
def __init__(self, train_or_test, dir=None): def __init__(self, train_or_test, shuffle=True, dir=None):
""" """
Args: Args:
train_or_test: string either 'train' or 'test' train_or_test: string either 'train' or 'test'
shuffle: default to True
""" """
assert train_or_test in ['train', 'test'] assert train_or_test in ['train', 'test']
if dir is None: if dir is None:
...@@ -86,13 +88,17 @@ class Cifar10(DataFlow): ...@@ -86,13 +88,17 @@ class Cifar10(DataFlow):
self.train_or_test = train_or_test self.train_or_test = train_or_test
self.dir = dir self.dir = dir
self.data = read_cifar10(self.fs) self.data = read_cifar10(self.fs)
self.shuffle = shuffle
def size(self): def size(self):
return 50000 if self.train_or_test == 'train' else 10000 return 50000 if self.train_or_test == 'train' else 10000
def get_data(self): def get_data(self):
for k in self.data: idxs = list(range(len(self.data)))
yield k if self.shuffle:
random.shuffle(idxs)
for k in idxs:
yield self.data[k]
def get_per_pixel_mean(self): def get_per_pixel_mean(self):
""" """
......
...@@ -5,7 +5,7 @@ ...@@ -5,7 +5,7 @@
import os import os
import gzip import gzip
import random
import numpy import numpy
from six.moves import urllib, range from six.moves import urllib, range
...@@ -100,7 +100,7 @@ class Mnist(DataFlow): ...@@ -100,7 +100,7 @@ class Mnist(DataFlow):
Return [image, label], Return [image, label],
image is 28x28 in the range [0,1] image is 28x28 in the range [0,1]
""" """
def __init__(self, train_or_test, dir=None): def __init__(self, train_or_test, shuffle=True, dir=None):
""" """
Args: Args:
train_or_test: string either 'train' or 'test' train_or_test: string either 'train' or 'test'
...@@ -109,6 +109,7 @@ class Mnist(DataFlow): ...@@ -109,6 +109,7 @@ class Mnist(DataFlow):
dir = os.path.join(os.path.dirname(__file__), 'mnist_data') dir = os.path.join(os.path.dirname(__file__), 'mnist_data')
assert train_or_test in ['train', 'test'] assert train_or_test in ['train', 'test']
self.train_or_test = train_or_test self.train_or_test = train_or_test
self.shuffle = shuffle
TRAIN_IMAGES = 'train-images-idx3-ubyte.gz' TRAIN_IMAGES = 'train-images-idx3-ubyte.gz'
TRAIN_LABELS = 'train-labels-idx1-ubyte.gz' TRAIN_LABELS = 'train-labels-idx1-ubyte.gz'
...@@ -136,7 +137,10 @@ class Mnist(DataFlow): ...@@ -136,7 +137,10 @@ class Mnist(DataFlow):
def get_data(self): def get_data(self):
ds = self.train if self.train_or_test == 'train' else self.test ds = self.train if self.train_or_test == 'train' else self.test
for k in range(ds.num_examples): idxs = list(range(ds.num_examples))
if self.shuffle:
random.shuffle(idxs)
for k in idxs:
img = ds.images[k].reshape((28, 28)) img = ds.images[k].reshape((28, 28))
label = ds.labels[k] label = ds.labels[k]
yield [img, label] yield [img, label]
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment