Commit 0a012166 authored by Yuxin Wu's avatar Yuxin Wu

cifar10 mean

parent bac11ae3
......@@ -4,7 +4,7 @@
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
import os, sys
import pickle
import numpy
import numpy as np
from six.moves import urllib
import tarfile
import logging
......@@ -39,6 +39,7 @@ def maybe_download_and_extract(dest_directory):
tarfile.open(filepath, 'r:gz').extractall(dest_directory)
def read_cifar10(filenames):
ret = []
for fname in filenames:
fo = open(fname, 'rb')
dic = pickle.load(fo)
......@@ -47,8 +48,16 @@ def read_cifar10(filenames):
fo.close()
for k in xrange(10000):
img = data[k].reshape(3, 32, 32)
img = numpy.transpose(img, [1, 2, 0])
yield [img, label[k]]
img = np.transpose(img, [1, 2, 0])
ret.append([img, label[k]])
return ret
def get_filenames(dir):
filenames = [os.path.join(
dir, 'cifar-10-batches-py', 'data_batch_%d' % i) for i in xrange(1, 6)]
filenames.append(os.path.join(
dir, 'cifar-10-batches-py', 'test_batch'))
return filenames
class Cifar10(DataFlow):
"""
......@@ -65,27 +74,49 @@ class Cifar10(DataFlow):
dir = os.path.join(os.path.dirname(__file__), 'cifar10_data')
maybe_download_and_extract(dir)
fnames = get_filenames(dir)
if train_or_test == 'train':
self.fs = [os.path.join(
dir, 'cifar-10-batches-py', 'data_batch_%d' % i) for i in xrange(1, 6)]
self.fs = fnames[:5]
else:
self.fs = [os.path.join(dir, 'cifar-10-batches-py', 'test_batch')]
self.fs = fnames[-1]
for f in self.fs:
if not os.path.isfile(f):
raise ValueError('Failed to find file: ' + f)
self.train_or_test = train_or_test
self.dir = dir
self.data = read_cifar10(self.fs)
def size(self):
return 50000 if self.train_or_test == 'train' else 10000
def get_data(self):
for k in read_cifar10(self.fs):
for k in self.data:
yield k
def get_per_pixel_mean(self):
"""
return a mean image of all (train and test) images of size 32x32x3
"""
fnames = get_filenames(self.dir)
all_imgs = [x[0] for x in read_cifar10(fnames)]
arr = np.array(all_imgs, dtype='float32')
mean = np.mean(arr, axis=0)
return mean
def get_per_channel_mean(self):
"""
return three values as mean of each channel
"""
mean = self.get_per_pixel_mean()
return np.mean(mean, axis=(0,1))
if __name__ == '__main__':
ds = Cifar10('train')
from dataflow.dftools import dump_dataset_images
from tensorpack.dataflow.dftools import dump_dataset_images
mean = ds.get_per_channel_mean()
print mean
dump_dataset_images(ds, '/tmp/cifar', 100)
#for (img, label) in ds.get_data():
#from IPython import embed; embed()
#break
......
......@@ -6,8 +6,9 @@
import sys, os
from scipy.misc import imsave
from utils.utils import mkdir_p
from ..utils.utils import mkdir_p
# TODO name_func to write label?
def dump_dataset_images(ds, dirname, max_count=None, index=0):
""" dump images to a folder
index: the index of the image in a data point
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment