Commit 61b79a46 authored by Yuxin Wu's avatar Yuxin Wu

hdf5 data

parent b06fa732
......@@ -88,10 +88,10 @@ class Model(ModelDesc):
add_param_summary([('.*/W', ['histogram', 'sparsity'])]) # monitor W
return tf.add_n([cost, wd_cost], name='cost')
def get_config():
#anchors = np.mgrid[0:4,0:4][:,1:,1:].transpose(1,2,0).reshape((-1,2)) / 4.0
# prepare dataset
dataset_train = dataset.Cifar10('train')
def get_data(train_or_test):
isTrain = train_or_test == 'train'
ds = dataset.Cifar10(train_or_test)
if isTrain:
augmentors = [
imgaug.RandomCrop((30, 30)),
imgaug.Flip(horiz=True),
......@@ -102,19 +102,24 @@ def get_config():
(30,30), 0.2, 3),
imgaug.MeanVarianceNormalize(all_channel=True)
]
dataset_train = AugmentImageComponent(dataset_train, augmentors)
dataset_train = BatchData(dataset_train, 128)
dataset_train = PrefetchData(dataset_train, 3, 2)
step_per_epoch = dataset_train.size() / 2
step_per_epoch = 10
else:
augmentors = [
imgaug.CenterCrop((30, 30)),
imgaug.MeanVarianceNormalize(all_channel=True)
]
dataset_test = dataset.Cifar10('test')
dataset_test = AugmentImageComponent(dataset_test, augmentors)
dataset_test = BatchData(dataset_test, 128, remainder=True)
ds = AugmentImageComponent(ds, augmentors)
ds = BatchData(ds, 128, remainder=not isTrain)
if isTrain:
ds = PrefetchData(ds, 3, 2)
return ds
def get_config():
# prepare dataset
dataset_train = get_data('train')
step_per_epoch = dataset_train.size() / 2
dataset_test = get_data('test')
sess_config = get_default_sess_config()
sess_config.gpu_options.per_process_gpu_memory_fraction = 0.5
......
# -*- coding: utf-8 -*-
# File: format.py
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
import h5py
import random
from six.moves import range
from .base import DataFlow
"""
Adapter for different data format.
"""
__all__ = ['HDF5Data']
class HDF5Data(DataFlow):
"""
Zip data from different paths in this HDF5 data file
"""
def __init__(self, filename, data_paths, shuffle=True):
self.f = h5py.File(filename, 'r')
self.dps = [self.f[k] for k in data_paths]
lens = [len(k) for k in self.dps]
assert all([k==lens[0] for k in lens])
self._size = lens[0]
self.shuffle = shuffle
def size(self):
return self._size
def get_data(self):
idxs = list(range(self._size))
if self.shuffle:
random.shuffle(idxs)
for k in idxs:
yield [dp[k] for dp in self.dps]
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment