Commit 2b6dce55 authored by ppwwyyxx's avatar ppwwyyxx

cifar10 dataset

parent e1b19c5d
*.gz
train_log
# Byte-compiled / optimized / DLL files
__pycache__/
......
......@@ -7,7 +7,10 @@ from pkgutil import walk_packages
import os
import os.path
__SKIP = ['dftools', 'dataset']
def global_import(name):
if name in __SKIP:
return
p = __import__(name, globals(), locals())
lst = p.__all__ if '__all__' in dir(p) else dir(p)
for k in lst:
......
......@@ -6,7 +6,7 @@
import numpy as np
from .base import DataFlow
__all__ = ['BatchData', 'FixedSizeData', 'FakeData']
__all__ = ['BatchData', 'FixedSizeData', 'FakeData', 'MapData']
class BatchData(DataFlow):
def __init__(self, ds, batch_size, remainder=False):
......@@ -43,10 +43,16 @@ class BatchData(DataFlow):
size = len(data_holder[0])
result = []
for k in xrange(size):
dt = data_holder[0][k]
if type(dt) in [int, bool, long]:
tp = 'int32'
elif type(dt) == float:
tp = 'float32'
else:
tp = dt.dtype
result.append(
np.array([x[k] for x in data_holder],
dtype=data_holder[0][k].dtype))
return tuple(result)
np.array([x[k] for x in data_holder], dtype=tp))
return result
class FixedSizeData(DataFlow):
def __init__(self, ds, size):
......@@ -76,5 +82,21 @@ class FakeData(DataFlow):
def get_data(self):
for _ in xrange(self._size):
yield tuple((np.random.random(k) for k in self.shapes))
yield [np.random.random(k) for k in self.shapes]
class MapData(DataFlow):
""" Apply a function to the given index in the datapoint"""
def __init__(self, ds, func, index=0):
self.ds = ds
self.func = func
self.index = index
def size(self):
return self.ds.size()
def get_data(self):
for dp in self.ds.get_data():
d = list(dp)
dp[self.index] = self.func(dp[self.index])
yield dp
mnist_data
cifar10_data
......@@ -8,6 +8,7 @@ import os
import os.path
def global_import(name):
print name
p = __import__(name, globals(), locals())
lst = p.__all__ if '__all__' in dir(p) else dir(p)
for k in lst:
......
#!/usr/bin/env python2
# -*- coding: UTF-8 -*-
# File: cifar10.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
#!/usr/bin/env python2
# -*- coding: UTF-8 -*-
# File: cifar10.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
import os, sys
import cPickle
import numpy
from six.moves import urllib
import tarfile
from utils import logger
from dataflow.base import DataFlow
__all__ = ['Cifar10']
DATA_URL = 'http://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz'
def maybe_download_and_extract(dest_directory):
"""Download and extract the tarball from Alex's website.
copied from tensorflow example """
if not os.path.exists(dest_directory):
os.makedirs(dest_directory)
filename = DATA_URL.split('/')[-1]
filepath = os.path.join(dest_directory, filename)
if os.path.isdir(os.path.join(dest_directory, 'cifar-10-batches-py')):
return
else:
def _progress(count, block_size, total_size):
sys.stdout.write('\r>> Downloading %s %.1f%%' % (filename,
float(count * block_size) / float(total_size) * 100.0))
sys.stdout.flush()
filepath, _ = urllib.request.urlretrieve(DATA_URL, filepath, reporthook=_progress)
print()
statinfo = os.stat(filepath)
print('Succesfully downloaded', filename, statinfo.st_size, 'bytes.')
tarfile.open(filepath, 'r:gz').extractall(dest_directory)
def read_cifar10(filenames):
for fname in filenames:
fo = open(fname, 'rb')
dic = cPickle.load(fo)
data = dic['data']
label = dic['labels']
fo.close()
for k in xrange(10000):
img = data[k].reshape(3, 32, 32)
img = numpy.transpose(img, [1, 2, 0])
yield [img, label[k]]
class Cifar10(DataFlow):
def __init__(self, train_or_test, dir=None):
"""
Args:
train_or_test: string either 'train' or 'test'
"""
assert train_or_test in ['train', 'test']
if dir is None:
dir = os.path.join(os.path.dirname(__file__), 'cifar10_data')
if train_or_test == 'train':
self.fs = [os.path.join(
dir, 'cifar-10-batches-py', 'data_batch_%d' % i) for i in xrange(1, 6)]
else:
self.fs = [os.path.join(dir, 'cifar-10-batches-py', 'test_batch')]
for f in self.fs:
if not os.path.isfile(f):
raise ValueError('Failed to find file: ' + f)
self.train_or_test = train_or_test
def size(self):
return 50000 if self.train_or_test == 'train' else 10000
def get_data(self):
for k in read_cifar10(self.fs):
yield k
if __name__ == '__main__':
ds = Cifar10('train')
from dataflow.dftools import dump_dataset_images
dump_dataset_images(ds, '/tmp/cifar', 100)
#for (img, label) in ds.get_data():
#from IPython import embed; embed()
#break
......@@ -10,10 +10,12 @@ import numpy
from six.moves import urllib
from utils import logger
from ..base import DataFlow
from dataflow.base import DataFlow
__all__ = ['Mnist']
""" This file is mostly copied from tensorflow example """
SOURCE_URL = 'http://yann.lecun.com/exdb/mnist/'
def maybe_download(filename, work_directory):
......@@ -133,7 +135,7 @@ class Mnist(DataFlow):
for k in xrange(ds.num_examples):
img = ds.images[k].reshape((28, 28))
label = ds.labels[k]
yield (img, label)
yield [img, label]
if __name__ == '__main__':
ds = Mnist('train')
......
#!/usr/bin/env python2
# -*- coding: UTF-8 -*-
# File: dftools.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
import sys, os
from scipy.misc import imsave
from utils.utils import mkdir_p
def dump_dataset_images(ds, dirname, max_count=None, index=0):
""" dump images to a folder
index: the index of the image in a data point
"""
mkdir_p(dirname)
if max_count is None:
max_count = sys.maxint
for i, dp in enumerate(ds.get_data()):
print i
if i > max_count:
return
img = dp[index]
imsave(os.path.join(dirname, "{}.jpg".format(i)), img)
......@@ -3,16 +3,11 @@
# File: example_mnist.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
# use user-space protobuf
import sys
import os
sys.path.insert(0, os.path.expanduser('~/.local/lib/python2.7/site-packages'))
import tensorflow as tf
from tensorflow.python.ops import control_flow_ops
import numpy as np
import os
import os, sys
from models import *
from utils import *
......
#!/usr/bin/env python2
# -*- coding: UTF-8 -*-
# File: usercustomize.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
# use user-space protobuf
import sys, os
site = os.path.join(os.environ['HOME'],
'.local/lib/python2.7/site-packages')
sys.path.insert(0, site)
......@@ -2,6 +2,7 @@
# -*- coding: UTF-8 -*-
# File: utils.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
import os
def expand_dim_if_necessary(var, dp):
"""
......@@ -17,3 +18,12 @@ def expand_dim_if_necessary(var, dp):
dp = dp.reshape(new_shape)
return dp
def mkdir_p(dirname):
if dirname == '':
return
try:
os.makedirs(dirname)
except OSError as e:
if e.errno != 17:
raise e
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment