Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
0937a01f
Commit
0937a01f
authored
Jun 04, 2018
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
small fix on keras examples
parent
0430c07c
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
8 additions
and
2 deletions
+8
-2
examples/DeepQNetwork/DQNModel.py
examples/DeepQNetwork/DQNModel.py
+1
-2
examples/keras/imagenet-resnet-keras.py
examples/keras/imagenet-resnet-keras.py
+1
-0
tensorpack/contrib/keras.py
tensorpack/contrib/keras.py
+6
-0
No files found.
examples/DeepQNetwork/DQNModel.py
View file @
0937a01f
...
...
@@ -83,8 +83,7 @@ class Model(ModelDesc):
def
optimizer
(
self
):
lr
=
tf
.
get_variable
(
'learning_rate'
,
initializer
=
self
.
learning_rate
,
trainable
=
False
)
opt
=
tf
.
train
.
AdamOptimizer
(
lr
,
epsilon
=
1e-3
)
return
optimizer
.
apply_grad_processors
(
opt
,
[
gradproc
.
GlobalNormClip
(
10
),
gradproc
.
SummaryGradient
()])
return
optimizer
.
apply_grad_processors
(
opt
,
[
gradproc
.
SummaryGradient
()])
@
staticmethod
def
update_target_param
():
...
...
examples/keras/imagenet-resnet-keras.py
View file @
0937a01f
...
...
@@ -4,6 +4,7 @@
# Author: Yuxin Wu
import
numpy
as
np
import
os
import
tensorflow
as
tf
import
argparse
...
...
tensorpack/contrib/keras.py
View file @
0937a01f
...
...
@@ -56,6 +56,7 @@ class KerasModelCaller(object):
old_trainable_names
=
set
([
x
.
name
for
x
in
tf
.
trainable_variables
()])
trainable_backup
=
backup_collection
([
tf
.
GraphKeys
.
TRAINABLE_VARIABLES
])
update_ops_backup
=
backup_collection
([
tf
.
GraphKeys
.
UPDATE_OPS
])
def
post_process_model
(
model
):
added_trainable_names
=
set
([
x
.
name
for
x
in
tf
.
trainable_variables
()])
...
...
@@ -73,6 +74,11 @@ class KerasModelCaller(object):
logger
.
warn
(
"Keras created trainable variable '{}' which is actually not trainable. "
"This was automatically corrected by tensorpack."
.
format
(
n
))
# Keras models might not use this collection at all (in some versions).
restore_collection
(
update_ops_backup
)
for
op
in
model
.
updates
:
tf
.
add_to_collection
(
tf
.
GraphKeys
.
UPDATE_OPS
,
op
)
if
self
.
cached_model
is
None
:
assert
not
reuse
model
=
self
.
cached_model
=
self
.
get_model
(
*
input_tensors
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment