Commit f227f45f authored by Yuxin Wu's avatar Yuxin Wu

Class names for cifar/fashion mnist (#863)

parent ea173d09
...@@ -4,7 +4,6 @@ ...@@ -4,7 +4,6 @@
from contextlib import contextmanager, ExitStack from contextlib import contextmanager, ExitStack
import numpy as np import numpy as np
import tensorflow as tf import tensorflow as tf
from tensorflow.contrib.framework import add_model_variable
from tensorpack.tfutils import argscope from tensorpack.tfutils import argscope
from tensorpack.tfutils.scope_utils import auto_reuse_variable_scope from tensorpack.tfutils.scope_utils import auto_reuse_variable_scope
...@@ -49,7 +48,7 @@ def freeze_affine_getter(getter, *args, **kwargs): ...@@ -49,7 +48,7 @@ def freeze_affine_getter(getter, *args, **kwargs):
if name.endswith('/gamma') or name.endswith('/beta'): if name.endswith('/gamma') or name.endswith('/beta'):
kwargs['trainable'] = False kwargs['trainable'] = False
ret = getter(*args, **kwargs) ret = getter(*args, **kwargs)
add_model_variable(ret) tf.add_to_collection(tf.GraphKeys.MODEL_VARIABLES, ret)
else: else:
ret = getter(*args, **kwargs) ret = getter(*args, **kwargs)
return ret return ret
......
...@@ -66,14 +66,22 @@ def read_cifar(filenames, cifar_classnum): ...@@ -66,14 +66,22 @@ def read_cifar(filenames, cifar_classnum):
def get_filenames(dir, cifar_classnum): def get_filenames(dir, cifar_classnum):
assert cifar_classnum == 10 or cifar_classnum == 100 assert cifar_classnum == 10 or cifar_classnum == 100
if cifar_classnum == 10: if cifar_classnum == 10:
filenames = [os.path.join( train_files = [os.path.join(
dir, 'cifar-10-batches-py', 'data_batch_%d' % i) for i in range(1, 6)] dir, 'cifar-10-batches-py', 'data_batch_%d' % i) for i in range(1, 6)]
filenames.append(os.path.join( test_files = [os.path.join(
dir, 'cifar-10-batches-py', 'test_batch')) dir, 'cifar-10-batches-py', 'test_batch')]
meta_file = os.path.join(dir, 'cifar-10-batches-py', 'batches.meta')
elif cifar_classnum == 100: elif cifar_classnum == 100:
filenames = [os.path.join(dir, 'cifar-100-python', 'train'), train_files = [os.path.join(dir, 'cifar-100-python', 'train')]
os.path.join(dir, 'cifar-100-python', 'test')] test_files = [os.path.join(dir, 'cifar-100-python', 'test')]
return filenames meta_file = os.path.join(dir, 'cifar-100-python', 'meta')
return train_files, test_files, meta_file
def _parse_meta(filename, cifar_classnum):
with open(filename, 'rb') as f:
obj = pickle.load(f)
return obj['label_names' if cifar_classnum == 10 else 'fine_label_names']
class CifarBase(RNGDataFlow): class CifarBase(RNGDataFlow):
...@@ -84,14 +92,15 @@ class CifarBase(RNGDataFlow): ...@@ -84,14 +92,15 @@ class CifarBase(RNGDataFlow):
if dir is None: if dir is None:
dir = get_dataset_path('cifar{}_data'.format(cifar_classnum)) dir = get_dataset_path('cifar{}_data'.format(cifar_classnum))
maybe_download_and_extract(dir, self.cifar_classnum) maybe_download_and_extract(dir, self.cifar_classnum)
fnames = get_filenames(dir, cifar_classnum) train_files, test_files, meta_file = get_filenames(dir, cifar_classnum)
if train_or_test == 'train': if train_or_test == 'train':
self.fs = fnames[:-1] self.fs = train_files
else: else:
self.fs = [fnames[-1]] self.fs = test_files
for f in self.fs: for f in self.fs:
if not os.path.isfile(f): if not os.path.isfile(f):
raise ValueError('Failed to find file: ' + f) raise ValueError('Failed to find file: ' + f)
self._label_names = _parse_meta(meta_file, cifar_classnum)
self.train_or_test = train_or_test self.train_or_test = train_or_test
self.data = read_cifar(self.fs, cifar_classnum) self.data = read_cifar(self.fs, cifar_classnum)
self.dir = dir self.dir = dir
...@@ -110,14 +119,22 @@ class CifarBase(RNGDataFlow): ...@@ -110,14 +119,22 @@ class CifarBase(RNGDataFlow):
def get_per_pixel_mean(self): def get_per_pixel_mean(self):
""" """
return a mean image of all (train and test) images of size 32x32x3 Returns:
a mean image of all (train and test) images of size 32x32x3
""" """
fnames = get_filenames(self.dir, self.cifar_classnum) train_files, test_files, _ = get_filenames(self.dir, self.cifar_classnum)
all_imgs = [x[0] for x in read_cifar(fnames, self.cifar_classnum)] all_imgs = [x[0] for x in read_cifar(train_files + test_files, self.cifar_classnum)]
arr = np.array(all_imgs, dtype='float32') arr = np.array(all_imgs, dtype='float32')
mean = np.mean(arr, axis=0) mean = np.mean(arr, axis=0)
return mean return mean
def get_label_names(self):
"""
Returns:
[str]: name of each class.
"""
return self._label_names
def get_per_channel_mean(self): def get_per_channel_mean(self):
""" """
return three values as mean of each channel return three values as mean of each channel
......
...@@ -67,8 +67,8 @@ class Mnist(RNGDataFlow): ...@@ -67,8 +67,8 @@ class Mnist(RNGDataFlow):
image is 28x28 in the range [0,1], label is an int. image is 28x28 in the range [0,1], label is an int.
""" """
DIR_NAME = 'mnist_data' _DIR_NAME = 'mnist_data'
SOURCE_URL = 'http://yann.lecun.com/exdb/mnist/' _SOURCE_URL = 'http://yann.lecun.com/exdb/mnist/'
def __init__(self, train_or_test, shuffle=True, dir=None): def __init__(self, train_or_test, shuffle=True, dir=None):
""" """
...@@ -77,15 +77,15 @@ class Mnist(RNGDataFlow): ...@@ -77,15 +77,15 @@ class Mnist(RNGDataFlow):
shuffle (bool): shuffle the dataset shuffle (bool): shuffle the dataset
""" """
if dir is None: if dir is None:
dir = get_dataset_path(self.DIR_NAME) dir = get_dataset_path(self._DIR_NAME)
assert train_or_test in ['train', 'test'] assert train_or_test in ['train', 'test']
self.train_or_test = train_or_test self.train_or_test = train_or_test
self.shuffle = shuffle self.shuffle = shuffle
def get_images_and_labels(image_file, label_file): def get_images_and_labels(image_file, label_file):
f = maybe_download(self.SOURCE_URL + image_file, dir) f = maybe_download(self._SOURCE_URL + image_file, dir)
images = extract_images(f) images = extract_images(f)
f = maybe_download(self.SOURCE_URL + label_file, dir) f = maybe_download(self._SOURCE_URL + label_file, dir)
labels = extract_labels(f) labels = extract_labels(f)
assert images.shape[0] == labels.shape[0] assert images.shape[0] == labels.shape[0]
return images, labels return images, labels
...@@ -113,8 +113,21 @@ class Mnist(RNGDataFlow): ...@@ -113,8 +113,21 @@ class Mnist(RNGDataFlow):
class FashionMnist(Mnist): class FashionMnist(Mnist):
DIR_NAME = 'fashion_mnist_data' """
SOURCE_URL = 'http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/' Same API as :class:`Mnist`, but more fashion.
"""
_DIR_NAME = 'fashion_mnist_data'
_SOURCE_URL = 'http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/'
def get_label_names(self):
"""
Returns:
[str]: the name of each class
"""
# copied from https://github.com/zalandoresearch/fashion-mnist
return ['T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat',
'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle boot']
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -64,7 +64,8 @@ class SVHNDigit(RNGDataFlow): ...@@ -64,7 +64,8 @@ class SVHNDigit(RNGDataFlow):
@staticmethod @staticmethod
def get_per_pixel_mean(): def get_per_pixel_mean():
""" """
return 32x32x3 image Returns:
a 32x32x3 image
""" """
a = SVHNDigit('train') a = SVHNDigit('train')
b = SVHNDigit('test') b = SVHNDigit('test')
......
...@@ -3,7 +3,6 @@ ...@@ -3,7 +3,6 @@
import tensorflow as tf import tensorflow as tf
from tensorflow.contrib.framework import add_model_variable
from tensorflow.python.training import moving_averages from tensorflow.python.training import moving_averages
import re import re
import six import six
...@@ -191,7 +190,7 @@ def BatchNorm(inputs, axis=None, training=None, momentum=0.9, epsilon=1e-5, ...@@ -191,7 +190,7 @@ def BatchNorm(inputs, axis=None, training=None, momentum=0.9, epsilon=1e-5,
if ctx.is_main_training_tower: if ctx.is_main_training_tower:
for v in layer.non_trainable_variables: for v in layer.non_trainable_variables:
if isinstance(v, tf.Variable): if isinstance(v, tf.Variable):
add_model_variable(v) tf.add_to_collection(tf.GraphKeys.MODEL_VARIABLES, v)
if not ctx.is_main_training_tower or internal_update: if not ctx.is_main_training_tower or internal_update:
restore_collection(coll_bk) restore_collection(coll_bk)
...@@ -354,7 +353,7 @@ def BatchRenorm(x, rmax, dmax, momentum=0.9, epsilon=1e-5, ...@@ -354,7 +353,7 @@ def BatchRenorm(x, rmax, dmax, momentum=0.9, epsilon=1e-5,
if ctx.is_main_training_tower: if ctx.is_main_training_tower:
for v in layer.non_trainable_variables: for v in layer.non_trainable_variables:
if isinstance(v, tf.Variable): if isinstance(v, tf.Variable):
add_model_variable(v) tf.add_to_collection(tf.GraphKeys.MODEL_VARIABLES, v)
else: else:
# only run UPDATE_OPS in the first tower # only run UPDATE_OPS in the first tower
restore_collection(coll_bk) restore_collection(coll_bk)
......
...@@ -3,7 +3,6 @@ ...@@ -3,7 +3,6 @@
# Credit: Qinyao He # Credit: Qinyao He
import tensorflow as tf import tensorflow as tf
from tensorflow.contrib.framework import add_model_variable
from contextlib import contextmanager from contextlib import contextmanager
from .common import get_tf_version_tuple from .common import get_tf_version_tuple
...@@ -13,6 +12,13 @@ __all__ = ['custom_getter_scope', 'freeze_variables', 'remap_variables'] ...@@ -13,6 +12,13 @@ __all__ = ['custom_getter_scope', 'freeze_variables', 'remap_variables']
@contextmanager @contextmanager
def custom_getter_scope(custom_getter): def custom_getter_scope(custom_getter):
"""
Args:
custom_getter: the same as in :func:`tf.get_variable`
Returns:
The current variable scope with a custom_getter.
"""
scope = tf.get_variable_scope() scope = tf.get_variable_scope()
if get_tf_version_tuple() >= (1, 5): if get_tf_version_tuple() >= (1, 5):
with tf.variable_scope( with tf.variable_scope(
...@@ -35,7 +41,8 @@ def remap_variables(fn): ...@@ -35,7 +41,8 @@ def remap_variables(fn):
fn (tf.Variable -> tf.Tensor) fn (tf.Variable -> tf.Tensor)
Returns: Returns:
a context where all the variables will be mapped by fn. The current variable scope with a custom_getter that maps
all the variables by fn.
Example: Example:
.. code-block:: python .. code-block:: python
...@@ -83,7 +90,7 @@ def freeze_variables(stop_gradient=True, skip_collection=False): ...@@ -83,7 +90,7 @@ def freeze_variables(stop_gradient=True, skip_collection=False):
kwargs['trainable'] = False kwargs['trainable'] = False
v = getter(*args, **kwargs) v = getter(*args, **kwargs)
if skip_collection: if skip_collection:
add_model_variable(v) tf.add_to_collection(tf.GraphKeys.MODEL_VARIABLES, v)
if trainable and stop_gradient: if trainable and stop_gradient:
v = tf.stop_gradient(v, name='freezed_' + name) v = tf.stop_gradient(v, name='freezed_' + name)
return v return v
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment