Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
72385a85
Commit
72385a85
authored
Nov 05, 2017
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
add `KerasModel` wrapper (#160)
parent
42f10617
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
46 additions
and
9 deletions
+46
-9
docs/tutorial/extend/dataflow.md
docs/tutorial/extend/dataflow.md
+1
-0
examples/mnist-keras-v2.py
examples/mnist-keras-v2.py
+5
-9
tensorpack/contrib/keras.py
tensorpack/contrib/keras.py
+40
-0
No files found.
docs/tutorial/extend/dataflow.md
View file @
72385a85
...
@@ -24,6 +24,7 @@ Optionally, you can implement the following two methods:
...
@@ -24,6 +24,7 @@ Optionally, you can implement the following two methods:
A typical situation is when your DataFlow uses random number generator (RNG). Then you would need to reset the RNG here.
A typical situation is when your DataFlow uses random number generator (RNG). Then you would need to reset the RNG here.
Otherwise, child processes will have the same random seed. The `RNGDataFlow` base class does this for you.
Otherwise, child processes will have the same random seed. The `RNGDataFlow` base class does this for you.
You can subclass `RNGDataFlow` to access `self.rng` whose seed has been taken care of.
With a "low-level" DataFlow defined like above, you can then compose it with existing modules (e.g. batching, prefetching, ...).
With a "low-level" DataFlow defined like above, you can then compose it with existing modules (e.g. batching, prefetching, ...).
...
...
examples/mnist-keras-v2.py
View file @
72385a85
...
@@ -16,7 +16,7 @@ from tensorpack.input_source import QueueInput
...
@@ -16,7 +16,7 @@ from tensorpack.input_source import QueueInput
from
tensorpack.callbacks
import
ModelSaver
,
InferenceRunner
,
ScalarStats
from
tensorpack.callbacks
import
ModelSaver
,
InferenceRunner
,
ScalarStats
from
tensorpack.dataflow
import
dataset
,
BatchData
,
MapData
from
tensorpack.dataflow
import
dataset
,
BatchData
,
MapData
from
tensorpack.utils
import
logger
from
tensorpack.utils
import
logger
from
tensorpack.contrib.keras
import
setup_keras_trainer
from
tensorpack.contrib.keras
import
KerasModel
IMAGE_SIZE
=
28
IMAGE_SIZE
=
28
...
@@ -35,8 +35,6 @@ def get_data():
...
@@ -35,8 +35,6 @@ def get_data():
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
logger
.
auto_set_dir
()
logger
.
auto_set_dir
()
dataset_train
,
dataset_test
=
get_data
()
M
=
Sequential
()
M
=
Sequential
()
M
.
add
(
KL
.
Conv2D
(
32
,
3
,
activation
=
'relu'
,
input_shape
=
[
IMAGE_SIZE
,
IMAGE_SIZE
,
1
],
padding
=
'same'
))
M
.
add
(
KL
.
Conv2D
(
32
,
3
,
activation
=
'relu'
,
input_shape
=
[
IMAGE_SIZE
,
IMAGE_SIZE
,
1
],
padding
=
'same'
))
M
.
add
(
KL
.
MaxPooling2D
())
M
.
add
(
KL
.
MaxPooling2D
())
...
@@ -50,17 +48,15 @@ if __name__ == '__main__':
...
@@ -50,17 +48,15 @@ if __name__ == '__main__':
M
.
add
(
KL
.
Dense
(
10
,
activation
=
None
,
kernel_regularizer
=
regularizers
.
l2
(
1e-5
)))
M
.
add
(
KL
.
Dense
(
10
,
activation
=
None
,
kernel_regularizer
=
regularizers
.
l2
(
1e-5
)))
M
.
add
(
KL
.
Activation
(
'softmax'
))
M
.
add
(
KL
.
Activation
(
'softmax'
))
trainer
=
SimpleTrainer
()
dataset_train
,
dataset_test
=
get_data
()
setup_keras_trainer
(
M
=
KerasModel
(
M
,
QueueInput
(
dataset_train
))
trainer
,
M
.
compile
(
model
=
M
,
input
=
QueueInput
(
dataset_train
),
optimizer
=
tf
.
train
.
AdamOptimizer
(
1e-3
),
optimizer
=
tf
.
train
.
AdamOptimizer
(
1e-3
),
loss
=
'categorical_crossentropy'
,
loss
=
'categorical_crossentropy'
,
metrics
=
[
'accuracy'
]
metrics
=
[
'accuracy'
]
)
)
trainer
.
train_with_defaults
(
M
.
fit
(
callbacks
=
[
callbacks
=
[
ModelSaver
(),
ModelSaver
(),
InferenceRunner
(
InferenceRunner
(
...
...
tensorpack/contrib/keras.py
View file @
72385a85
...
@@ -11,6 +11,11 @@ from ..tfutils.tower import get_current_tower_context
...
@@ -11,6 +11,11 @@ from ..tfutils.tower import get_current_tower_context
from
..tfutils.collection
import
freeze_collection
from
..tfutils.collection
import
freeze_collection
from
..callbacks
import
Callback
,
InferenceRunner
,
CallbackToHook
from
..callbacks
import
Callback
,
InferenceRunner
,
CallbackToHook
from
..tfutils.summary
import
add_moving_summary
from
..tfutils.summary
import
add_moving_summary
from
..utils.gpu
import
get_nr_gpu
from
..train
import
Trainer
,
SimpleTrainer
,
SyncMultiGPUTrainerParameterServer
__all__
=
[
'KerasPhaseCallback'
,
'setup_keras_trainer'
,
'KerasModel'
]
# Keras needs an extra input if learning_phase is used by the model
# Keras needs an extra input if learning_phase is used by the model
...
@@ -95,3 +100,38 @@ def setup_keras_trainer(
...
@@ -95,3 +100,38 @@ def setup_keras_trainer(
lambda
:
optimizer
)
lambda
:
optimizer
)
if
model
.
uses_learning_phase
:
if
model
.
uses_learning_phase
:
trainer
.
register_callback
(
KerasPhaseCallback
(
True
))
trainer
.
register_callback
(
KerasPhaseCallback
(
True
))
class
KerasModel
(
object
):
def
__init__
(
self
,
model
,
input
,
trainer
=
None
):
"""
Args:
model (keras.model.Model):
"""
self
.
model
=
model
if
trainer
is
None
:
nr_gpu
=
get_nr_gpu
()
if
nr_gpu
<=
1
:
trainer
=
SimpleTrainer
()
else
:
trainer
=
SyncMultiGPUTrainerParameterServer
(
nr_gpu
)
assert
isinstance
(
trainer
,
Trainer
),
trainer
self
.
trainer
=
trainer
self
.
input
=
input
def
compile
(
self
,
optimizer
,
loss
,
metrics
):
setup_keras_trainer
(
self
.
trainer
,
model
=
self
.
model
,
input
=
self
.
input
,
optimizer
=
optimizer
,
loss
=
loss
,
metrics
=
metrics
)
def
fit
(
self
,
**
kwargs
):
callbacks
=
kwargs
.
pop
(
'callbacks'
,
[])
callbacks
.
extend
(
self
.
get_default_callbacks
())
self
.
trainer
.
train_with_defaults
(
**
kwargs
)
def
get_default_callbacks
(
self
):
return
[]
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment