Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
62d54f68
Commit
62d54f68
authored
Sep 21, 2017
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
update keras example
parent
ccd67e86
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
32 additions
and
46 deletions
+32
-46
examples/mnist-keras.py
examples/mnist-keras.py
+32
-46
No files found.
examples/mnist-keras.py
View file @
62d54f68
...
...
@@ -10,6 +10,7 @@ import os
import
sys
import
argparse
import
keras
import
keras.layers
as
KL
import
keras.backend
as
KB
from
keras.models
import
Sequential
...
...
@@ -27,14 +28,8 @@ from tensorpack.utils.argtools import memoized
IMAGE_SIZE
=
28
class
Model
(
ModelDesc
):
def
_get_inputs
(
self
):
return
[
InputDesc
(
tf
.
float32
,
(
None
,
IMAGE_SIZE
,
IMAGE_SIZE
),
'input'
),
InputDesc
(
tf
.
int32
,
(
None
,),
'label'
),
]
@
memoized
# this is necessary for sonnet/Keras to work under tensorpack
def
_build_keras_model
(
self
):
@
memoized
# this is necessary for sonnet/Keras to work under tensorpack
def
get_keras_model
():
M
=
Sequential
()
M
.
add
(
KL
.
Conv2D
(
32
,
3
,
activation
=
'relu'
,
input_shape
=
[
IMAGE_SIZE
,
IMAGE_SIZE
,
1
],
padding
=
'same'
))
M
.
add
(
KL
.
MaxPooling2D
())
...
...
@@ -48,31 +43,30 @@ class Model(ModelDesc):
M
.
add
(
KL
.
Dense
(
10
,
activation
=
None
,
kernel_regularizer
=
regularizers
.
l2
(
1e-5
)))
return
M
class
Model
(
ModelDesc
):
def
_get_inputs
(
self
):
return
[
InputDesc
(
tf
.
float32
,
(
None
,
IMAGE_SIZE
,
IMAGE_SIZE
),
'input'
),
InputDesc
(
tf
.
int32
,
(
None
,),
'label'
)]
def
_build_graph
(
self
,
inputs
):
image
,
label
=
inputs
image
=
tf
.
expand_dims
(
image
,
3
)
image
=
image
*
2
-
1
# center the pixels values at zero
image
=
tf
.
expand_dims
(
image
,
3
)
*
2
-
1
with
argscope
(
Conv2D
,
kernel_shape
=
3
,
nl
=
tf
.
nn
.
relu
,
out_channel
=
32
):
M
=
self
.
_build_keras_model
()
M
=
get_keras_model
()
logits
=
M
(
image
)
prob
=
tf
.
nn
.
softmax
(
logits
,
name
=
'prob'
)
# a Bx10 with probabilities
# a vector of length B with loss of each sample
# build cost function by tensorflow
cost
=
tf
.
nn
.
sparse_softmax_cross_entropy_with_logits
(
logits
=
logits
,
labels
=
label
)
cost
=
tf
.
reduce_mean
(
cost
,
name
=
'cross_entropy_loss'
)
# the average cross-entropy loss
# for tensorpack validation
wrong
=
symbolic_functions
.
prediction_incorrect
(
logits
,
label
,
name
=
'incorrect'
)
train_error
=
tf
.
reduce_mean
(
wrong
,
name
=
'train_error'
)
summary
.
add_moving_summary
(
train_error
)
wd_cost
=
tf
.
add_n
(
M
.
losses
,
name
=
'regularize_loss'
)
# this is how Keras manage regularizers
self
.
cost
=
tf
.
add_n
([
wd_cost
,
cost
],
name
=
'total_cost'
)
summary
.
add_moving_summary
(
cost
,
wd_cost
,
self
.
cost
)
# this is the keras naming
summary
.
add_param_summary
((
'conv2d.*/kernel'
,
[
'histogram'
,
'rms'
]))
summary
.
add_moving_summary
(
self
.
cost
)
def
_get_optimizer
(
self
):
lr
=
tf
.
train
.
exponential_decay
(
...
...
@@ -84,7 +78,7 @@ class Model(ModelDesc):
return
tf
.
train
.
AdamOptimizer
(
lr
)
# Keras needs an extra input
# Keras needs an extra input
if learning_phase is needed
class
KerasCallback
(
Callback
):
def
__init__
(
self
,
isTrain
):
self
.
_isTrain
=
isTrain
...
...
@@ -106,31 +100,23 @@ def get_config():
dataset_train
,
dataset_test
=
get_data
()
return
TrainConfig
(
model
=
Model
(
),
model
=
KerasModel
(
get_keras_model
()
),
dataflow
=
dataset_train
,
callbacks
=
[
KerasCallback
(
1
),
# for Keras training
KerasCallback
(
True
),
# for Keras training
ModelSaver
(),
InferenceRunner
(
dataset_test
,
[
ScalarStats
(
'cross_entropy_loss'
),
ClassificationError
(
'incorrect'
)],
extra_hooks
=
[
CallbackToHook
(
KerasCallback
(
0
))]),
# for keras inference
extra_hooks
=
[
CallbackToHook
(
KerasCallback
(
False
))]),
# for keras inference
],
max_epoch
=
100
,
)
if
__name__
==
'__main__'
:
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
'--gpu'
,
help
=
'comma separated list of GPU(s) to use.'
)
args
=
parser
.
parse_args
()
if
args
.
gpu
:
os
.
environ
[
'CUDA_VISIBLE_DEVICES'
]
=
args
.
gpu
config
=
get_config
()
if
args
.
gpu
:
config
.
nr_tower
=
len
(
args
.
gpu
.
split
(
','
))
if
config
.
nr_tower
>
1
:
SyncMultiGPUTrainer
(
config
)
.
train
()
else
:
QueueInputTrainer
(
config
)
.
train
()
# for multigpu training:
# config.nr_tower = 2
# SyncMultiGPUTrainer(config).train()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment