Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
6f6787db
Commit
6f6787db
authored
Jan 08, 2019
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
add documentation about KerasModel (#1036)
parent
f5d1714a
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
14 additions
and
6 deletions
+14
-6
docs/modules/index.rst
docs/modules/index.rst
+1
-1
tensorpack/contrib/keras.py
tensorpack/contrib/keras.py
+13
-5
No files found.
docs/modules/index.rst
View file @
6f6787db
...
@@ -16,4 +16,4 @@ API Documentation
...
@@ -16,4 +16,4 @@ API Documentation
predict
predict
tfutils
tfutils
utils
utils
contrib
tensorpack/contrib/keras.py
View file @
6f6787db
...
@@ -93,6 +93,8 @@ class KerasModelCaller(object):
...
@@ -93,6 +93,8 @@ class KerasModelCaller(object):
with
clear_tower0_name_scope
():
with
clear_tower0_name_scope
():
model
=
self
.
cached_model
=
self
.
get_model
(
*
input_tensors
)
model
=
self
.
cached_model
=
self
.
get_model
(
*
input_tensors
)
assert
isinstance
(
model
,
tf
.
keras
.
Model
),
\
"Your get_model function should return a `tf.keras.Model`!"
outputs
=
model
.
outputs
outputs
=
model
.
outputs
elif
reuse
:
elif
reuse
:
# use the cached Keras model to mimic reuse
# use the cached Keras model to mimic reuse
...
@@ -110,11 +112,16 @@ class KerasModelCaller(object):
...
@@ -110,11 +112,16 @@ class KerasModelCaller(object):
return
outputs
return
outputs
# Keras needs an extra input if learning_phase is used by the model
# This cb will be used by
# 1. trainer with isTrain=True
# 2. InferenceRunner with isTrain=False, in the form of hooks
class
KerasPhaseCallback
(
Callback
):
class
KerasPhaseCallback
(
Callback
):
"""
Keras needs an extra input if learning_phase is used by the model
This callback will be used:
1. By the trainer with isTrain=True
2. By InferenceRunner with isTrain=False, in the form of hooks
If you use :class:`KerasModel` or :func:`setup_keras_trainer`,
this callback will be automatically added when needed.
"""
def
__init__
(
self
,
isTrain
):
def
__init__
(
self
,
isTrain
):
assert
isinstance
(
isTrain
,
bool
),
isTrain
assert
isinstance
(
isTrain
,
bool
),
isTrain
self
.
_isTrain
=
isTrain
self
.
_isTrain
=
isTrain
...
@@ -221,7 +228,7 @@ class KerasModel(object):
...
@@ -221,7 +228,7 @@ class KerasModel(object):
"""
"""
Args:
Args:
get_model (input1, input2, ... -> keras.Model):
get_model (input1, input2, ... -> keras.Model):
T
akes tensors and returns a Keras model. Will be part of the tower function.
A function which t
akes tensors and returns a Keras model. Will be part of the tower function.
inputs_desc ([InputDesc]):
inputs_desc ([InputDesc]):
targets_desc ([InputDesc]):
targets_desc ([InputDesc]):
input (InputSource | DataFlow):
input (InputSource | DataFlow):
...
@@ -229,6 +236,7 @@ class KerasModel(object):
...
@@ -229,6 +236,7 @@ class KerasModel(object):
GPUs and use them all.
GPUs and use them all.
"""
"""
self
.
get_model
=
get_model
self
.
get_model
=
get_model
assert
callable
(
get_model
),
get_model
self
.
inputs_desc
=
inputs_desc
self
.
inputs_desc
=
inputs_desc
self
.
targets_desc
=
targets_desc
self
.
targets_desc
=
targets_desc
if
trainer
is
None
:
if
trainer
is
None
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment