Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
e8674dca
Commit
e8674dca
authored
Jan 02, 2018
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[Keras] Use get_model function instead of letting users create the model
directly (#160)
parent
365c56d2
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
130 additions
and
78 deletions
+130
-78
examples/mnist-keras-v2.py
examples/mnist-keras-v2.py
+28
-14
tensorpack/contrib/keras.py
tensorpack/contrib/keras.py
+102
-59
tensorpack/graph_builder/model_desc.py
tensorpack/graph_builder/model_desc.py
+0
-5
No files found.
examples/mnist-keras-v2.py
View file @
e8674dca
...
...
@@ -14,6 +14,7 @@ from tensorpack.input_source import QueueInput
from
tensorpack.dataflow
import
dataset
,
BatchData
,
MapData
from
tensorpack.utils
import
logger
from
tensorpack.contrib.keras
import
KerasModel
from
tensorpack.callbacks
import
ModelSaver
IMAGE_SIZE
=
28
...
...
@@ -32,28 +33,41 @@ def get_data():
if
__name__
==
'__main__'
:
logger
.
auto_set_dir
()
def
model_func
(
input_tensors
):
M
=
keras
.
models
.
Sequential
()
M
.
add
(
KL
.
Conv2D
(
32
,
3
,
activation
=
'relu'
,
input_shape
=
[
IMAGE_SIZE
,
IMAGE_SIZE
,
1
],
padding
=
'same'
))
M
.
add
(
KL
.
InputLayer
(
input_shape
=
[
IMAGE_SIZE
,
IMAGE_SIZE
,
1
],
input_tensor
=
input_tensors
[
0
]))
M
.
add
(
KL
.
Conv2D
(
32
,
3
,
activation
=
'relu'
,
padding
=
'same'
))
M
.
add
(
KL
.
MaxPooling2D
())
M
.
add
(
KL
.
Conv2D
(
32
,
3
,
activation
=
'relu'
,
padding
=
'same'
))
M
.
add
(
KL
.
Conv2D
(
32
,
3
,
activation
=
'relu'
,
padding
=
'same'
))
M
.
add
(
KL
.
MaxPooling2D
())
M
.
add
(
KL
.
Conv2D
(
32
,
3
,
padding
=
'same'
,
activation
=
'relu'
))
M
.
add
(
KL
.
Flatten
())
M
.
add
(
KL
.
Dense
(
512
,
activation
=
'relu'
,
kernel_regularizer
=
keras
.
regularizers
.
l2
(
1e-5
)))
M
.
add
(
KL
.
Dropout
(
0.5
))
M
.
add
(
KL
.
Dense
(
10
,
activation
=
None
,
kernel_regularizer
=
keras
.
regularizers
.
l2
(
1e-5
)))
M
.
add
(
KL
.
Activation
(
'softmax'
))
return
M
dataset_train
,
dataset_test
=
get_data
()
M
=
KerasModel
(
M
,
QueueInput
(
dataset_train
))
# from tensorpack import *
# trainer = SyncMultiGPUTrainerReplicated(2)
M
=
KerasModel
(
model_func
,
QueueInput
(
dataset_train
))
M
.
compile
(
optimizer
=
tf
.
train
.
AdamOptimizer
(
1e-3
),
loss
=
'categorical_crossentropy'
,
metrics
=
[
'accuracy'
]
metrics
=
[
'
categorical_
accuracy'
]
)
M
.
fit
(
validation_data
=
dataset_test
,
steps_per_epoch
=
dataset_train
.
size
(),
callbacks
=
[
ModelSaver
()
]
)
tensorpack/contrib/keras.py
View file @
e8674dca
...
...
@@ -3,15 +3,18 @@
# File: keras.py
import
tensorflow
as
tf
from
six.moves
import
zip
import
six
from
tensorflow
import
keras
from
tensorflow.python.keras
import
metrics
as
metrics_module
from
..models.regularize
import
regularize_cost_from_collection
from
..graph_builder
import
InputDesc
from
..tfutils.tower
import
get_current_tower_context
from
..tfutils.collection
import
freeze_collection
# from ..tfutils.collection import freeze_collection # TODO freeze UPDATE_OPS in replicated
from
..callbacks
import
(
Callback
,
InferenceRunner
,
CallbackToHook
,
ScalarStats
,
ModelSaver
)
ScalarStats
)
from
..tfutils.summary
import
add_moving_summary
from
..utils.gpu
import
get_nr_gpu
from
..train
import
Trainer
,
SimpleTrainer
,
SyncMultiGPUTrainerParameterServer
...
...
@@ -20,6 +23,30 @@ from ..train import Trainer, SimpleTrainer, SyncMultiGPUTrainerParameterServer
__all__
=
[
'KerasPhaseCallback'
,
'setup_keras_trainer'
,
'KerasModel'
]
class
KerasModelCaller
(
object
):
"""
Keras model doesn't support vs reuse.
This is hack to mimic reuse.
"""
def
__init__
(
self
,
get_model
):
self
.
get_model
=
get_model
self
.
cached_model
=
None
def
__call__
(
self
,
input_tensors
):
reuse
=
tf
.
get_variable_scope
()
.
reuse
if
self
.
cached_model
is
None
:
assert
not
reuse
self
.
cached_model
=
self
.
get_model
(
input_tensors
)
return
self
.
cached_model
.
outputs
if
reuse
:
return
self
.
cached_model
.
call
(
input_tensors
)
else
:
M
=
self
.
get_model
(
input_tensors
)
return
M
.
outputs
# Keras needs an extra input if learning_phase is used by the model
# This cb will be used by
# 1. trainer with isTrain=True
...
...
@@ -44,81 +71,96 @@ class KerasPhaseCallback(Callback):
def
setup_keras_trainer
(
trainer
,
model
,
input
,
trainer
,
get_
model
,
input
,
optimizer
,
loss
,
metrics
=
None
):
"""
Args:
trainer (SingleCostTrainer):
model (
keras.model.Model):
get_model ( ->
keras.model.Model):
input (InputSource):
optimizer (tf.tarin.Optimizer):
loss, metrics:
same as in `keras.model.Model.compile()`.
loss, metrics:
list of strings
"""
assert
isinstance
(
optimizer
,
tf
.
train
.
Optimizer
),
optimizer
inputs_desc
=
[
InputDesc
.
from_tensor
(
t
)
for
t
in
model
.
inputs
]
outputs_desc
=
[
InputDesc
.
from_tensor
(
t
)
for
t
in
model
.
outputs
]
G_tmp
=
tf
.
Graph
()
# we need the model instance to know metadata about inputs/outputs
with
G_tmp
.
as_default
():
M_tmp
=
get_model
([
None
])
# TODO use a proxy with Nones
inputs_desc
=
[
InputDesc
(
t
.
dtype
,
t
.
shape
.
as_list
(),
'input{}'
.
format
(
i
))
for
i
,
t
in
enumerate
(
M_tmp
.
inputs
)]
outputs_desc
=
[
InputDesc
(
t
.
dtype
,
t
.
shape
.
as_list
(),
'output{}'
.
format
(
i
))
for
i
,
t
in
enumerate
(
M_tmp
.
outputs
)]
nr_inputs
=
len
(
inputs_desc
)
del
G_tmp
,
M_tmp
# clear the collection
del
tf
.
get_collection_ref
(
tf
.
GraphKeys
.
TRAINABLE_VARIABLES
)[:]
model_caller
=
KerasModelCaller
(
get_model
)
def
get_cost
(
*
inputs
):
assert
len
(
inputs
)
==
len
(
inputs_desc
)
+
len
(
outputs_desc
),
\
"Input source size {} != {} + {}"
.
format
(
len
(
inputs
),
len
(
inputs_desc
),
len
(
outputs_desc
))
ctx
=
get_current_tower_context
()
assert
ctx
.
is_main_training_tower
or
not
ctx
.
has_own_variables
input_tensors
=
list
(
inputs
[:
nr_inputs
])
target_tensors
=
list
(
inputs
[
nr_inputs
:])
# Keras check and do weird things if target is a placeholder..
# Use tf.identity so it's not a placeholder.
target_tensors
=
[
tf
.
identity
(
t
)
for
t
in
target_tensors
]
input_keras_tensors
=
[
keras
.
layers
.
Input
(
tensor
=
t
)
for
t
in
input_tensors
]
outputs
=
model
(
input_keras_tensors
)
M
=
keras
.
models
.
Model
(
input_tensors
,
outputs
)
with
freeze_collection
([
tf
.
GraphKeys
.
TRAINABLE_VARIABLES
]):
# Keras optimizer mistakenly creates TRAINABLE_VARIABLES ...
M
.
compile
(
optimizer
=
optimizer
,
loss
=
loss
,
target_tensors
=
target_tensors
,
metrics
=
metrics
)
# BN updates
if
ctx
.
is_training
:
for
u
in
M
.
updates
:
tf
.
add_to_collection
(
tf
.
GraphKeys
.
UPDATE_OPS
,
u
)
add_moving_summary
(
tf
.
identity
(
M
.
total_loss
,
name
=
'total_loss'
))
assert
len
(
M
.
metrics
)
==
len
(
M
.
metrics_tensors
)
for
name
,
tensor
in
zip
(
M
.
metrics
,
M
.
metrics_tensors
):
add_moving_summary
(
tf
.
identity
(
tensor
,
name
=
name
))
# tensorpack requires TRAINABLE_VARIABLES created inside tower
if
ctx
.
is_main_training_tower
:
for
p
in
M
.
weights
:
tf
.
add_to_collection
(
tf
.
GraphKeys
.
TRAINABLE_VARIABLES
,
p
)
return
M
.
total_loss
# TODO mapping between target tensors & output tensors
outputs
=
model_caller
(
input_tensors
)
if
isinstance
(
outputs
,
tf
.
Tensor
):
outputs
=
[
outputs
]
assert
len
(
outputs
)
==
len
(
target_tensors
),
\
"len({}) != len({})"
.
format
(
str
(
outputs
),
str
(
target_tensors
))
assert
len
(
outputs
)
==
len
(
loss
),
\
"len({}) != len({})"
.
format
(
str
(
outputs
),
str
(
loss
))
# TODO more losses
with
tf
.
name_scope
(
'keras_loss'
):
loss_fn
=
keras
.
losses
.
get
(
loss
[
0
])
loss_opt
=
loss_fn
(
target_tensors
[
0
],
outputs
[
0
])
loss_opt
=
tf
.
reduce_mean
(
loss_opt
,
name
=
loss
[
0
])
loss_reg
=
regularize_cost_from_collection
()
if
loss_reg
is
not
None
:
total_loss
=
tf
.
add
(
loss_opt
,
loss_reg
,
name
=
'total_loss'
)
add_moving_summary
(
loss_opt
,
loss_reg
,
total_loss
)
else
:
add_moving_summary
(
loss_opt
)
total_loss
=
tf
.
identity
(
loss_opt
,
name
=
'total_loss'
)
if
metrics
and
(
ctx
.
is_main_training_tower
or
not
ctx
.
is_training
):
# for list: one metric for each output
metric_tensors
=
[]
for
oid
,
metric_name
in
enumerate
(
metrics
):
output_tensor
=
outputs
[
oid
]
target_tensor
=
target_tensors
[
oid
]
# TODO may not have the same mapping?
with
tf
.
name_scope
(
'keras_metric'
):
# TODO ns reuse
metric_fn
=
metrics_module
.
get
(
metric_name
)
metric_tensor
=
metric_fn
(
target_tensor
,
output_tensor
)
metric_tensor
=
tf
.
reduce_mean
(
metric_tensor
,
name
=
metric_name
)
# check name conflict here
metric_tensors
.
append
(
metric_tensor
)
add_moving_summary
(
*
metric_tensors
)
return
total_loss
trainer
.
setup_graph
(
inputs_desc
+
outputs_desc
,
input
,
get_cost
,
lambda
:
optimizer
)
if
model
.
uses_learning_phase
:
if
model
_caller
.
cached_model
.
uses_learning_phase
:
trainer
.
register_callback
(
KerasPhaseCallback
(
True
))
class
KerasModel
(
object
):
def
__init__
(
self
,
model
,
input
,
trainer
=
None
):
def
__init__
(
self
,
get_
model
,
input
,
trainer
=
None
):
"""
Args:
model (
keras.model.Model):
get_model ( ->
keras.model.Model):
input (InputSource):
trainer (Trainer): the default will check the number of available
GPUs and use them all.
"""
self
.
model
=
model
self
.
get_model
=
get_
model
if
trainer
is
None
:
nr_gpu
=
get_nr_gpu
()
if
nr_gpu
<=
1
:
...
...
@@ -130,15 +172,22 @@ class KerasModel(object):
self
.
input
=
input
self
.
trainer
=
trainer
def
compile
(
self
,
optimizer
,
loss
,
metrics
):
def
compile
(
self
,
optimizer
,
loss
,
metrics
=
None
):
"""
Args:
optimizer (tf.train.Optimizer):
loss, metrics: s
ame as in `keras.model.Model.compile()`.
loss, metrics: s
tring or list of strings
"""
self
.
_metrics
=
metrics
if
isinstance
(
loss
,
six
.
string_types
):
loss
=
[
loss
]
if
metrics
is
None
:
metrics
=
[]
if
isinstance
(
metrics
,
six
.
string_types
):
metrics
=
[
metrics
]
self
.
_stats_to_inference
=
loss
+
metrics
setup_keras_trainer
(
self
.
trainer
,
model
=
self
.
model
,
self
.
trainer
,
get_model
=
self
.
get_
model
,
input
=
self
.
input
,
optimizer
=
optimizer
,
loss
=
loss
,
...
...
@@ -151,14 +200,8 @@ class KerasModel(object):
kwargs: same as `self.trainer.train_with_defaults`.
"""
callbacks
=
kwargs
.
pop
(
'callbacks'
,
[])
callbacks
.
extend
(
self
.
get_default_callbacks
())
if
validation_data
is
not
None
:
callbacks
.
append
(
InferenceRunner
(
validation_data
,
ScalarStats
(
self
.
_
metrics
+
[
'total_loss'
])))
validation_data
,
ScalarStats
(
self
.
_
stats_to_inference
+
[
'total_loss'
])))
self
.
trainer
.
train_with_defaults
(
callbacks
=
callbacks
,
**
kwargs
)
def
get_default_callbacks
(
self
):
return
[
ModelSaver
(
keep_checkpoint_every_n_hours
=
0.2
)
]
tensorpack/graph_builder/model_desc.py
View file @
e8674dca
...
...
@@ -65,11 +65,6 @@ class InputDesc(
return
self
.
_cached_placeholder
return
self
.
build_placeholder
()
@
staticmethod
def
from_tensor
(
t
):
return
InputDesc
(
t
.
dtype
,
t
.
shape
.
as_list
(),
t
.
name
[:
-
2
])
@
six
.
add_metaclass
(
ABCMeta
)
class
ModelDescBase
(
object
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment