Commit d24a9230 authored by Yuxin Wu's avatar Yuxin Wu

Only use the mean of training set (fix #1072)

parent 497f25b1
...@@ -118,7 +118,7 @@ class Model(ModelDesc): ...@@ -118,7 +118,7 @@ class Model(ModelDesc):
def get_data(train_or_test): def get_data(train_or_test):
isTrain = train_or_test == 'train' isTrain = train_or_test == 'train'
ds = dataset.Cifar10(train_or_test) ds = dataset.Cifar10(train_or_test)
pp_mean = ds.get_per_pixel_mean() pp_mean = ds.get_per_pixel_mean(('train',))
if isTrain: if isTrain:
augmentors = [ augmentors = [
imgaug.CenterPaste((40, 40)), imgaug.CenterPaste((40, 40)),
......
...@@ -132,13 +132,23 @@ class CifarBase(RNGDataFlow): ...@@ -132,13 +132,23 @@ class CifarBase(RNGDataFlow):
# since cifar is quite small, just do it for safety # since cifar is quite small, just do it for safety
yield self.data[k] yield self.data[k]
def get_per_pixel_mean(self): def get_per_pixel_mean(self, names=('train', 'test')):
""" """
Args:
names (tuple[str]): the names ('train' or 'test') of the datasets
Returns: Returns:
a mean image of all (train and test) images of size 32x32x3 a mean image of all images in the given datasets, with size 32x32x3
""" """
for name in names:
assert name in ['train', 'test'], name
train_files, test_files, _ = get_filenames(self.dir, self.cifar_classnum) train_files, test_files, _ = get_filenames(self.dir, self.cifar_classnum)
all_imgs = [x[0] for x in read_cifar(train_files + test_files, self.cifar_classnum)] all_files = []
if 'train' in names:
all_files.extend(train_files)
if 'test' in names:
all_files.extend(test_files)
all_imgs = [x[0] for x in read_cifar(all_files, self.cifar_classnum)]
arr = np.array(all_imgs, dtype='float32') arr = np.array(all_imgs, dtype='float32')
mean = np.mean(arr, axis=0) mean = np.mean(arr, axis=0)
return mean return mean
...@@ -150,11 +160,15 @@ class CifarBase(RNGDataFlow): ...@@ -150,11 +160,15 @@ class CifarBase(RNGDataFlow):
""" """
return self._label_names return self._label_names
def get_per_channel_mean(self): def get_per_channel_mean(self, names=('train', 'test')):
""" """
return three values as mean of each channel Args:
names (tuple[str]): the names ('train' or 'test') of the datasets
Returns:
An array of three values as mean of each channel, for all images in the given datasets.
""" """
mean = self.get_per_pixel_mean() mean = self.get_per_pixel_mean(names)
return np.mean(mean, axis=(0, 1)) return np.mean(mean, axis=(0, 1))
......
...@@ -62,15 +62,18 @@ class SVHNDigit(RNGDataFlow): ...@@ -62,15 +62,18 @@ class SVHNDigit(RNGDataFlow):
yield [self.X[k], self.Y[k]] yield [self.X[k], self.Y[k]]
@staticmethod @staticmethod
def get_per_pixel_mean(): def get_per_pixel_mean(names=('train', 'test', 'extra')):
""" """
Args:
names (tuple[str]): names of the dataset split
Returns: Returns:
a 32x32x3 image a 32x32x3 image, the mean of all images in the given datasets
""" """
a = SVHNDigit('train') for name in names:
b = SVHNDigit('test') assert name in ['train', 'test', 'extra'], name
c = SVHNDigit('extra') images = [SVHNDigit(x).X for x in names]
return np.concatenate((a.X, b.X, c.X)).mean(axis=0) return np.concatenate(tuple(images)).mean(axis=0)
try: try:
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment