Commit d24a9230 authored by Yuxin Wu's avatar Yuxin Wu

Only use the mean of training set (fix #1072)

parent 497f25b1
......@@ -118,7 +118,7 @@ class Model(ModelDesc):
def get_data(train_or_test):
isTrain = train_or_test == 'train'
ds = dataset.Cifar10(train_or_test)
pp_mean = ds.get_per_pixel_mean()
pp_mean = ds.get_per_pixel_mean(('train',))
if isTrain:
augmentors = [
imgaug.CenterPaste((40, 40)),
......
......@@ -132,13 +132,23 @@ class CifarBase(RNGDataFlow):
# since cifar is quite small, just do it for safety
yield self.data[k]
def get_per_pixel_mean(self):
def get_per_pixel_mean(self, names=('train', 'test')):
"""
Args:
names (tuple[str]): the names ('train' or 'test') of the datasets
Returns:
a mean image of all (train and test) images of size 32x32x3
a mean image of all images in the given datasets, with size 32x32x3
"""
for name in names:
assert name in ['train', 'test'], name
train_files, test_files, _ = get_filenames(self.dir, self.cifar_classnum)
all_imgs = [x[0] for x in read_cifar(train_files + test_files, self.cifar_classnum)]
all_files = []
if 'train' in names:
all_files.extend(train_files)
if 'test' in names:
all_files.extend(test_files)
all_imgs = [x[0] for x in read_cifar(all_files, self.cifar_classnum)]
arr = np.array(all_imgs, dtype='float32')
mean = np.mean(arr, axis=0)
return mean
......@@ -150,11 +160,15 @@ class CifarBase(RNGDataFlow):
"""
return self._label_names
def get_per_channel_mean(self):
def get_per_channel_mean(self, names=('train', 'test')):
"""
return three values as mean of each channel
Args:
names (tuple[str]): the names ('train' or 'test') of the datasets
Returns:
An array of three values as mean of each channel, for all images in the given datasets.
"""
mean = self.get_per_pixel_mean()
mean = self.get_per_pixel_mean(names)
return np.mean(mean, axis=(0, 1))
......
......@@ -62,15 +62,18 @@ class SVHNDigit(RNGDataFlow):
yield [self.X[k], self.Y[k]]
@staticmethod
def get_per_pixel_mean():
def get_per_pixel_mean(names=('train', 'test', 'extra')):
"""
Args:
names (tuple[str]): names of the dataset split
Returns:
a 32x32x3 image
a 32x32x3 image, the mean of all images in the given datasets
"""
a = SVHNDigit('train')
b = SVHNDigit('test')
c = SVHNDigit('extra')
return np.concatenate((a.X, b.X, c.X)).mean(axis=0)
for name in names:
assert name in ['train', 'test', 'extra'], name
images = [SVHNDigit(x).X for x in names]
return np.concatenate(tuple(images)).mean(axis=0)
try:
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment