Commit 0a012166 authored by Yuxin Wu's avatar Yuxin Wu

cifar10 mean

parent bac11ae3
...@@ -4,7 +4,7 @@ ...@@ -4,7 +4,7 @@
# Author: Yuxin Wu <ppwwyyxx@gmail.com> # Author: Yuxin Wu <ppwwyyxx@gmail.com>
import os, sys import os, sys
import pickle import pickle
import numpy import numpy as np
from six.moves import urllib from six.moves import urllib
import tarfile import tarfile
import logging import logging
...@@ -39,6 +39,7 @@ def maybe_download_and_extract(dest_directory): ...@@ -39,6 +39,7 @@ def maybe_download_and_extract(dest_directory):
tarfile.open(filepath, 'r:gz').extractall(dest_directory) tarfile.open(filepath, 'r:gz').extractall(dest_directory)
def read_cifar10(filenames): def read_cifar10(filenames):
ret = []
for fname in filenames: for fname in filenames:
fo = open(fname, 'rb') fo = open(fname, 'rb')
dic = pickle.load(fo) dic = pickle.load(fo)
...@@ -47,8 +48,16 @@ def read_cifar10(filenames): ...@@ -47,8 +48,16 @@ def read_cifar10(filenames):
fo.close() fo.close()
for k in xrange(10000): for k in xrange(10000):
img = data[k].reshape(3, 32, 32) img = data[k].reshape(3, 32, 32)
img = numpy.transpose(img, [1, 2, 0]) img = np.transpose(img, [1, 2, 0])
yield [img, label[k]] ret.append([img, label[k]])
return ret
def get_filenames(dir):
filenames = [os.path.join(
dir, 'cifar-10-batches-py', 'data_batch_%d' % i) for i in xrange(1, 6)]
filenames.append(os.path.join(
dir, 'cifar-10-batches-py', 'test_batch'))
return filenames
class Cifar10(DataFlow): class Cifar10(DataFlow):
""" """
...@@ -65,27 +74,49 @@ class Cifar10(DataFlow): ...@@ -65,27 +74,49 @@ class Cifar10(DataFlow):
dir = os.path.join(os.path.dirname(__file__), 'cifar10_data') dir = os.path.join(os.path.dirname(__file__), 'cifar10_data')
maybe_download_and_extract(dir) maybe_download_and_extract(dir)
fnames = get_filenames(dir)
if train_or_test == 'train': if train_or_test == 'train':
self.fs = [os.path.join( self.fs = fnames[:5]
dir, 'cifar-10-batches-py', 'data_batch_%d' % i) for i in xrange(1, 6)]
else: else:
self.fs = [os.path.join(dir, 'cifar-10-batches-py', 'test_batch')] self.fs = fnames[-1]
for f in self.fs: for f in self.fs:
if not os.path.isfile(f): if not os.path.isfile(f):
raise ValueError('Failed to find file: ' + f) raise ValueError('Failed to find file: ' + f)
self.train_or_test = train_or_test self.train_or_test = train_or_test
self.dir = dir
self.data = read_cifar10(self.fs)
def size(self): def size(self):
return 50000 if self.train_or_test == 'train' else 10000 return 50000 if self.train_or_test == 'train' else 10000
def get_data(self): def get_data(self):
for k in read_cifar10(self.fs): for k in self.data:
yield k yield k
def get_per_pixel_mean(self):
"""
return a mean image of all (train and test) images of size 32x32x3
"""
fnames = get_filenames(self.dir)
all_imgs = [x[0] for x in read_cifar10(fnames)]
arr = np.array(all_imgs, dtype='float32')
mean = np.mean(arr, axis=0)
return mean
def get_per_channel_mean(self):
"""
return three values as mean of each channel
"""
mean = self.get_per_pixel_mean()
return np.mean(mean, axis=(0,1))
if __name__ == '__main__': if __name__ == '__main__':
ds = Cifar10('train') ds = Cifar10('train')
from dataflow.dftools import dump_dataset_images from tensorpack.dataflow.dftools import dump_dataset_images
mean = ds.get_per_channel_mean()
print mean
dump_dataset_images(ds, '/tmp/cifar', 100) dump_dataset_images(ds, '/tmp/cifar', 100)
#for (img, label) in ds.get_data(): #for (img, label) in ds.get_data():
#from IPython import embed; embed() #from IPython import embed; embed()
#break #break
......
...@@ -6,8 +6,9 @@ ...@@ -6,8 +6,9 @@
import sys, os import sys, os
from scipy.misc import imsave from scipy.misc import imsave
from utils.utils import mkdir_p from ..utils.utils import mkdir_p
# TODO name_func to write label?
def dump_dataset_images(ds, dirname, max_count=None, index=0): def dump_dataset_images(ds, dirname, max_count=None, index=0):
""" dump images to a folder """ dump images to a folder
index: the index of the image in a data point index: the index of the image in a data point
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment