Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
ac02c62f
Commit
ac02c62f
authored
Jan 02, 2018
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[Keras] use inputs_desc/targets_desc explicitly, to avoid hacks (#160)
parent
f1ee1833
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
37 additions
and
41 deletions
+37
-41
examples/mnist-keras-v2.py
examples/mnist-keras-v2.py
+15
-12
tensorpack/contrib/keras.py
tensorpack/contrib/keras.py
+22
-29
No files found.
examples/mnist-keras-v2.py
View file @
ac02c62f
...
...
@@ -10,7 +10,7 @@ from tensorflow import keras
KL
=
keras
.
layers
from
tensorpack
.input_source
import
QueueInput
from
tensorpack
import
InputDesc
,
QueueInput
from
tensorpack.dataflow
import
dataset
,
BatchData
,
MapData
from
tensorpack.utils
import
logger
from
tensorpack.contrib.keras
import
KerasModel
...
...
@@ -22,8 +22,7 @@ IMAGE_SIZE = 28
def
get_data
():
def
f
(
dp
):
im
=
dp
[
0
][:,
:,
None
]
onehot
=
np
.
zeros
(
10
,
dtype
=
'int32'
)
onehot
[
dp
[
1
]]
=
1
onehot
=
np
.
eye
(
10
)[
dp
[
1
]]
return
[
im
,
onehot
]
train
=
BatchData
(
MapData
(
dataset
.
Mnist
(
'train'
),
f
),
128
)
...
...
@@ -34,11 +33,14 @@ def get_data():
if
__name__
==
'__main__'
:
logger
.
auto_set_dir
()
def
model_func
(
input_tensors
):
def
model_func
(
inputs
):
"""
Keras model has to be created inside this function to be used with tensorpack.
"""
M
=
keras
.
models
.
Sequential
()
M
.
add
(
KL
.
InputLayer
(
input_shape
=
[
IMAGE_SIZE
,
IMAGE_SIZE
,
1
],
input_tensor
=
input_tensor
s
[
0
]))
# input_tensor have to be used here for tensorpack trainer to function properly.
# Just use inputs[1], inputs[2] if you have multiple inputs.
M
.
add
(
KL
.
InputLayer
(
input_tensor
=
input
s
[
0
]))
M
.
add
(
KL
.
Conv2D
(
32
,
3
,
activation
=
'relu'
,
padding
=
'same'
))
M
.
add
(
KL
.
MaxPooling2D
())
M
.
add
(
KL
.
Conv2D
(
32
,
3
,
activation
=
'relu'
,
padding
=
'same'
))
...
...
@@ -51,18 +53,19 @@ if __name__ == '__main__':
M
.
add
(
KL
.
Dropout
(
0.5
))
M
.
add
(
KL
.
Dense
(
10
,
activation
=
None
,
kernel_regularizer
=
keras
.
regularizers
.
l2
(
1e-5
)))
M
.
add
(
KL
.
Activation
(
'softmax'
))
return
M
dataset_train
,
dataset_test
=
get_data
()
# from tensorpack import *
# trainer = SyncMultiGPUTrainerReplicated(2)
M
=
KerasModel
(
model_func
,
QueueInput
(
dataset_train
))
M
=
KerasModel
(
model_func
,
inputs_desc
=
[
InputDesc
(
tf
.
float32
,
[
None
,
IMAGE_SIZE
,
IMAGE_SIZE
,
1
],
'images'
)],
targets_desc
=
[
InputDesc
(
tf
.
float32
,
[
None
,
10
],
'labels'
)],
input
=
QueueInput
(
dataset_train
))
M
.
compile
(
optimizer
=
tf
.
train
.
AdamOptimizer
(
1e-3
),
loss
=
'categorical_crossentropy'
,
metrics
=
[
'categorical_accuracy'
]
metrics
=
'categorical_accuracy'
)
M
.
fit
(
validation_data
=
dataset_test
,
...
...
tensorpack/contrib/keras.py
View file @
ac02c62f
...
...
@@ -8,8 +8,8 @@ from tensorflow import keras
from
tensorflow.python.keras
import
metrics
as
metrics_module
from
..models.regularize
import
regularize_cost_from_collection
from
..
graph_builder
import
InputDesc
from
..train
import
Trainer
,
SimpleTrainer
,
SyncMultiGPUTrainerParameterServer
,
DistributedTrainerBase
from
..
train
import
Trainer
,
SimpleTrainer
,
SyncMultiGPUTrainerParameterServer
from
..train
.trainers
import
DistributedTrainerBase
from
..callbacks
import
(
Callback
,
InferenceRunner
,
CallbackToHook
,
ScalarStats
)
...
...
@@ -45,6 +45,10 @@ class KerasModelCaller(object):
self
.
cached_model
=
None
def
__call__
(
self
,
input_tensors
):
"""
Returns:
output tensors of this tower, evaluated with the input tensors.
"""
reuse
=
tf
.
get_variable_scope
()
.
reuse
if
self
.
cached_model
is
None
:
assert
not
reuse
...
...
@@ -52,26 +56,13 @@ class KerasModelCaller(object):
return
self
.
cached_model
.
outputs
if
reuse
:
# use the cached Keras model to mimic reuse
return
self
.
cached_model
.
call
(
input_tensors
)
else
:
# create new Keras model if not reuse
M
=
self
.
get_model
(
input_tensors
)
return
M
.
outputs
def
call_virtual
(
self
):
class
NoneTensorProxy
(
object
):
def
__getitem__
(
self
,
index
):
return
None
def
__len__
(
self
):
raise
NotImplementedError
(
"Do not call `len(inputs)` because it's only a virtual object "
"for the moment! Use `inputs[index]` directly!"
)
G_tmp
=
tf
.
Graph
()
# we need a model instance to know metadata about inputs/outputs
with
G_tmp
.
as_default
():
return
self
.
get_model
(
NoneTensorProxy
())
# Keras needs an extra input if learning_phase is used by the model
# This cb will be used by
...
...
@@ -97,8 +88,9 @@ class KerasPhaseCallback(Callback):
def
setup_keras_trainer
(
trainer
,
get_model
,
input
,
optimizer
,
loss
,
metrics
):
trainer
,
get_model
,
inputs_desc
,
targets_desc
,
input
,
optimizer
,
loss
,
metrics
):
"""
Args:
trainer (SingleCostTrainer):
...
...
@@ -113,17 +105,11 @@ def setup_keras_trainer(
assert
isinstance
(
metrics
,
list
),
metrics
model_caller
=
KerasModelCaller
(
get_model
)
M_tmp
=
model_caller
.
call_virtual
()
inputs_desc
=
[
InputDesc
(
t
.
dtype
,
t
.
shape
.
as_list
(),
'input{}'
.
format
(
i
))
for
i
,
t
in
enumerate
(
M_tmp
.
inputs
)]
outputs_desc
=
[
InputDesc
(
t
.
dtype
,
t
.
shape
.
as_list
(),
'output{}'
.
format
(
i
))
for
i
,
t
in
enumerate
(
M_tmp
.
outputs
)]
nr_inputs
=
len
(
inputs_desc
)
def
get_cost
(
*
inputs
):
assert
len
(
inputs
)
==
len
(
inputs_desc
)
+
len
(
outpu
ts_desc
),
\
"Input source size {} != {} + {}"
.
format
(
len
(
inputs
),
len
(
inputs_desc
),
len
(
outpu
ts_desc
))
assert
len
(
inputs
)
==
len
(
inputs_desc
)
+
len
(
targe
ts_desc
),
\
"Input source size {} != {} + {}"
.
format
(
len
(
inputs
),
len
(
inputs_desc
),
len
(
targe
ts_desc
))
ctx
=
get_current_tower_context
()
input_tensors
=
list
(
inputs
[:
nr_inputs
])
target_tensors
=
list
(
inputs
[
nr_inputs
:])
...
...
@@ -173,7 +159,7 @@ def setup_keras_trainer(
return
total_loss
trainer
.
setup_graph
(
inputs_desc
+
outpu
ts_desc
,
inputs_desc
+
targe
ts_desc
,
input
,
get_cost
,
lambda
:
optimizer
)
...
...
@@ -182,20 +168,26 @@ def setup_keras_trainer(
class
KerasModel
(
object
):
def
__init__
(
self
,
get_model
,
input
,
trainer
=
None
):
def
__init__
(
self
,
get_model
,
inputs_desc
,
targets_desc
,
input
,
trainer
=
None
):
"""
Args:
get_model ( -> keras.model.Model):
inputs_desc ([InputDesc]):
targets_desc ([InputDesc]):
input (InputSource):
trainer (Trainer): the default will check the number of available
GPUs and use them all.
"""
self
.
get_model
=
get_model
self
.
inputs_desc
=
inputs_desc
self
.
targets_desc
=
targets_desc
if
trainer
is
None
:
nr_gpu
=
get_nr_gpu
()
if
nr_gpu
<=
1
:
trainer
=
SimpleTrainer
()
else
:
# the default multigpu trainer
trainer
=
SyncMultiGPUTrainerParameterServer
(
nr_gpu
)
assert
isinstance
(
trainer
,
Trainer
),
trainer
assert
not
isinstance
(
trainer
,
DistributedTrainerBase
)
...
...
@@ -219,6 +211,7 @@ class KerasModel(object):
self
.
_stats_to_inference
=
loss
+
metrics
+
[
TOTAL_LOSS_NAME
]
setup_keras_trainer
(
self
.
trainer
,
get_model
=
self
.
get_model
,
inputs_desc
=
self
.
inputs_desc
,
targets_desc
=
self
.
targets_desc
,
input
=
self
.
input
,
optimizer
=
optimizer
,
loss
=
loss
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment