Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
f1ee1833
Commit
f1ee1833
authored
Jan 02, 2018
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[Keras] minor improvements (#160)
parent
e8674dca
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
68 additions
and
35 deletions
+68
-35
tensorpack/contrib/keras.py
tensorpack/contrib/keras.py
+60
-30
tensorpack/tfutils/scope_utils.py
tensorpack/tfutils/scope_utils.py
+8
-5
No files found.
tensorpack/contrib/keras.py
View file @
f1ee1833
...
...
@@ -7,22 +7,33 @@ import six
from
tensorflow
import
keras
from
tensorflow.python.keras
import
metrics
as
metrics_module
from
..models.regularize
import
regularize_cost_from_collection
from
..graph_builder
import
InputDesc
from
..tfutils.tower
import
get_current_tower_context
# from ..tfutils.collection import freeze_collection # TODO freeze UPDATE_OPS in replicated
from
..train
import
Trainer
,
SimpleTrainer
,
SyncMultiGPUTrainerParameterServer
,
DistributedTrainerBase
from
..callbacks
import
(
Callback
,
InferenceRunner
,
CallbackToHook
,
ScalarStats
)
from
..tfutils.common
import
get_op_tensor_name
from
..tfutils.tower
import
get_current_tower_context
from
..tfutils.scope_utils
import
cached_name_scope
# from ..tfutils.collection import freeze_collection # TODO freeze UPDATE_OPS in replicated
from
..tfutils.summary
import
add_moving_summary
from
..utils.gpu
import
get_nr_gpu
from
..train
import
Trainer
,
SimpleTrainer
,
SyncMultiGPUTrainerParameterServer
__all__
=
[
'KerasPhaseCallback'
,
'setup_keras_trainer'
,
'KerasModel'
]
TOTAL_LOSS_NAME
=
'total_loss'
def
_check_name
(
tensor
,
name
):
tensorname
=
get_op_tensor_name
(
tensor
.
name
)[
0
]
assert
tensorname
.
split
(
'/'
)[
-
1
]
==
name
,
\
"{} does not match {}, you may have name conflict somewhere!"
.
format
(
tensor
.
name
,
name
)
class
KerasModelCaller
(
object
):
"""
Keras model doesn't support vs reuse.
...
...
@@ -46,6 +57,21 @@ class KerasModelCaller(object):
M
=
self
.
get_model
(
input_tensors
)
return
M
.
outputs
def
call_virtual
(
self
):
class
NoneTensorProxy
(
object
):
def
__getitem__
(
self
,
index
):
return
None
def
__len__
(
self
):
raise
NotImplementedError
(
"Do not call `len(inputs)` because it's only a virtual object "
"for the moment! Use `inputs[index]` directly!"
)
G_tmp
=
tf
.
Graph
()
# we need a model instance to know metadata about inputs/outputs
with
G_tmp
.
as_default
():
return
self
.
get_model
(
NoneTensorProxy
())
# Keras needs an extra input if learning_phase is used by the model
# This cb will be used by
...
...
@@ -58,9 +84,9 @@ class KerasPhaseCallback(Callback):
self
.
_learning_phase
=
keras
.
backend
.
learning_phase
()
def
_setup_graph
(
self
):
# HACK
cbs
=
self
.
trainer
.
_callbacks
.
cbs
for
cb
in
cbs
:
# XXX HACK
if
isinstance
(
cb
,
InferenceRunner
):
h
=
CallbackToHook
(
KerasPhaseCallback
(
False
))
cb
.
register_hook
(
h
)
...
...
@@ -72,7 +98,7 @@ class KerasPhaseCallback(Callback):
def
setup_keras_trainer
(
trainer
,
get_model
,
input
,
optimizer
,
loss
,
metrics
=
None
):
optimizer
,
loss
,
metrics
):
"""
Args:
trainer (SingleCostTrainer):
...
...
@@ -82,18 +108,18 @@ def setup_keras_trainer(
loss, metrics: list of strings
"""
assert
isinstance
(
optimizer
,
tf
.
train
.
Optimizer
),
optimizer
assert
isinstance
(
loss
,
list
),
loss
assert
len
(
loss
)
>=
1
,
"No loss was given!"
assert
isinstance
(
metrics
,
list
),
metrics
model_caller
=
KerasModelCaller
(
get_model
)
M_tmp
=
model_caller
.
call_virtual
()
G_tmp
=
tf
.
Graph
()
# we need the model instance to know metadata about inputs/outputs
with
G_tmp
.
as_default
():
M_tmp
=
get_model
([
None
])
# TODO use a proxy with Nones
inputs_desc
=
[
InputDesc
(
t
.
dtype
,
t
.
shape
.
as_list
(),
'input{}'
.
format
(
i
))
for
i
,
t
in
enumerate
(
M_tmp
.
inputs
)]
outputs_desc
=
[
InputDesc
(
t
.
dtype
,
t
.
shape
.
as_list
(),
'output{}'
.
format
(
i
))
for
i
,
t
in
enumerate
(
M_tmp
.
outputs
)]
nr_inputs
=
len
(
inputs_desc
)
del
G_tmp
,
M_tmp
model_caller
=
KerasModelCaller
(
get_model
)
def
get_cost
(
*
inputs
):
assert
len
(
inputs
)
==
len
(
inputs_desc
)
+
len
(
outputs_desc
),
\
...
...
@@ -112,19 +138,22 @@ def setup_keras_trainer(
assert
len
(
outputs
)
==
len
(
loss
),
\
"len({}) != len({})"
.
format
(
str
(
outputs
),
str
(
loss
))
# TODO more losses
with
tf
.
name_scope
(
'keras_loss'
):
loss_fn
=
keras
.
losses
.
get
(
loss
[
0
])
loss_opt
=
loss_fn
(
target_tensors
[
0
],
outputs
[
0
])
loss_opt
=
tf
.
reduce_mean
(
loss_opt
,
name
=
loss
[
0
])
loss_tensors
=
[]
for
idx
,
loss_name
in
enumerate
(
loss
):
with
cached_name_scope
(
'keras_loss'
,
top_level
=
False
):
loss_fn
=
keras
.
losses
.
get
(
loss_name
)
curr_loss
=
loss_fn
(
target_tensors
[
idx
],
outputs
[
idx
])
curr_loss
=
tf
.
reduce_mean
(
curr_loss
,
name
=
loss_name
)
_check_name
(
curr_loss
,
loss_name
)
loss_tensors
.
append
(
curr_loss
)
loss_reg
=
regularize_cost_from_collection
()
if
loss_reg
is
not
None
:
total_loss
=
tf
.
add
(
loss_opt
,
loss_reg
,
name
=
'total_loss'
)
add_moving_summary
(
loss_
opt
,
loss_reg
,
total_los
s
)
total_loss
=
tf
.
add
_n
(
loss_tensors
+
[
loss_reg
],
name
=
TOTAL_LOSS_NAME
)
add_moving_summary
(
loss_
reg
,
total_loss
,
*
loss_tensor
s
)
else
:
add_moving_summary
(
loss_opt
)
total_loss
=
tf
.
identity
(
loss_opt
,
name
=
'total_loss'
)
add_moving_summary
(
*
loss_tensors
)
total_loss
=
tf
.
add_n
(
loss_tensors
,
name
=
TOTAL_LOSS_NAME
)
if
metrics
and
(
ctx
.
is_main_training_tower
or
not
ctx
.
is_training
):
# for list: one metric for each output
...
...
@@ -132,10 +161,11 @@ def setup_keras_trainer(
for
oid
,
metric_name
in
enumerate
(
metrics
):
output_tensor
=
outputs
[
oid
]
target_tensor
=
target_tensors
[
oid
]
# TODO may not have the same mapping?
with
tf
.
name_scope
(
'keras_metric'
):
# TODO ns reuse
with
cached_name_scope
(
'keras_metric'
,
top_level
=
False
):
metric_fn
=
metrics_module
.
get
(
metric_name
)
metric_tensor
=
metric_fn
(
target_tensor
,
output_tensor
)
metric_tensor
=
tf
.
reduce_mean
(
metric_tensor
,
name
=
metric_name
)
_check_name
(
metric_tensor
,
metric_name
)
# check name conflict here
metric_tensors
.
append
(
metric_tensor
)
add_moving_summary
(
*
metric_tensors
)
...
...
@@ -168,6 +198,7 @@ class KerasModel(object):
else
:
trainer
=
SyncMultiGPUTrainerParameterServer
(
nr_gpu
)
assert
isinstance
(
trainer
,
Trainer
),
trainer
assert
not
isinstance
(
trainer
,
DistributedTrainerBase
)
self
.
input
=
input
self
.
trainer
=
trainer
...
...
@@ -185,7 +216,7 @@ class KerasModel(object):
if
isinstance
(
metrics
,
six
.
string_types
):
metrics
=
[
metrics
]
self
.
_stats_to_inference
=
loss
+
metrics
self
.
_stats_to_inference
=
loss
+
metrics
+
[
TOTAL_LOSS_NAME
]
setup_keras_trainer
(
self
.
trainer
,
get_model
=
self
.
get_model
,
input
=
self
.
input
,
...
...
@@ -201,7 +232,6 @@ class KerasModel(object):
"""
callbacks
=
kwargs
.
pop
(
'callbacks'
,
[])
if
validation_data
is
not
None
:
callbacks
.
append
(
InferenceRunner
(
validation_data
,
ScalarStats
(
self
.
_stats_to_inference
+
[
'total_loss'
])))
callbacks
.
append
(
InferenceRunner
(
validation_data
,
ScalarStats
(
self
.
_stats_to_inference
)))
self
.
trainer
.
train_with_defaults
(
callbacks
=
callbacks
,
**
kwargs
)
tensorpack/tfutils/scope_utils.py
View file @
f1ee1833
...
...
@@ -85,15 +85,18 @@ def _get_cached_ns(name):
@
contextmanager
def
cached_name_scope
(
name
):
def
cached_name_scope
(
name
,
top_level
=
True
):
"""
Return a context which either opens and caches a new
top-level
name scope,
Return a context which either opens and caches a new name scope,
or reenter an existing one.
Note
:
The name scope will always be top-level. It will not be nested under
any existing name scope of the caller.
Args
:
top_level(bool): if True, the name scope will always be top-level.
It will not be nested under
any existing name scope of the caller.
"""
if
not
top_level
:
current_ns
=
tf
.
get_default_graph
()
.
get_name_scope
()
name
=
current_ns
+
'/'
+
name
ns
=
_get_cached_ns
(
name
)
with
tf
.
name_scope
(
ns
):
yield
ns
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment