Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
712fd299
Commit
712fd299
authored
May 05, 2018
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Not all Keras variables are marked trainable (#748)
parent
ac72ab73
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
2 additions
and
2 deletions
+2
-2
examples/keras/imagenet-resnet-keras.py
examples/keras/imagenet-resnet-keras.py
+1
-1
tensorpack/contrib/keras.py
tensorpack/contrib/keras.py
+1
-1
No files found.
examples/keras/imagenet-resnet-keras.py
View file @
712fd299
...
...
@@ -136,7 +136,7 @@ if __name__ == '__main__':
parser
.
add_argument
(
'--data'
,
help
=
'ILSVRC dataset dir'
)
parser
.
add_argument
(
'--fake'
,
help
=
'use fakedata to test or benchmark this model'
,
action
=
'store_true'
)
args
=
parser
.
parse_args
()
logger
.
set_logger_dir
(
"train_log/imagenet-resnet-keras"
)
logger
.
set_logger_dir
(
os
.
path
.
join
(
"train_log"
,
"imagenet-resnet-keras"
)
)
tf
.
keras
.
backend
.
set_image_data_format
(
'channels_first'
)
...
...
tensorpack/contrib/keras.py
View file @
712fd299
...
...
@@ -78,7 +78,7 @@ class KerasModelCaller(object):
for
v
in
M
.
weights
:
# In Keras, the collection is not respected and could contain non-trainable vars.
# We put M.weights into the collection instead.
if
v
.
name
not
in
old_trainable_names
:
if
v
.
name
not
in
old_trainable_names
and
v
.
name
in
added_trainable_names
:
tf
.
add_to_collection
(
tf
.
GraphKeys
.
TRAINABLE_VARIABLES
,
v
)
new_trainable_names
=
set
([
x
.
name
for
x
in
tf
.
trainable_variables
()])
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment