Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
0c519bdc
Commit
0c519bdc
authored
Apr 30, 2018
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Let Keras Model accept a real "tower func" with positional args (#739)
parent
3ef33a34
Changes
6
Show whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
17 additions
and
14 deletions
+17
-14
examples/ResNet/imagenet-resnet.py
examples/ResNet/imagenet-resnet.py
+1
-1
examples/keras/imagenet-resnet-keras.py
examples/keras/imagenet-resnet-keras.py
+2
-2
examples/keras/mnist-keras-v2.py
examples/keras/mnist-keras-v2.py
+3
-3
tensorpack/contrib/keras.py
tensorpack/contrib/keras.py
+8
-4
tensorpack/dataflow/common.py
tensorpack/dataflow/common.py
+1
-2
tensorpack/predict/config.py
tensorpack/predict/config.py
+2
-2
No files found.
examples/ResNet/imagenet-resnet.py
View file @
0c519bdc
#!/usr/bin/env python
#!/usr/bin/env python
# -*- coding:
UTF
-8 -*-
# -*- coding:
utf
-8 -*-
# File: imagenet-resnet.py
# File: imagenet-resnet.py
import
argparse
import
argparse
...
...
examples/keras/imagenet-resnet-keras.py
View file @
0c519bdc
...
@@ -85,8 +85,8 @@ def conv_block(input_tensor, kernel_size, filters, stage, block, strides=(2, 2))
...
@@ -85,8 +85,8 @@ def conv_block(input_tensor, kernel_size, filters, stage, block, strides=(2, 2))
return
x
return
x
def
resnet50
(
i
nputs
):
def
resnet50
(
i
mage
):
input
=
tf
.
layers
.
Input
(
tensor
=
inputs
[
0
]
)
input
=
Input
(
tensor
=
image
)
def
image_preprocess
(
image
):
def
image_preprocess
(
image
):
image
=
ImageNetModel
.
image_preprocess
(
image
)
image
=
ImageNetModel
.
image_preprocess
(
image
)
...
...
examples/keras/mnist-keras-v2.py
View file @
0c519bdc
...
@@ -31,16 +31,16 @@ def get_data():
...
@@ -31,16 +31,16 @@ def get_data():
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
logger
.
auto_set_dir
()
logger
.
auto_set_dir
(
'd'
)
def
model_func
(
i
nputs
):
def
model_func
(
i
mage
):
"""
"""
Keras model has to be created inside this function to be used with tensorpack.
Keras model has to be created inside this function to be used with tensorpack.
"""
"""
M
=
keras
.
models
.
Sequential
()
M
=
keras
.
models
.
Sequential
()
# input_tensor have to be used here for tensorpack trainer to function properly.
# input_tensor have to be used here for tensorpack trainer to function properly.
# Just use inputs[1], inputs[2] if you have multiple inputs.
# Just use inputs[1], inputs[2] if you have multiple inputs.
M
.
add
(
KL
.
InputLayer
(
input_tensor
=
i
nputs
[
0
]
))
M
.
add
(
KL
.
InputLayer
(
input_tensor
=
i
mage
))
M
.
add
(
KL
.
Conv2D
(
32
,
3
,
activation
=
'relu'
,
padding
=
'same'
))
M
.
add
(
KL
.
Conv2D
(
32
,
3
,
activation
=
'relu'
,
padding
=
'same'
))
M
.
add
(
KL
.
MaxPooling2D
())
M
.
add
(
KL
.
MaxPooling2D
())
M
.
add
(
KL
.
Conv2D
(
32
,
3
,
activation
=
'relu'
,
padding
=
'same'
))
M
.
add
(
KL
.
Conv2D
(
32
,
3
,
activation
=
'relu'
,
padding
=
'same'
))
...
...
tensorpack/contrib/keras.py
100644 → 100755
View file @
0c519bdc
...
@@ -47,13 +47,15 @@ class KerasModelCaller(object):
...
@@ -47,13 +47,15 @@ class KerasModelCaller(object):
def
__call__
(
self
,
input_tensors
):
def
__call__
(
self
,
input_tensors
):
"""
"""
Args:
input_tensors ([tf.Tensor])
Returns:
Returns:
output tensors of this tower, evaluated with the input tensors.
output tensors of this tower, evaluated with the input tensors.
"""
"""
reuse
=
tf
.
get_variable_scope
()
.
reuse
reuse
=
tf
.
get_variable_scope
()
.
reuse
if
self
.
cached_model
is
None
:
if
self
.
cached_model
is
None
:
assert
not
reuse
assert
not
reuse
self
.
cached_model
=
self
.
get_model
(
input_tensors
)
self
.
cached_model
=
self
.
get_model
(
*
input_tensors
)
return
self
.
cached_model
.
outputs
return
self
.
cached_model
.
outputs
if
reuse
:
if
reuse
:
...
@@ -63,7 +65,7 @@ class KerasModelCaller(object):
...
@@ -63,7 +65,7 @@ class KerasModelCaller(object):
return
self
.
cached_model
.
call
(
input_tensors
)
return
self
.
cached_model
.
call
(
input_tensors
)
else
:
else
:
# create new Keras model if not reuse
# create new Keras model if not reuse
M
=
self
.
get_model
(
input_tensors
)
M
=
self
.
get_model
(
*
input_tensors
)
return
M
.
outputs
return
M
.
outputs
...
@@ -99,7 +101,8 @@ def setup_keras_trainer(
...
@@ -99,7 +101,8 @@ def setup_keras_trainer(
"""
"""
Args:
Args:
trainer (SingleCostTrainer):
trainer (SingleCostTrainer):
get_model ( -> keras.model.Model):
get_model (input1, input2, ... -> keras.model.Model):
Takes tensors and returns a Keras model. Will be part of the tower function.
input (InputSource):
input (InputSource):
optimizer (tf.tarin.Optimizer):
optimizer (tf.tarin.Optimizer):
loss, metrics: list of strings
loss, metrics: list of strings
...
@@ -175,7 +178,8 @@ class KerasModel(object):
...
@@ -175,7 +178,8 @@ class KerasModel(object):
input
,
trainer
=
None
):
input
,
trainer
=
None
):
"""
"""
Args:
Args:
get_model ( -> keras.model.Model):
get_model (input1, input2, ... -> keras.model.Model):
Takes tensors and returns a Keras model. Will be part of the tower function.
inputs_desc ([InputDesc]):
inputs_desc ([InputDesc]):
targets_desc ([InputDesc]):
targets_desc ([InputDesc]):
input (InputSource | DataFlow):
input (InputSource | DataFlow):
...
...
tensorpack/dataflow/common.py
View file @
0c519bdc
# -*- coding:
UTF
-8 -*-
# -*- coding:
utf
-8 -*-
# File: common.py
# File: common.py
from
__future__
import
division
from
__future__
import
division
import
six
import
six
import
numpy
as
np
import
numpy
as
np
...
...
tensorpack/predict/config.py
View file @
0c519bdc
# -*- coding:
UTF
-8 -*-
# -*- coding:
utf
-8 -*-
# File: config.py
# File: config.py
...
@@ -35,7 +35,7 @@ class PredictConfig(object):
...
@@ -35,7 +35,7 @@ class PredictConfig(object):
Args:
Args:
model (ModelDescBase): to be used to obtain inputs_desc and tower_func.
model (ModelDescBase): to be used to obtain inputs_desc and tower_func.
inputs_desc ([InputDesc]):
inputs_desc ([InputDesc]):
tower_func: a callable which takes input tensors and construct a tower.
tower_func: a callable which takes input tensors
(by positional args)
and construct a tower.
input_names (list): a list of input tensor names. Defaults to match inputs_desc.
input_names (list): a list of input tensor names. Defaults to match inputs_desc.
output_names (list): a list of names of the output tensors to predict, the
output_names (list): a list of names of the output tensors to predict, the
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment