Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
236c78e0
Commit
236c78e0
authored
Jul 05, 2016
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
better save/restore with towerp & batch_norm
parent
178f3611
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
22 additions
and
11 deletions
+22
-11
examples/DisturbLabel/README.md
examples/DisturbLabel/README.md
+2
-1
examples/DoReFa-Net/svhn-digit-dorefa.py
examples/DoReFa-Net/svhn-digit-dorefa.py
+4
-0
opt-requirements.txt
opt-requirements.txt
+1
-1
tensorpack/models/batch_norm.py
tensorpack/models/batch_norm.py
+6
-6
tensorpack/tfutils/sessinit.py
tensorpack/tfutils/sessinit.py
+9
-3
No files found.
examples/DisturbLabel/README.md
View file @
236c78e0
## DisturbLabel
## DisturbLabel
I ran into the paper
[
DisturbLabel: Regularizing CNN on the Loss Layer
](
https://arxiv.org/abs/1605.00055
)
on CVPR16.
I ran into the paper
[
DisturbLabel: Regularizing CNN on the Loss Layer
](
https://arxiv.org/abs/1605.00055
)
on CVPR16,
which basically said that noisy data gives you better performance.
As many, I didn't believe the method and the results at first.
As many, I didn't believe the method and the results at first.
This is a simple mnist training script with DisturbLabel. It uses the architecture in the paper and
This is a simple mnist training script with DisturbLabel. It uses the architecture in the paper and
...
...
examples/DoReFa-Net/svhn-digit-dorefa.py
View file @
236c78e0
...
@@ -43,6 +43,10 @@ def get_dorefa(bitW, bitA, bitG):
...
@@ -43,6 +43,10 @@ def get_dorefa(bitW, bitA, bitG):
def
fw
(
x
):
def
fw
(
x
):
if
bitW
==
32
:
if
bitW
==
32
:
return
x
return
x
if
bitW
==
1
:
# BWN
with
G
.
gradient_override_map
({
"Sign"
:
"Identity"
}):
E
=
tf
.
stop_gradient
(
tf
.
reduce_mean
(
tf
.
abs
(
x
)))
return
tf
.
sign
(
x
/
E
)
*
E
x
=
tf
.
tanh
(
x
)
x
=
tf
.
tanh
(
x
)
x
=
x
/
tf
.
reduce_max
(
tf
.
abs
(
x
))
*
0.5
+
0.5
x
=
x
/
tf
.
reduce_max
(
tf
.
abs
(
x
))
*
0.5
+
0.5
return
2
*
quantize
(
x
,
bitW
)
-
1
return
2
*
quantize
(
x
,
bitW
)
-
1
...
...
opt-requirements.txt
View file @
236c78e0
...
@@ -3,4 +3,4 @@ scipy
...
@@ -3,4 +3,4 @@ scipy
nltk
nltk
h5py
h5py
pyzmq
pyzmq
tornado
tornado
; python_version < '3.0'
tensorpack/models/batch_norm.py
View file @
236c78e0
...
@@ -64,20 +64,20 @@ def BatchNorm(x, use_local_stat=True, decay=0.9, epsilon=1e-5):
...
@@ -64,20 +64,20 @@ def BatchNorm(x, use_local_stat=True, decay=0.9, epsilon=1e-5):
assert
not
use_local_stat
assert
not
use_local_stat
with
tf
.
name_scope
(
None
):
with
tf
.
name_scope
(
None
):
ema
=
tf
.
train
.
ExponentialMovingAverage
(
decay
=
decay
,
name
=
emaname
)
ema
=
tf
.
train
.
ExponentialMovingAverage
(
decay
=
decay
,
name
=
emaname
)
ema_apply_op
=
ema
.
apply
([
batch_mean
,
batch_var
])
mean_var_name
=
ema
.
average_name
(
batch_mean
)
+
':0'
ema_mean
,
ema_var
=
ema
.
average
(
batch_mean
),
ema
.
average
(
batch_var
)
var_var_name
=
ema
.
average_name
(
batch_var
)
+
':0'
G
=
tf
.
get_default_graph
()
G
=
tf
.
get_default_graph
()
# find training statistics in training tower
# find training statistics in training tower
try
:
try
:
mean_name
=
re
.
sub
(
'towerp[0-9]+/'
,
''
,
ema_mean
.
name
)
mean_name
=
re
.
sub
(
'towerp[0-9]+/'
,
''
,
mean_var_
name
)
var_name
=
re
.
sub
(
'towerp[0-9]+/'
,
''
,
ema_var
.
name
)
var_name
=
re
.
sub
(
'towerp[0-9]+/'
,
''
,
var_var_
name
)
#var_name = batch_var.op.name[prefixlen:] + '/' + emaname + ':0'
#var_name = batch_var.op.name[prefixlen:] + '/' + emaname + ':0'
ema_mean
=
G
.
get_tensor_by_name
(
mean_name
)
ema_mean
=
G
.
get_tensor_by_name
(
mean_name
)
ema_var
=
G
.
get_tensor_by_name
(
var_name
)
ema_var
=
G
.
get_tensor_by_name
(
var_name
)
except
KeyError
:
except
KeyError
:
mean_name
=
re
.
sub
(
'towerp[0-9]+/'
,
'tower0/'
,
ema_mean
.
name
)
mean_name
=
re
.
sub
(
'towerp[0-9]+/'
,
'tower0/'
,
mean_var_
name
)
var_name
=
re
.
sub
(
'towerp[0-9]+/'
,
'tower0/'
,
ema_var
.
name
)
var_name
=
re
.
sub
(
'towerp[0-9]+/'
,
'tower0/'
,
var_var_
name
)
ema_mean
=
G
.
get_tensor_by_name
(
mean_name
)
ema_mean
=
G
.
get_tensor_by_name
(
mean_name
)
ema_var
=
G
.
get_tensor_by_name
(
var_name
)
ema_var
=
G
.
get_tensor_by_name
(
var_name
)
#logger.info("In prediction, using {} instead of {} for {}".format(
#logger.info("In prediction, using {} instead of {} for {}".format(
...
...
tensorpack/tfutils/sessinit.py
View file @
236c78e0
...
@@ -94,7 +94,11 @@ class SaverRestore(SessionInit):
...
@@ -94,7 +94,11 @@ class SaverRestore(SessionInit):
@
staticmethod
@
staticmethod
def
_read_checkpoint_vars
(
model_path
):
def
_read_checkpoint_vars
(
model_path
):
reader
=
tf
.
train
.
NewCheckpointReader
(
model_path
)
reader
=
tf
.
train
.
NewCheckpointReader
(
model_path
)
return
set
(
reader
.
get_variable_to_shape_map
()
.
keys
())
ckpt_vars
=
reader
.
get_variable_to_shape_map
()
.
keys
()
for
v
in
ckpt_vars
:
if
v
.
startswith
(
'towerp'
):
logger
.
warn
(
"Found {} in checkpoint. Anything from prediction tower shouldn't be saved."
.
format
(
v
.
name
))
return
set
(
ckpt_vars
)
@
staticmethod
@
staticmethod
def
_get_vars_to_restore_multimap
(
vars_available
):
def
_get_vars_to_restore_multimap
(
vars_available
):
...
@@ -102,13 +106,14 @@ class SaverRestore(SessionInit):
...
@@ -102,13 +106,14 @@ class SaverRestore(SessionInit):
Get a dict of {var_name: [var, var]} to restore
Get a dict of {var_name: [var, var]} to restore
:param vars_available: varaibles available in the checkpoint, for existence checking
:param vars_available: varaibles available in the checkpoint, for existence checking
"""
"""
# TODO warn if some variable in checkpoint is not used
vars_to_restore
=
tf
.
all_variables
()
vars_to_restore
=
tf
.
all_variables
()
var_dict
=
defaultdict
(
list
)
var_dict
=
defaultdict
(
list
)
for
v
in
vars_to_restore
:
for
v
in
vars_to_restore
:
name
=
v
.
op
.
name
name
=
v
.
op
.
name
if
'towerp'
in
name
:
if
'towerp'
in
name
:
logger
.
warn
(
"Anything from prediction tower shouldn't be saved."
)
logger
.
warn
(
"Variable {} in prediction tower shouldn't exist."
.
format
(
v
.
name
))
# don't overwrite anything in the current prediction graph
continue
if
'tower'
in
name
:
if
'tower'
in
name
:
new_name
=
re
.
sub
(
'tower[p0-9]+/'
,
''
,
name
)
new_name
=
re
.
sub
(
'tower[p0-9]+/'
,
''
,
name
)
name
=
new_name
name
=
new_name
...
@@ -117,6 +122,7 @@ class SaverRestore(SessionInit):
...
@@ -117,6 +122,7 @@ class SaverRestore(SessionInit):
vars_available
.
remove
(
name
)
vars_available
.
remove
(
name
)
else
:
else
:
logger
.
warn
(
"Param {} not found in checkpoint! Will not restore."
.
format
(
v
.
op
.
name
))
logger
.
warn
(
"Param {} not found in checkpoint! Will not restore."
.
format
(
v
.
op
.
name
))
# TODO warn if some variable in checkpoint is not used
#for name in vars_available:
#for name in vars_available:
#logger.warn("Param {} in checkpoint doesn't appear in the graph!".format(name))
#logger.warn("Param {} in checkpoint doesn't appear in the graph!".format(name))
return
var_dict
return
var_dict
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment