Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
f227f45f
Commit
f227f45f
authored
Aug 15, 2018
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Class names for cifar/fashion mnist (#863)
parent
ea173d09
Changes
6
Show whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
64 additions
and
28 deletions
+64
-28
examples/FasterRCNN/basemodel.py
examples/FasterRCNN/basemodel.py
+1
-2
tensorpack/dataflow/dataset/cifar.py
tensorpack/dataflow/dataset/cifar.py
+29
-12
tensorpack/dataflow/dataset/mnist.py
tensorpack/dataflow/dataset/mnist.py
+20
-7
tensorpack/dataflow/dataset/svhn.py
tensorpack/dataflow/dataset/svhn.py
+2
-1
tensorpack/models/batch_norm.py
tensorpack/models/batch_norm.py
+2
-3
tensorpack/tfutils/varreplace.py
tensorpack/tfutils/varreplace.py
+10
-3
No files found.
examples/FasterRCNN/basemodel.py
View file @
f227f45f
...
@@ -4,7 +4,6 @@
...
@@ -4,7 +4,6 @@
from
contextlib
import
contextmanager
,
ExitStack
from
contextlib
import
contextmanager
,
ExitStack
import
numpy
as
np
import
numpy
as
np
import
tensorflow
as
tf
import
tensorflow
as
tf
from
tensorflow.contrib.framework
import
add_model_variable
from
tensorpack.tfutils
import
argscope
from
tensorpack.tfutils
import
argscope
from
tensorpack.tfutils.scope_utils
import
auto_reuse_variable_scope
from
tensorpack.tfutils.scope_utils
import
auto_reuse_variable_scope
...
@@ -49,7 +48,7 @@ def freeze_affine_getter(getter, *args, **kwargs):
...
@@ -49,7 +48,7 @@ def freeze_affine_getter(getter, *args, **kwargs):
if
name
.
endswith
(
'/gamma'
)
or
name
.
endswith
(
'/beta'
):
if
name
.
endswith
(
'/gamma'
)
or
name
.
endswith
(
'/beta'
):
kwargs
[
'trainable'
]
=
False
kwargs
[
'trainable'
]
=
False
ret
=
getter
(
*
args
,
**
kwargs
)
ret
=
getter
(
*
args
,
**
kwargs
)
add_model_variable
(
ret
)
tf
.
add_to_collection
(
tf
.
GraphKeys
.
MODEL_VARIABLES
,
ret
)
else
:
else
:
ret
=
getter
(
*
args
,
**
kwargs
)
ret
=
getter
(
*
args
,
**
kwargs
)
return
ret
return
ret
...
...
tensorpack/dataflow/dataset/cifar.py
View file @
f227f45f
...
@@ -66,14 +66,22 @@ def read_cifar(filenames, cifar_classnum):
...
@@ -66,14 +66,22 @@ def read_cifar(filenames, cifar_classnum):
def
get_filenames
(
dir
,
cifar_classnum
):
def
get_filenames
(
dir
,
cifar_classnum
):
assert
cifar_classnum
==
10
or
cifar_classnum
==
100
assert
cifar_classnum
==
10
or
cifar_classnum
==
100
if
cifar_classnum
==
10
:
if
cifar_classnum
==
10
:
filenam
es
=
[
os
.
path
.
join
(
train_fil
es
=
[
os
.
path
.
join
(
dir
,
'cifar-10-batches-py'
,
'data_batch_
%
d'
%
i
)
for
i
in
range
(
1
,
6
)]
dir
,
'cifar-10-batches-py'
,
'data_batch_
%
d'
%
i
)
for
i
in
range
(
1
,
6
)]
filenames
.
append
(
os
.
path
.
join
(
test_files
=
[
os
.
path
.
join
(
dir
,
'cifar-10-batches-py'
,
'test_batch'
))
dir
,
'cifar-10-batches-py'
,
'test_batch'
)]
meta_file
=
os
.
path
.
join
(
dir
,
'cifar-10-batches-py'
,
'batches.meta'
)
elif
cifar_classnum
==
100
:
elif
cifar_classnum
==
100
:
filenames
=
[
os
.
path
.
join
(
dir
,
'cifar-100-python'
,
'train'
),
train_files
=
[
os
.
path
.
join
(
dir
,
'cifar-100-python'
,
'train'
)]
os
.
path
.
join
(
dir
,
'cifar-100-python'
,
'test'
)]
test_files
=
[
os
.
path
.
join
(
dir
,
'cifar-100-python'
,
'test'
)]
return
filenames
meta_file
=
os
.
path
.
join
(
dir
,
'cifar-100-python'
,
'meta'
)
return
train_files
,
test_files
,
meta_file
def
_parse_meta
(
filename
,
cifar_classnum
):
with
open
(
filename
,
'rb'
)
as
f
:
obj
=
pickle
.
load
(
f
)
return
obj
[
'label_names'
if
cifar_classnum
==
10
else
'fine_label_names'
]
class
CifarBase
(
RNGDataFlow
):
class
CifarBase
(
RNGDataFlow
):
...
@@ -84,14 +92,15 @@ class CifarBase(RNGDataFlow):
...
@@ -84,14 +92,15 @@ class CifarBase(RNGDataFlow):
if
dir
is
None
:
if
dir
is
None
:
dir
=
get_dataset_path
(
'cifar{}_data'
.
format
(
cifar_classnum
))
dir
=
get_dataset_path
(
'cifar{}_data'
.
format
(
cifar_classnum
))
maybe_download_and_extract
(
dir
,
self
.
cifar_classnum
)
maybe_download_and_extract
(
dir
,
self
.
cifar_classnum
)
fnames
=
get_filenames
(
dir
,
cifar_classnum
)
train_files
,
test_files
,
meta_file
=
get_filenames
(
dir
,
cifar_classnum
)
if
train_or_test
==
'train'
:
if
train_or_test
==
'train'
:
self
.
fs
=
fnames
[:
-
1
]
self
.
fs
=
train_files
else
:
else
:
self
.
fs
=
[
fnames
[
-
1
]]
self
.
fs
=
test_files
for
f
in
self
.
fs
:
for
f
in
self
.
fs
:
if
not
os
.
path
.
isfile
(
f
):
if
not
os
.
path
.
isfile
(
f
):
raise
ValueError
(
'Failed to find file: '
+
f
)
raise
ValueError
(
'Failed to find file: '
+
f
)
self
.
_label_names
=
_parse_meta
(
meta_file
,
cifar_classnum
)
self
.
train_or_test
=
train_or_test
self
.
train_or_test
=
train_or_test
self
.
data
=
read_cifar
(
self
.
fs
,
cifar_classnum
)
self
.
data
=
read_cifar
(
self
.
fs
,
cifar_classnum
)
self
.
dir
=
dir
self
.
dir
=
dir
...
@@ -110,14 +119,22 @@ class CifarBase(RNGDataFlow):
...
@@ -110,14 +119,22 @@ class CifarBase(RNGDataFlow):
def
get_per_pixel_mean
(
self
):
def
get_per_pixel_mean
(
self
):
"""
"""
return a mean image of all (train and test) images of size 32x32x3
Returns:
a mean image of all (train and test) images of size 32x32x3
"""
"""
fnames
=
get_filenames
(
self
.
dir
,
self
.
cifar_classnum
)
train_files
,
test_files
,
_
=
get_filenames
(
self
.
dir
,
self
.
cifar_classnum
)
all_imgs
=
[
x
[
0
]
for
x
in
read_cifar
(
fnam
es
,
self
.
cifar_classnum
)]
all_imgs
=
[
x
[
0
]
for
x
in
read_cifar
(
train_files
+
test_fil
es
,
self
.
cifar_classnum
)]
arr
=
np
.
array
(
all_imgs
,
dtype
=
'float32'
)
arr
=
np
.
array
(
all_imgs
,
dtype
=
'float32'
)
mean
=
np
.
mean
(
arr
,
axis
=
0
)
mean
=
np
.
mean
(
arr
,
axis
=
0
)
return
mean
return
mean
def
get_label_names
(
self
):
"""
Returns:
[str]: name of each class.
"""
return
self
.
_label_names
def
get_per_channel_mean
(
self
):
def
get_per_channel_mean
(
self
):
"""
"""
return three values as mean of each channel
return three values as mean of each channel
...
...
tensorpack/dataflow/dataset/mnist.py
View file @
f227f45f
...
@@ -67,8 +67,8 @@ class Mnist(RNGDataFlow):
...
@@ -67,8 +67,8 @@ class Mnist(RNGDataFlow):
image is 28x28 in the range [0,1], label is an int.
image is 28x28 in the range [0,1], label is an int.
"""
"""
DIR_NAME
=
'mnist_data'
_
DIR_NAME
=
'mnist_data'
SOURCE_URL
=
'http://yann.lecun.com/exdb/mnist/'
_
SOURCE_URL
=
'http://yann.lecun.com/exdb/mnist/'
def
__init__
(
self
,
train_or_test
,
shuffle
=
True
,
dir
=
None
):
def
__init__
(
self
,
train_or_test
,
shuffle
=
True
,
dir
=
None
):
"""
"""
...
@@ -77,15 +77,15 @@ class Mnist(RNGDataFlow):
...
@@ -77,15 +77,15 @@ class Mnist(RNGDataFlow):
shuffle (bool): shuffle the dataset
shuffle (bool): shuffle the dataset
"""
"""
if
dir
is
None
:
if
dir
is
None
:
dir
=
get_dataset_path
(
self
.
DIR_NAME
)
dir
=
get_dataset_path
(
self
.
_
DIR_NAME
)
assert
train_or_test
in
[
'train'
,
'test'
]
assert
train_or_test
in
[
'train'
,
'test'
]
self
.
train_or_test
=
train_or_test
self
.
train_or_test
=
train_or_test
self
.
shuffle
=
shuffle
self
.
shuffle
=
shuffle
def
get_images_and_labels
(
image_file
,
label_file
):
def
get_images_and_labels
(
image_file
,
label_file
):
f
=
maybe_download
(
self
.
SOURCE_URL
+
image_file
,
dir
)
f
=
maybe_download
(
self
.
_
SOURCE_URL
+
image_file
,
dir
)
images
=
extract_images
(
f
)
images
=
extract_images
(
f
)
f
=
maybe_download
(
self
.
SOURCE_URL
+
label_file
,
dir
)
f
=
maybe_download
(
self
.
_
SOURCE_URL
+
label_file
,
dir
)
labels
=
extract_labels
(
f
)
labels
=
extract_labels
(
f
)
assert
images
.
shape
[
0
]
==
labels
.
shape
[
0
]
assert
images
.
shape
[
0
]
==
labels
.
shape
[
0
]
return
images
,
labels
return
images
,
labels
...
@@ -113,8 +113,21 @@ class Mnist(RNGDataFlow):
...
@@ -113,8 +113,21 @@ class Mnist(RNGDataFlow):
class
FashionMnist
(
Mnist
):
class
FashionMnist
(
Mnist
):
DIR_NAME
=
'fashion_mnist_data'
"""
SOURCE_URL
=
'http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/'
Same API as :class:`Mnist`, but more fashion.
"""
_DIR_NAME
=
'fashion_mnist_data'
_SOURCE_URL
=
'http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/'
def
get_label_names
(
self
):
"""
Returns:
[str]: the name of each class
"""
# copied from https://github.com/zalandoresearch/fashion-mnist
return
[
'T-shirt/top'
,
'Trouser'
,
'Pullover'
,
'Dress'
,
'Coat'
,
'Sandal'
,
'Shirt'
,
'Sneaker'
,
'Bag'
,
'Ankle boot'
]
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
...
...
tensorpack/dataflow/dataset/svhn.py
View file @
f227f45f
...
@@ -64,7 +64,8 @@ class SVHNDigit(RNGDataFlow):
...
@@ -64,7 +64,8 @@ class SVHNDigit(RNGDataFlow):
@
staticmethod
@
staticmethod
def
get_per_pixel_mean
():
def
get_per_pixel_mean
():
"""
"""
return 32x32x3 image
Returns:
a 32x32x3 image
"""
"""
a
=
SVHNDigit
(
'train'
)
a
=
SVHNDigit
(
'train'
)
b
=
SVHNDigit
(
'test'
)
b
=
SVHNDigit
(
'test'
)
...
...
tensorpack/models/batch_norm.py
View file @
f227f45f
...
@@ -3,7 +3,6 @@
...
@@ -3,7 +3,6 @@
import
tensorflow
as
tf
import
tensorflow
as
tf
from
tensorflow.contrib.framework
import
add_model_variable
from
tensorflow.python.training
import
moving_averages
from
tensorflow.python.training
import
moving_averages
import
re
import
re
import
six
import
six
...
@@ -191,7 +190,7 @@ def BatchNorm(inputs, axis=None, training=None, momentum=0.9, epsilon=1e-5,
...
@@ -191,7 +190,7 @@ def BatchNorm(inputs, axis=None, training=None, momentum=0.9, epsilon=1e-5,
if
ctx
.
is_main_training_tower
:
if
ctx
.
is_main_training_tower
:
for
v
in
layer
.
non_trainable_variables
:
for
v
in
layer
.
non_trainable_variables
:
if
isinstance
(
v
,
tf
.
Variable
):
if
isinstance
(
v
,
tf
.
Variable
):
add_model_variable
(
v
)
tf
.
add_to_collection
(
tf
.
GraphKeys
.
MODEL_VARIABLES
,
v
)
if
not
ctx
.
is_main_training_tower
or
internal_update
:
if
not
ctx
.
is_main_training_tower
or
internal_update
:
restore_collection
(
coll_bk
)
restore_collection
(
coll_bk
)
...
@@ -354,7 +353,7 @@ def BatchRenorm(x, rmax, dmax, momentum=0.9, epsilon=1e-5,
...
@@ -354,7 +353,7 @@ def BatchRenorm(x, rmax, dmax, momentum=0.9, epsilon=1e-5,
if
ctx
.
is_main_training_tower
:
if
ctx
.
is_main_training_tower
:
for
v
in
layer
.
non_trainable_variables
:
for
v
in
layer
.
non_trainable_variables
:
if
isinstance
(
v
,
tf
.
Variable
):
if
isinstance
(
v
,
tf
.
Variable
):
add_model_variable
(
v
)
tf
.
add_to_collection
(
tf
.
GraphKeys
.
MODEL_VARIABLES
,
v
)
else
:
else
:
# only run UPDATE_OPS in the first tower
# only run UPDATE_OPS in the first tower
restore_collection
(
coll_bk
)
restore_collection
(
coll_bk
)
...
...
tensorpack/tfutils/varreplace.py
View file @
f227f45f
...
@@ -3,7 +3,6 @@
...
@@ -3,7 +3,6 @@
# Credit: Qinyao He
# Credit: Qinyao He
import
tensorflow
as
tf
import
tensorflow
as
tf
from
tensorflow.contrib.framework
import
add_model_variable
from
contextlib
import
contextmanager
from
contextlib
import
contextmanager
from
.common
import
get_tf_version_tuple
from
.common
import
get_tf_version_tuple
...
@@ -13,6 +12,13 @@ __all__ = ['custom_getter_scope', 'freeze_variables', 'remap_variables']
...
@@ -13,6 +12,13 @@ __all__ = ['custom_getter_scope', 'freeze_variables', 'remap_variables']
@
contextmanager
@
contextmanager
def
custom_getter_scope
(
custom_getter
):
def
custom_getter_scope
(
custom_getter
):
"""
Args:
custom_getter: the same as in :func:`tf.get_variable`
Returns:
The current variable scope with a custom_getter.
"""
scope
=
tf
.
get_variable_scope
()
scope
=
tf
.
get_variable_scope
()
if
get_tf_version_tuple
()
>=
(
1
,
5
):
if
get_tf_version_tuple
()
>=
(
1
,
5
):
with
tf
.
variable_scope
(
with
tf
.
variable_scope
(
...
@@ -35,7 +41,8 @@ def remap_variables(fn):
...
@@ -35,7 +41,8 @@ def remap_variables(fn):
fn (tf.Variable -> tf.Tensor)
fn (tf.Variable -> tf.Tensor)
Returns:
Returns:
a context where all the variables will be mapped by fn.
The current variable scope with a custom_getter that maps
all the variables by fn.
Example:
Example:
.. code-block:: python
.. code-block:: python
...
@@ -83,7 +90,7 @@ def freeze_variables(stop_gradient=True, skip_collection=False):
...
@@ -83,7 +90,7 @@ def freeze_variables(stop_gradient=True, skip_collection=False):
kwargs
[
'trainable'
]
=
False
kwargs
[
'trainable'
]
=
False
v
=
getter
(
*
args
,
**
kwargs
)
v
=
getter
(
*
args
,
**
kwargs
)
if
skip_collection
:
if
skip_collection
:
add_model_variable
(
v
)
tf
.
add_to_collection
(
tf
.
GraphKeys
.
MODEL_VARIABLES
,
v
)
if
trainable
and
stop_gradient
:
if
trainable
and
stop_gradient
:
v
=
tf
.
stop_gradient
(
v
,
name
=
'freezed_'
+
name
)
v
=
tf
.
stop_gradient
(
v
,
name
=
'freezed_'
+
name
)
return
v
return
v
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment