Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
756dbc70
Commit
756dbc70
authored
Jun 28, 2016
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
update var name for batch_norm
parent
117fb29f
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
16 additions
and
14 deletions
+16
-14
examples/Atari2600/DQN.py
examples/Atari2600/DQN.py
+1
-2
tensorpack/models/batch_norm.py
tensorpack/models/batch_norm.py
+8
-8
tensorpack/tfutils/summary.py
tensorpack/tfutils/summary.py
+7
-4
No files found.
examples/Atari2600/DQN.py
View file @
756dbc70
...
@@ -71,7 +71,7 @@ class Model(ModelDesc):
...
@@ -71,7 +71,7 @@ class Model(ModelDesc):
""" image: [0,255]"""
""" image: [0,255]"""
image
=
image
/
255.0
image
=
image
/
255.0
with
argscope
(
Conv2D
,
nl
=
PReLU
.
f
,
use_bias
=
True
):
with
argscope
(
Conv2D
,
nl
=
PReLU
.
f
,
use_bias
=
True
):
l
=
(
LinearWrap
(
image
)
return
(
LinearWrap
(
image
)
.
Conv2D
(
'conv0'
,
out_channel
=
32
,
kernel_shape
=
5
)
.
Conv2D
(
'conv0'
,
out_channel
=
32
,
kernel_shape
=
5
)
.
MaxPooling
(
'pool0'
,
2
)
.
MaxPooling
(
'pool0'
,
2
)
.
Conv2D
(
'conv1'
,
out_channel
=
32
,
kernel_shape
=
5
)
.
Conv2D
(
'conv1'
,
out_channel
=
32
,
kernel_shape
=
5
)
...
@@ -87,7 +87,6 @@ class Model(ModelDesc):
...
@@ -87,7 +87,6 @@ class Model(ModelDesc):
.
FullyConnected
(
'fc0'
,
512
,
nl
=
lambda
x
,
name
:
LeakyReLU
.
f
(
x
,
0.01
,
name
))
.
FullyConnected
(
'fc0'
,
512
,
nl
=
lambda
x
,
name
:
LeakyReLU
.
f
(
x
,
0.01
,
name
))
.
FullyConnected
(
'fct'
,
NUM_ACTIONS
,
nl
=
tf
.
identity
)())
.
FullyConnected
(
'fct'
,
NUM_ACTIONS
,
nl
=
tf
.
identity
)())
return
l
def
_build_graph
(
self
,
inputs
,
is_training
):
def
_build_graph
(
self
,
inputs
,
is_training
):
state
,
action
,
reward
,
next_state
,
isOver
=
inputs
state
,
action
,
reward
,
next_state
,
isOver
=
inputs
...
...
tensorpack/models/batch_norm.py
View file @
756dbc70
...
@@ -55,17 +55,17 @@ def BatchNorm(x, use_local_stat=True, decay=0.9, epsilon=1e-5):
...
@@ -55,17 +55,17 @@ def BatchNorm(x, use_local_stat=True, decay=0.9, epsilon=1e-5):
emaname
=
'EMA'
emaname
=
'EMA'
in_main_tower
=
not
batch_mean
.
name
.
startswith
(
'towerp'
)
in_main_tower
=
not
batch_mean
.
name
.
startswith
(
'towerp'
)
if
in_main_tower
:
if
in_main_tower
:
ema
=
tf
.
train
.
ExponentialMovingAverage
(
decay
=
decay
,
name
=
emaname
)
with
tf
.
name_scope
(
None
):
# https://github.com/tensorflow/tensorflow/issues/2740
ema_apply_op
=
ema
.
apply
([
batch_mean
,
batch_var
])
ema
=
tf
.
train
.
ExponentialMovingAverage
(
decay
=
decay
,
name
=
emaname
)
ema_mean
,
ema_var
=
ema
.
average
(
batch_mean
),
ema
.
average
(
batch_var
)
ema_apply_op
=
ema
.
apply
([
batch_mean
,
batch_var
])
ema_mean
,
ema_var
=
ema
.
average
(
batch_mean
),
ema
.
average
(
batch_var
)
else
:
else
:
# use training-statistics in prediction
# use training-statistics in prediction
assert
not
use_local_stat
assert
not
use_local_stat
# XXX have to do this again to get actual name. see issue:
with
tf
.
name_scope
(
None
):
# https://github.com/tensorflow/tensorflow/issues/2740
# https://github.com/tensorflow/tensorflow/issues/2740
ema
=
tf
.
train
.
ExponentialMovingAverage
(
decay
=
decay
,
name
=
emaname
)
ema
=
tf
.
train
.
ExponentialMovingAverage
(
decay
=
decay
,
name
=
emaname
)
ema_apply_op
=
ema
.
apply
([
batch_mean
,
batch_var
])
ema_apply_op
=
ema
.
apply
([
batch_mean
,
batch_var
])
ema_mean
,
ema_var
=
ema
.
average
(
batch_mean
),
ema
.
average
(
batch_var
)
ema_mean
,
ema_var
=
ema
.
average
(
batch_mean
),
ema
.
average
(
batch_var
)
G
=
tf
.
get_default_graph
()
G
=
tf
.
get_default_graph
()
# find training statistics in training tower
# find training statistics in training tower
...
...
tensorpack/tfutils/summary.py
View file @
756dbc70
...
@@ -10,7 +10,7 @@ from ..utils import *
...
@@ -10,7 +10,7 @@ from ..utils import *
from
.
import
get_global_step_var
from
.
import
get_global_step_var
__all__
=
[
'create_summary'
,
'add_param_summary'
,
'add_activation_summary'
,
__all__
=
[
'create_summary'
,
'add_param_summary'
,
'add_activation_summary'
,
'summary_moving_average'
]
'
add_moving_summary'
,
'
summary_moving_average'
]
def
create_summary
(
name
,
v
):
def
create_summary
(
name
,
v
):
"""
"""
...
@@ -42,8 +42,8 @@ def add_param_summary(summary_lists):
...
@@ -42,8 +42,8 @@ def add_param_summary(summary_lists):
"""
"""
Add summary for all trainable variables matching the regex
Add summary for all trainable variables matching the regex
:param summary_lists: list of (regex, [list of
action
to perform]).
:param summary_lists: list of (regex, [list of
summary type
to perform]).
Action
can be 'mean', 'scalar', 'histogram', 'sparsity', 'rms'
Type
can be 'mean', 'scalar', 'histogram', 'sparsity', 'rms'
"""
"""
def
perform
(
var
,
action
):
def
perform
(
var
,
action
):
ndim
=
var
.
get_shape
()
.
ndims
ndim
=
var
.
get_shape
()
.
ndims
...
@@ -66,7 +66,7 @@ def add_param_summary(summary_lists):
...
@@ -66,7 +66,7 @@ def add_param_summary(summary_lists):
tf
.
scalar_summary
(
name
+
'/rms'
,
tf
.
scalar_summary
(
name
+
'/rms'
,
tf
.
sqrt
(
tf
.
reduce_mean
(
tf
.
square
(
var
))))
tf
.
sqrt
(
tf
.
reduce_mean
(
tf
.
square
(
var
))))
return
return
raise
RuntimeError
(
"Unknown
action
{}"
.
format
(
action
))
raise
RuntimeError
(
"Unknown
summary type:
{}"
.
format
(
action
))
import
re
import
re
params
=
tf
.
get_collection
(
tf
.
GraphKeys
.
TRAINABLE_VARIABLES
)
params
=
tf
.
get_collection
(
tf
.
GraphKeys
.
TRAINABLE_VARIABLES
)
...
@@ -79,6 +79,9 @@ def add_param_summary(summary_lists):
...
@@ -79,6 +79,9 @@ def add_param_summary(summary_lists):
for
act
in
actions
:
for
act
in
actions
:
perform
(
p
,
act
)
perform
(
p
,
act
)
def
add_moving_summary
(
v
):
tf
.
add_to_collection
(
MOVING_SUMMARY_VARS_KEY
,
v
)
def
summary_moving_average
():
def
summary_moving_average
():
""" Create a MovingAverage op and summary for all variables in
""" Create a MovingAverage op and summary for all variables in
MOVING_SUMMARY_VARS_KEY.
MOVING_SUMMARY_VARS_KEY.
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment