Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
bc4c6044
Commit
bc4c6044
authored
Oct 30, 2017
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Initial attempt at keras-style training
parent
88b99f38
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
197 additions
and
30 deletions
+197
-30
examples/mnist-keras-v2.py
examples/mnist-keras-v2.py
+76
-0
examples/mnist-keras.py
examples/mnist-keras.py
+5
-18
tensorpack/callbacks/inference_runner.py
tensorpack/callbacks/inference_runner.py
+19
-12
tensorpack/contrib/keras.py
tensorpack/contrib/keras.py
+97
-0
No files found.
examples/mnist-keras-v2.py
0 → 100755
View file @
bc4c6044
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# File: mnist-keras-v2.py
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
import
numpy
as
np
import
tensorflow
as
tf
import
tensorflow.contrib.slim
as
slim
import
os
import
sys
import
argparse
import
keras
from
keras.models
import
Sequential
import
keras.layers
as
KL
from
keras
import
regularizers
from
tensorpack.train
import
SimpleTrainer
from
tensorpack.input_source
import
QueueInput
from
tensorpack.callbacks
import
*
from
tensorpack.dataflow
import
dataset
,
BatchData
,
MapData
from
tensorpack.utils
import
logger
from
tensorpack.contrib.keras
import
setup_keras_trainer
IMAGE_SIZE
=
28
def
get_data
():
def
f
(
dp
):
im
=
dp
[
0
][:,
:,
None
]
onehot
=
np
.
zeros
(
10
,
dtype
=
'int32'
)
onehot
[
dp
[
1
]]
=
1
return
[
im
,
onehot
]
train
=
BatchData
(
MapData
(
dataset
.
Mnist
(
'train'
),
f
),
128
)
test
=
BatchData
(
MapData
(
dataset
.
Mnist
(
'test'
),
f
),
256
)
return
train
,
test
if
__name__
==
'__main__'
:
logger
.
auto_set_dir
()
dataset_train
,
dataset_test
=
get_data
()
M
=
Sequential
()
M
.
add
(
KL
.
Conv2D
(
32
,
3
,
activation
=
'relu'
,
input_shape
=
[
IMAGE_SIZE
,
IMAGE_SIZE
,
1
],
padding
=
'same'
))
M
.
add
(
KL
.
MaxPooling2D
())
M
.
add
(
KL
.
Conv2D
(
32
,
3
,
activation
=
'relu'
,
padding
=
'same'
))
M
.
add
(
KL
.
Conv2D
(
32
,
3
,
activation
=
'relu'
,
padding
=
'same'
))
M
.
add
(
KL
.
MaxPooling2D
())
M
.
add
(
KL
.
Conv2D
(
32
,
3
,
padding
=
'same'
,
activation
=
'relu'
))
M
.
add
(
KL
.
Flatten
())
M
.
add
(
KL
.
Dense
(
512
,
activation
=
'relu'
,
kernel_regularizer
=
regularizers
.
l2
(
1e-5
)))
M
.
add
(
KL
.
Dropout
(
0.5
))
M
.
add
(
KL
.
Dense
(
10
,
activation
=
None
,
kernel_regularizer
=
regularizers
.
l2
(
1e-5
)))
M
.
add
(
KL
.
Activation
(
'softmax'
))
trainer
=
SimpleTrainer
()
setup_keras_trainer
(
trainer
,
model
=
M
,
input
=
QueueInput
(
dataset_train
),
optimizer
=
tf
.
train
.
AdamOptimizer
(
1e-3
),
loss
=
'categorical_crossentropy'
,
metrics
=
[
'accuracy'
]
)
trainer
.
train_with_defaults
(
callbacks
=
[
ModelSaver
(),
InferenceRunner
(
dataset_test
,
[
ScalarStats
([
'total_loss'
,
'accuracy'
])]),
],
steps_per_epoch
=
dataset_train
.
size
(),
)
examples/mnist-keras.py
View file @
bc4c6044
#!/usr/bin/env python
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# -*- coding: utf-8 -*-
# File: mnist-keras.py
# File: mnist-keras
-functional
.py
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
import
numpy
as
np
import
numpy
as
np
...
@@ -12,18 +12,18 @@ import argparse
...
@@ -12,18 +12,18 @@ import argparse
import
keras
import
keras
import
keras.layers
as
KL
import
keras.layers
as
KL
import
keras.backend
as
KB
from
keras.models
import
Sequential
from
keras.models
import
Sequential
from
keras
import
regularizers
from
keras
import
regularizers
"""
"""
This is an mnist example demonstrating how to use Keras
models
inside tensorpack.
This is an mnist example demonstrating how to use Keras
symbolic function
inside tensorpack.
This way you can define models in Keras-style, and benefit from the more efficeint trainers in tensorpack.
This way you can define models in Keras-style, and benefit from the more efficeint trainers in tensorpack.
"""
"""
from
tensorpack
import
*
from
tensorpack
import
*
from
tensorpack.dataflow
import
dataset
from
tensorpack.dataflow
import
dataset
from
tensorpack.utils.argtools
import
memoized
from
tensorpack.utils.argtools
import
memoized
from
tensorpack.contrib.keras
import
KerasPhaseCallback
IMAGE_SIZE
=
28
IMAGE_SIZE
=
28
...
@@ -78,18 +78,6 @@ class Model(ModelDesc):
...
@@ -78,18 +78,6 @@ class Model(ModelDesc):
return
tf
.
train
.
AdamOptimizer
(
lr
)
return
tf
.
train
.
AdamOptimizer
(
lr
)
# Keras needs an extra input if learning_phase is used by the model
class
KerasCallback
(
Callback
):
def
__init__
(
self
,
isTrain
):
assert
isinstance
(
isTrain
,
bool
),
isTrain
self
.
_isTrain
=
isTrain
self
.
_learning_phase
=
KB
.
learning_phase
()
def
_before_run
(
self
,
ctx
):
return
tf
.
train
.
SessionRunArgs
(
fetches
=
[],
feed_dict
=
{
self
.
_learning_phase
:
int
(
self
.
_isTrain
)})
def
get_data
():
def
get_data
():
train
=
BatchData
(
dataset
.
Mnist
(
'train'
),
128
)
train
=
BatchData
(
dataset
.
Mnist
(
'train'
),
128
)
test
=
BatchData
(
dataset
.
Mnist
(
'test'
),
256
,
remainder
=
True
)
test
=
BatchData
(
dataset
.
Mnist
(
'test'
),
256
,
remainder
=
True
)
...
@@ -104,12 +92,11 @@ def get_config():
...
@@ -104,12 +92,11 @@ def get_config():
model
=
Model
(),
model
=
Model
(),
dataflow
=
dataset_train
,
dataflow
=
dataset_train
,
callbacks
=
[
callbacks
=
[
KerasCallback
(
True
),
# for Keras training
Keras
Phase
Callback
(
True
),
# for Keras training
ModelSaver
(),
ModelSaver
(),
InferenceRunner
(
InferenceRunner
(
dataset_test
,
dataset_test
,
[
ScalarStats
(
'cross_entropy_loss'
),
ClassificationError
(
'incorrect'
)],
[
ScalarStats
(
'cross_entropy_loss'
),
ClassificationError
(
'incorrect'
)]),
extra_hooks
=
[
CallbackToHook
(
KerasCallback
(
False
))]),
# for keras inference
],
],
max_epoch
=
100
,
max_epoch
=
100
,
)
)
...
...
tensorpack/callbacks/inference_runner.py
View file @
bc4c6044
...
@@ -61,12 +61,11 @@ class InferenceRunnerBase(Callback):
...
@@ -61,12 +61,11 @@ class InferenceRunnerBase(Callback):
Also, InferenceRunner assumes that `trainer.model` exists.
Also, InferenceRunner assumes that `trainer.model` exists.
"""
"""
def
__init__
(
self
,
input
,
infs
,
extra_hooks
=
None
):
def
__init__
(
self
,
input
,
infs
):
"""
"""
Args:
Args:
input (InputSource): the input to use. Must have ``size()``.
input (InputSource): the input to use. Must have ``size()``.
infs (list[Inferencer]): list of :class:`Inferencer` to run.
infs (list[Inferencer]): list of :class:`Inferencer` to run.
extra_hooks (list[SessionRunHook]): extra :class:`SessionRunHook` to run with the evaluation.
"""
"""
self
.
_input_source
=
input
self
.
_input_source
=
input
if
not
isinstance
(
infs
,
list
):
if
not
isinstance
(
infs
,
list
):
...
@@ -82,12 +81,16 @@ class InferenceRunnerBase(Callback):
...
@@ -82,12 +81,16 @@ class InferenceRunnerBase(Callback):
raise
ValueError
(
"Input used in InferenceRunner must have a size!"
)
raise
ValueError
(
"Input used in InferenceRunner must have a size!"
)
logger
.
info
(
"InferenceRunner will eval on an InputSource of size {}"
.
format
(
self
.
_size
))
logger
.
info
(
"InferenceRunner will eval on an InputSource of size {}"
.
format
(
self
.
_size
))
if
extra_hooks
is
None
:
self
.
_hooks
=
[]
extra_hooks
=
[]
self
.
_extra_hooks
=
extra_hooks
def
register_hook
(
self
,
hook
):
"""
Args:
hook (tf.train.SessionRunHook):
"""
self
.
_hooks
.
append
(
hook
)
def
_before_train
(
self
):
def
_before_train
(
self
):
self
.
_hooks
.
extend
(
self
.
_extra_hooks
)
self
.
_hooked_sess
=
HookedSession
(
self
.
trainer
.
sess
,
self
.
_hooks
)
self
.
_hooked_sess
=
HookedSession
(
self
.
trainer
.
sess
,
self
.
_hooks
)
self
.
_input_callbacks
.
before_train
()
self
.
_input_callbacks
.
before_train
()
...
@@ -100,7 +103,7 @@ class InferenceRunner(InferenceRunnerBase):
...
@@ -100,7 +103,7 @@ class InferenceRunner(InferenceRunnerBase):
A callback that runs a list of :class:`Inferencer` on some :class:`InputSource`.
A callback that runs a list of :class:`Inferencer` on some :class:`InputSource`.
"""
"""
def
__init__
(
self
,
input
,
infs
,
tower_name
=
'InferenceTower'
,
device
=
0
,
extra_hooks
=
None
):
def
__init__
(
self
,
input
,
infs
,
tower_name
=
'InferenceTower'
,
device
=
0
):
"""
"""
Args:
Args:
input (InputSource or DataFlow): The :class:`InputSource` to run
input (InputSource or DataFlow): The :class:`InputSource` to run
...
@@ -115,8 +118,7 @@ class InferenceRunner(InferenceRunnerBase):
...
@@ -115,8 +118,7 @@ class InferenceRunner(InferenceRunnerBase):
assert
isinstance
(
input
,
InputSource
),
input
assert
isinstance
(
input
,
InputSource
),
input
self
.
_tower_name
=
tower_name
self
.
_tower_name
=
tower_name
self
.
_device
=
device
self
.
_device
=
device
super
(
InferenceRunner
,
self
)
.
__init__
(
super
(
InferenceRunner
,
self
)
.
__init__
(
input
,
infs
)
input
,
infs
,
extra_hooks
=
extra_hooks
)
def
_build_hook
(
self
,
inf
):
def
_build_hook
(
self
,
inf
):
out_names
=
inf
.
get_fetches
()
out_names
=
inf
.
get_fetches
()
...
@@ -138,11 +140,13 @@ class InferenceRunner(InferenceRunnerBase):
...
@@ -138,11 +140,13 @@ class InferenceRunner(InferenceRunnerBase):
self
.
_input_source
,
self
.
trainer
.
tower_func
)
self
.
_input_source
,
self
.
trainer
.
tower_func
)
self
.
_tower_handle
=
self
.
trainer
.
tower_func
.
towers
[
-
1
]
self
.
_tower_handle
=
self
.
trainer
.
tower_func
.
towers
[
-
1
]
self
.
_hooks
=
[
self
.
_build_hook
(
inf
)
for
inf
in
self
.
infs
]
for
h
in
[
self
.
_build_hook
(
inf
)
for
inf
in
self
.
infs
]:
self
.
register_hook
(
h
)
# trigger_{step,epoch}, {before,after}_epoch is ignored.
# trigger_{step,epoch}, {before,after}_epoch is ignored.
# We assume that InputSource callbacks won't use these methods
# We assume that InputSource callbacks won't use these methods
self
.
_input_callbacks
=
Callbacks
(
input_callbacks
)
self
.
_input_callbacks
=
Callbacks
(
input_callbacks
)
self
.
_hooks
.
extend
(
self
.
_input_callbacks
.
get_hooks
())
for
h
in
self
.
_input_callbacks
.
get_hooks
():
self
.
register_hook
(
h
)
for
inf
in
self
.
infs
:
for
inf
in
self
.
infs
:
inf
.
setup_graph
(
self
.
trainer
)
inf
.
setup_graph
(
self
.
trainer
)
...
@@ -202,7 +206,7 @@ class DataParallelInferenceRunner(InferenceRunnerBase):
...
@@ -202,7 +206,7 @@ class DataParallelInferenceRunner(InferenceRunnerBase):
# setup callbacks and hooks
# setup callbacks and hooks
self
.
_input_callbacks
=
Callbacks
(
input_callbacks
)
self
.
_input_callbacks
=
Callbacks
(
input_callbacks
)
# InputSource might have hooks which break us.
#
TODO
InputSource might have hooks which break us.
# e.g. hooks from StagingInput will force the consumption
# e.g. hooks from StagingInput will force the consumption
# of nr_tower datapoints in every run.
# of nr_tower datapoints in every run.
input_hooks
=
self
.
_input_callbacks
.
get_hooks
()
input_hooks
=
self
.
_input_callbacks
.
get_hooks
()
...
@@ -213,6 +217,9 @@ class DataParallelInferenceRunner(InferenceRunnerBase):
...
@@ -213,6 +217,9 @@ class DataParallelInferenceRunner(InferenceRunnerBase):
inf
.
setup_graph
(
self
.
trainer
)
inf
.
setup_graph
(
self
.
trainer
)
self
.
_input_callbacks
.
setup_graph
(
self
.
trainer
)
self
.
_input_callbacks
.
setup_graph
(
self
.
trainer
)
def
register_hook
(
self
,
h
):
raise
NotImplementedError
(
"DataParallelInferenceRunner doesn't accept extra hooks!"
)
class
InferencerToHookDataParallel
(
InferencerToHook
):
class
InferencerToHookDataParallel
(
InferencerToHook
):
def
__init__
(
self
,
inf
,
fetches
,
size
):
def
__init__
(
self
,
inf
,
fetches
,
size
):
"""
"""
...
...
tensorpack/contrib/keras.py
0 → 100644
View file @
bc4c6044
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# File: keras.py
import
tensorflow
as
tf
from
six.moves
import
zip
import
keras
from
..graph_builder
import
InputDesc
from
..tfutils.tower
import
get_current_tower_context
from
..tfutils.collection
import
freeze_collection
from
..callbacks
import
Callback
,
InferenceRunner
,
CallbackToHook
from
..tfutils.summary
import
add_moving_summary
# Keras needs an extra input if learning_phase is used by the model
# This cb will be used by
# 1. trainer with isTrain=True
# 2. InferenceRunner with isTrain=False, in the form of hooks
class
KerasPhaseCallback
(
Callback
):
def
__init__
(
self
,
isTrain
):
assert
isinstance
(
isTrain
,
bool
),
isTrain
self
.
_isTrain
=
isTrain
self
.
_learning_phase
=
keras
.
backend
.
learning_phase
()
def
_setup_graph
(
self
):
# HACK
cbs
=
self
.
trainer
.
_callbacks
.
cbs
for
cb
in
cbs
:
if
isinstance
(
cb
,
InferenceRunner
):
h
=
CallbackToHook
(
KerasPhaseCallback
(
False
))
cb
.
register_hook
(
h
)
def
_before_run
(
self
,
ctx
):
return
tf
.
train
.
SessionRunArgs
(
fetches
=
[],
feed_dict
=
{
self
.
_learning_phase
:
int
(
self
.
_isTrain
)})
def
setup_keras_trainer
(
trainer
,
model
,
input
,
optimizer
,
loss
,
metrics
=
None
):
"""
Args:
trainer (SingleCostTrainer):
model (keras.model.Model):
input (InputSource):
optimizer (tf.tarin.Optimizer):
loss, metrics: same as in `keras.model.Model.compile()`.
"""
assert
isinstance
(
optimizer
,
tf
.
train
.
Optimizer
),
optimizer
inputs_desc
=
[
InputDesc
.
from_tensor
(
t
)
for
t
in
model
.
inputs
]
outputs_desc
=
[
InputDesc
.
from_tensor
(
t
)
for
t
in
model
.
outputs
]
nr_inputs
=
len
(
inputs_desc
)
# clear the collection
del
tf
.
get_collection_ref
(
tf
.
GraphKeys
.
TRAINABLE_VARIABLES
)[:]
def
get_cost
(
*
inputs
):
ctx
=
get_current_tower_context
()
assert
ctx
.
is_main_training_tower
or
not
ctx
.
has_own_variables
input_tensors
=
list
(
inputs
[:
nr_inputs
])
target_tensors
=
list
(
inputs
[
nr_inputs
:])
# Keras check and do weird things if target is a placeholder..
# Use tf.identity so it's not a placeholder.
target_tensors
=
[
tf
.
identity
(
t
)
for
t
in
target_tensors
]
input_keras_tensors
=
[
keras
.
layers
.
Input
(
tensor
=
t
)
for
t
in
input_tensors
]
outputs
=
model
(
input_keras_tensors
)
M
=
keras
.
models
.
Model
(
input_tensors
,
outputs
)
with
freeze_collection
([
tf
.
GraphKeys
.
TRAINABLE_VARIABLES
]):
# Keras optimizer mistakenly creates TRAINABLE_VARIABLES ...
M
.
compile
(
optimizer
=
optimizer
,
loss
=
loss
,
target_tensors
=
target_tensors
,
metrics
=
metrics
)
add_moving_summary
(
tf
.
identity
(
M
.
total_loss
,
name
=
'total_loss'
))
assert
len
(
M
.
metrics
)
==
len
(
M
.
metrics_tensors
)
for
name
,
tensor
in
zip
(
M
.
metrics
,
M
.
metrics_tensors
):
add_moving_summary
(
tf
.
identity
(
tensor
,
name
=
name
))
# tensorpack requires TRAINABLE_VARIABLES created inside tower
if
ctx
.
is_main_training_tower
:
for
p
in
M
.
weights
:
tf
.
add_to_collection
(
tf
.
GraphKeys
.
TRAINABLE_VARIABLES
,
p
)
return
M
.
total_loss
trainer
.
setup_graph
(
inputs_desc
+
outputs_desc
,
input
,
get_cost
,
lambda
:
optimizer
)
if
model
.
uses_learning_phase
:
trainer
.
register_callback
(
KerasPhaseCallback
(
True
))
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment