Commit f227f45f authored by Yuxin Wu's avatar Yuxin Wu

Class names for cifar/fashion mnist (#863)

parent ea173d09
......@@ -4,7 +4,6 @@
from contextlib import contextmanager, ExitStack
import numpy as np
import tensorflow as tf
from tensorflow.contrib.framework import add_model_variable
from tensorpack.tfutils import argscope
from tensorpack.tfutils.scope_utils import auto_reuse_variable_scope
......@@ -49,7 +48,7 @@ def freeze_affine_getter(getter, *args, **kwargs):
if name.endswith('/gamma') or name.endswith('/beta'):
kwargs['trainable'] = False
ret = getter(*args, **kwargs)
add_model_variable(ret)
tf.add_to_collection(tf.GraphKeys.MODEL_VARIABLES, ret)
else:
ret = getter(*args, **kwargs)
return ret
......
......@@ -66,14 +66,22 @@ def read_cifar(filenames, cifar_classnum):
def get_filenames(dir, cifar_classnum):
assert cifar_classnum == 10 or cifar_classnum == 100
if cifar_classnum == 10:
filenames = [os.path.join(
train_files = [os.path.join(
dir, 'cifar-10-batches-py', 'data_batch_%d' % i) for i in range(1, 6)]
filenames.append(os.path.join(
dir, 'cifar-10-batches-py', 'test_batch'))
test_files = [os.path.join(
dir, 'cifar-10-batches-py', 'test_batch')]
meta_file = os.path.join(dir, 'cifar-10-batches-py', 'batches.meta')
elif cifar_classnum == 100:
filenames = [os.path.join(dir, 'cifar-100-python', 'train'),
os.path.join(dir, 'cifar-100-python', 'test')]
return filenames
train_files = [os.path.join(dir, 'cifar-100-python', 'train')]
test_files = [os.path.join(dir, 'cifar-100-python', 'test')]
meta_file = os.path.join(dir, 'cifar-100-python', 'meta')
return train_files, test_files, meta_file
def _parse_meta(filename, cifar_classnum):
with open(filename, 'rb') as f:
obj = pickle.load(f)
return obj['label_names' if cifar_classnum == 10 else 'fine_label_names']
class CifarBase(RNGDataFlow):
......@@ -84,14 +92,15 @@ class CifarBase(RNGDataFlow):
if dir is None:
dir = get_dataset_path('cifar{}_data'.format(cifar_classnum))
maybe_download_and_extract(dir, self.cifar_classnum)
fnames = get_filenames(dir, cifar_classnum)
train_files, test_files, meta_file = get_filenames(dir, cifar_classnum)
if train_or_test == 'train':
self.fs = fnames[:-1]
self.fs = train_files
else:
self.fs = [fnames[-1]]
self.fs = test_files
for f in self.fs:
if not os.path.isfile(f):
raise ValueError('Failed to find file: ' + f)
self._label_names = _parse_meta(meta_file, cifar_classnum)
self.train_or_test = train_or_test
self.data = read_cifar(self.fs, cifar_classnum)
self.dir = dir
......@@ -110,14 +119,22 @@ class CifarBase(RNGDataFlow):
def get_per_pixel_mean(self):
"""
return a mean image of all (train and test) images of size 32x32x3
Returns:
a mean image of all (train and test) images of size 32x32x3
"""
fnames = get_filenames(self.dir, self.cifar_classnum)
all_imgs = [x[0] for x in read_cifar(fnames, self.cifar_classnum)]
train_files, test_files, _ = get_filenames(self.dir, self.cifar_classnum)
all_imgs = [x[0] for x in read_cifar(train_files + test_files, self.cifar_classnum)]
arr = np.array(all_imgs, dtype='float32')
mean = np.mean(arr, axis=0)
return mean
def get_label_names(self):
"""
Returns:
[str]: name of each class.
"""
return self._label_names
def get_per_channel_mean(self):
"""
return three values as mean of each channel
......
......@@ -67,8 +67,8 @@ class Mnist(RNGDataFlow):
image is 28x28 in the range [0,1], label is an int.
"""
DIR_NAME = 'mnist_data'
SOURCE_URL = 'http://yann.lecun.com/exdb/mnist/'
_DIR_NAME = 'mnist_data'
_SOURCE_URL = 'http://yann.lecun.com/exdb/mnist/'
def __init__(self, train_or_test, shuffle=True, dir=None):
"""
......@@ -77,15 +77,15 @@ class Mnist(RNGDataFlow):
shuffle (bool): shuffle the dataset
"""
if dir is None:
dir = get_dataset_path(self.DIR_NAME)
dir = get_dataset_path(self._DIR_NAME)
assert train_or_test in ['train', 'test']
self.train_or_test = train_or_test
self.shuffle = shuffle
def get_images_and_labels(image_file, label_file):
f = maybe_download(self.SOURCE_URL + image_file, dir)
f = maybe_download(self._SOURCE_URL + image_file, dir)
images = extract_images(f)
f = maybe_download(self.SOURCE_URL + label_file, dir)
f = maybe_download(self._SOURCE_URL + label_file, dir)
labels = extract_labels(f)
assert images.shape[0] == labels.shape[0]
return images, labels
......@@ -113,8 +113,21 @@ class Mnist(RNGDataFlow):
class FashionMnist(Mnist):
DIR_NAME = 'fashion_mnist_data'
SOURCE_URL = 'http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/'
"""
Same API as :class:`Mnist`, but more fashion.
"""
_DIR_NAME = 'fashion_mnist_data'
_SOURCE_URL = 'http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/'
def get_label_names(self):
"""
Returns:
[str]: the name of each class
"""
# copied from https://github.com/zalandoresearch/fashion-mnist
return ['T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat',
'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle boot']
if __name__ == '__main__':
......
......@@ -64,7 +64,8 @@ class SVHNDigit(RNGDataFlow):
@staticmethod
def get_per_pixel_mean():
"""
return 32x32x3 image
Returns:
a 32x32x3 image
"""
a = SVHNDigit('train')
b = SVHNDigit('test')
......
......@@ -3,7 +3,6 @@
import tensorflow as tf
from tensorflow.contrib.framework import add_model_variable
from tensorflow.python.training import moving_averages
import re
import six
......@@ -191,7 +190,7 @@ def BatchNorm(inputs, axis=None, training=None, momentum=0.9, epsilon=1e-5,
if ctx.is_main_training_tower:
for v in layer.non_trainable_variables:
if isinstance(v, tf.Variable):
add_model_variable(v)
tf.add_to_collection(tf.GraphKeys.MODEL_VARIABLES, v)
if not ctx.is_main_training_tower or internal_update:
restore_collection(coll_bk)
......@@ -354,7 +353,7 @@ def BatchRenorm(x, rmax, dmax, momentum=0.9, epsilon=1e-5,
if ctx.is_main_training_tower:
for v in layer.non_trainable_variables:
if isinstance(v, tf.Variable):
add_model_variable(v)
tf.add_to_collection(tf.GraphKeys.MODEL_VARIABLES, v)
else:
# only run UPDATE_OPS in the first tower
restore_collection(coll_bk)
......
......@@ -3,7 +3,6 @@
# Credit: Qinyao He
import tensorflow as tf
from tensorflow.contrib.framework import add_model_variable
from contextlib import contextmanager
from .common import get_tf_version_tuple
......@@ -13,6 +12,13 @@ __all__ = ['custom_getter_scope', 'freeze_variables', 'remap_variables']
@contextmanager
def custom_getter_scope(custom_getter):
"""
Args:
custom_getter: the same as in :func:`tf.get_variable`
Returns:
The current variable scope with a custom_getter.
"""
scope = tf.get_variable_scope()
if get_tf_version_tuple() >= (1, 5):
with tf.variable_scope(
......@@ -35,7 +41,8 @@ def remap_variables(fn):
fn (tf.Variable -> tf.Tensor)
Returns:
a context where all the variables will be mapped by fn.
The current variable scope with a custom_getter that maps
all the variables by fn.
Example:
.. code-block:: python
......@@ -83,7 +90,7 @@ def freeze_variables(stop_gradient=True, skip_collection=False):
kwargs['trainable'] = False
v = getter(*args, **kwargs)
if skip_collection:
add_model_variable(v)
tf.add_to_collection(tf.GraphKeys.MODEL_VARIABLES, v)
if trainable and stop_gradient:
v = tf.stop_gradient(v, name='freezed_' + name)
return v
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment