Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
63b8fb00
Commit
63b8fb00
authored
Jan 08, 2019
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
make contrib.keras docs build
parent
ce3782ad
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
11 additions
and
12 deletions
+11
-12
tensorpack/contrib/keras.py
tensorpack/contrib/keras.py
+11
-12
No files found.
tensorpack/contrib/keras.py
View file @
63b8fb00
...
@@ -4,9 +4,7 @@
...
@@ -4,9 +4,7 @@
from
contextlib
import
contextmanager
from
contextlib
import
contextmanager
import
six
import
six
import
tensorflow
as
tf
import
tensorflow
as
tf
import
tensorflow.keras.backend
as
K
from
tensorflow
import
keras
from
tensorflow
import
keras
from
tensorflow.python.keras
import
metrics
as
metrics_module
from
..callbacks
import
Callback
,
CallbackToHook
,
InferenceRunner
,
InferenceRunnerBase
,
ScalarStats
from
..callbacks
import
Callback
,
CallbackToHook
,
InferenceRunner
,
InferenceRunnerBase
,
ScalarStats
from
..models.regularize
import
regularize_cost_from_collection
from
..models.regularize
import
regularize_cost_from_collection
...
@@ -36,11 +34,10 @@ def _check_name(tensor, name):
...
@@ -36,11 +34,10 @@ def _check_name(tensor, name):
class
KerasModelCaller
(
object
):
class
KerasModelCaller
(
object
):
"""
"""
Keras model doesn't support variable scope reuse.
Keras model doesn't support variable scope reuse.
This is a
hack
to mimic reuse.
This is a
wrapper around keras model
to mimic reuse.
"""
"""
def
__init__
(
self
,
get_model
):
def
__init__
(
self
,
get_model
):
self
.
get_model
=
get_model
self
.
get_model
=
get_model
self
.
cached_model
=
None
self
.
cached_model
=
None
def
__call__
(
self
,
input_tensors
):
def
__call__
(
self
,
input_tensors
):
...
@@ -70,7 +67,7 @@ class KerasModelCaller(object):
...
@@ -70,7 +67,7 @@ class KerasModelCaller(object):
for
n
in
added_trainable_names
:
for
n
in
added_trainable_names
:
if
n
not
in
new_trainable_names
:
if
n
not
in
new_trainable_names
:
logger
.
warn
(
"Keras created trainable variable '{}' which is actually not trainable. "
logger
.
warn
(
"Keras created trainable variable '{}' which is actually not trainable. "
"This was automatically corrected
by tensorpack
."
.
format
(
n
))
"This was automatically corrected."
.
format
(
n
))
# Keras models might not use this collection at all (in some versions).
# Keras models might not use this collection at all (in some versions).
# This is a BC-breaking change of tf.keras: https://github.com/tensorflow/tensorflow/issues/19643
# This is a BC-breaking change of tf.keras: https://github.com/tensorflow/tensorflow/issues/19643
...
@@ -93,7 +90,7 @@ class KerasModelCaller(object):
...
@@ -93,7 +90,7 @@ class KerasModelCaller(object):
with
clear_tower0_name_scope
():
with
clear_tower0_name_scope
():
model
=
self
.
cached_model
=
self
.
get_model
(
*
input_tensors
)
model
=
self
.
cached_model
=
self
.
get_model
(
*
input_tensors
)
assert
isinstance
(
model
,
tf
.
keras
.
Model
),
\
assert
isinstance
(
model
,
keras
.
Model
),
\
"Your get_model function should return a `tf.keras.Model`!"
"Your get_model function should return a `tf.keras.Model`!"
outputs
=
model
.
outputs
outputs
=
model
.
outputs
elif
reuse
:
elif
reuse
:
...
@@ -125,7 +122,7 @@ class KerasPhaseCallback(Callback):
...
@@ -125,7 +122,7 @@ class KerasPhaseCallback(Callback):
def
__init__
(
self
,
isTrain
):
def
__init__
(
self
,
isTrain
):
assert
isinstance
(
isTrain
,
bool
),
isTrain
assert
isinstance
(
isTrain
,
bool
),
isTrain
self
.
_isTrain
=
isTrain
self
.
_isTrain
=
isTrain
self
.
_learning_phase
=
K
.
learning_phase
()
self
.
_learning_phase
=
keras
.
backend
.
learning_phase
()
def
_setup_graph
(
self
):
def
_setup_graph
(
self
):
logger
.
info
(
"Using Keras learning phase {} in the graph!"
.
format
(
logger
.
info
(
"Using Keras learning phase {} in the graph!"
.
format
(
...
@@ -149,8 +146,9 @@ def setup_keras_trainer(
...
@@ -149,8 +146,9 @@ def setup_keras_trainer(
"""
"""
Args:
Args:
trainer (SingleCostTrainer):
trainer (SingleCostTrainer):
get_model (input1, input2, ... -> keras.model.Model):
get_model (input1, input2, ... -> tf.keras.Model):
Takes tensors and returns a Keras model. Will be part of the tower function.
A function which takes tensors, builds and returns a Keras model.
It will be part of the tower function.
input (InputSource):
input (InputSource):
optimizer (tf.train.Optimizer):
optimizer (tf.train.Optimizer):
loss, metrics: list of strings
loss, metrics: list of strings
...
@@ -202,7 +200,7 @@ def setup_keras_trainer(
...
@@ -202,7 +200,7 @@ def setup_keras_trainer(
output_tensor
=
outputs
[
oid
]
output_tensor
=
outputs
[
oid
]
target_tensor
=
target_tensors
[
oid
]
# TODO may not have the same mapping?
target_tensor
=
target_tensors
[
oid
]
# TODO may not have the same mapping?
with
cached_name_scope
(
'keras_metric'
,
top_level
=
False
):
with
cached_name_scope
(
'keras_metric'
,
top_level
=
False
):
metric_fn
=
metrics_module
.
get
(
metric_name
)
metric_fn
=
keras
.
metrics
.
get
(
metric_name
)
metric_tensor
=
metric_fn
(
target_tensor
,
output_tensor
)
metric_tensor
=
metric_fn
(
target_tensor
,
output_tensor
)
metric_tensor
=
tf
.
reduce_mean
(
metric_tensor
,
name
=
metric_name
)
metric_tensor
=
tf
.
reduce_mean
(
metric_tensor
,
name
=
metric_name
)
_check_name
(
metric_tensor
,
metric_name
)
_check_name
(
metric_tensor
,
metric_name
)
...
@@ -217,7 +215,7 @@ def setup_keras_trainer(
...
@@ -217,7 +215,7 @@ def setup_keras_trainer(
input
,
input
,
get_cost
,
get_cost
,
lambda
:
optimizer
)
lambda
:
optimizer
)
if
len
(
K
.
learning_phase
()
.
consumers
())
>
0
:
if
len
(
keras
.
backend
.
learning_phase
()
.
consumers
())
>
0
:
# check if learning_phase is used in this model
# check if learning_phase is used in this model
trainer
.
register_callback
(
KerasPhaseCallback
(
True
))
trainer
.
register_callback
(
KerasPhaseCallback
(
True
))
...
@@ -228,7 +226,8 @@ class KerasModel(object):
...
@@ -228,7 +226,8 @@ class KerasModel(object):
"""
"""
Args:
Args:
get_model (input1, input2, ... -> keras.Model):
get_model (input1, input2, ... -> keras.Model):
A function which takes tensors and returns a Keras model. Will be part of the tower function.
A function which takes tensors, builds and returns a Keras model.
It will be part of the tower function.
inputs_desc ([InputDesc]):
inputs_desc ([InputDesc]):
targets_desc ([InputDesc]):
targets_desc ([InputDesc]):
input (InputSource | DataFlow):
input (InputSource | DataFlow):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment