Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
c759f211
Commit
c759f211
authored
Mar 21, 2016
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
use official batchnorm op
parent
df89a95f
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
13 additions
and
13 deletions
+13
-13
examples/cifar10_convnet.py
examples/cifar10_convnet.py
+4
-4
tensorpack/models/batch_norm.py
tensorpack/models/batch_norm.py
+9
-9
No files found.
examples/cifar10_convnet.py
View file @
c759f211
...
...
@@ -19,8 +19,7 @@ from tensorpack.dataflow import *
from
tensorpack.dataflow
import
imgaug
"""
CIFAR10 90
%
validation accuracy after 100k step.
91
%
after 160k step
CIFAR10 90
%
validation accuracy after 70k step.
"""
BATCH_SIZE
=
128
...
...
@@ -43,6 +42,7 @@ class Model(ModelDesc):
num_threads
=
6
,
enqueue_many
=
True
)
tf
.
image_summary
(
"train_image"
,
image
,
10
)
image
=
image
/
4.0
# just to make range smaller
l
=
Conv2D
(
'conv1.1'
,
image
,
out_channel
=
64
,
kernel_shape
=
3
)
l
=
Conv2D
(
'conv1.2'
,
l
,
out_channel
=
64
,
kernel_shape
=
3
,
nl
=
tf
.
identity
)
l
=
BatchNorm
(
'bn1'
,
l
,
is_training
)
...
...
@@ -112,7 +112,7 @@ def get_data(train_or_test):
ds
=
AugmentImageComponent
(
ds
,
augmentors
)
ds
=
BatchData
(
ds
,
128
,
remainder
=
not
isTrain
)
if
isTrain
:
ds
=
PrefetchData
(
ds
,
3
,
2
)
ds
=
PrefetchData
(
ds
,
10
,
5
)
return
ds
...
...
@@ -120,7 +120,7 @@ def get_data(train_or_test):
def
get_config
():
# prepare dataset
dataset_train
=
get_data
(
'train'
)
step_per_epoch
=
dataset_train
.
size
()
/
2
step_per_epoch
=
dataset_train
.
size
()
dataset_test
=
get_data
(
'test'
)
sess_config
=
get_default_sess_config
()
...
...
tensorpack/models/batch_norm.py
View file @
c759f211
...
...
@@ -31,10 +31,7 @@ def BatchNorm(x, use_local_stat=True, decay=0.9, epsilon=1e-5):
"""
shape
=
x
.
get_shape
()
.
as_list
()
if
len
(
shape
)
==
2
:
x
=
tf
.
reshape
(
x
,
[
-
1
,
1
,
1
,
shape
[
1
]])
shape
=
x
.
get_shape
()
.
as_list
()
assert
len
(
shape
)
==
4
assert
len
(
shape
)
in
[
2
,
4
]
n_out
=
shape
[
-
1
]
# channel
beta
=
tf
.
get_variable
(
'beta'
,
[
n_out
])
...
...
@@ -42,7 +39,10 @@ def BatchNorm(x, use_local_stat=True, decay=0.9, epsilon=1e-5):
'gamma'
,
[
n_out
],
initializer
=
tf
.
constant_initializer
(
1.0
))
batch_mean
,
batch_var
=
tf
.
nn
.
moments
(
x
,
[
0
,
1
,
2
],
name
=
'moments'
)
if
len
(
shape
)
==
2
:
batch_mean
,
batch_var
=
tf
.
nn
.
moments
(
x
,
[
0
],
name
=
'moments'
,
keep_dims
=
False
)
else
:
batch_mean
,
batch_var
=
tf
.
nn
.
moments
(
x
,
[
0
,
1
,
2
],
name
=
'moments'
,
keep_dims
=
False
)
ema
=
tf
.
train
.
ExponentialMovingAverage
(
decay
=
decay
)
ema_apply_op
=
ema
.
apply
([
batch_mean
,
batch_var
])
...
...
@@ -50,10 +50,10 @@ def BatchNorm(x, use_local_stat=True, decay=0.9, epsilon=1e-5):
if
use_local_stat
:
with
tf
.
control_dependencies
([
ema_apply_op
]):
return
tf
.
nn
.
batch_norm
_with_global_norm
alization
(
x
,
batch_mean
,
batch_var
,
beta
,
gamma
,
epsilon
,
True
)
return
tf
.
nn
.
batch_normalization
(
x
,
batch_mean
,
batch_var
,
beta
,
gamma
,
epsilon
,
'bn'
)
else
:
batch
=
tf
.
cast
(
tf
.
shape
(
x
)[
0
],
tf
.
float32
)
mean
,
var
=
ema_mean
,
ema_var
*
batch
/
(
batch
-
1
)
# unbiased variance estimator
return
tf
.
nn
.
batch_norm
_with_global_norm
alization
(
x
,
mean
,
var
,
beta
,
gamma
,
epsilon
,
True
)
return
tf
.
nn
.
batch_normalization
(
x
,
mean
,
var
,
beta
,
gamma
,
epsilon
,
'bn'
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment