Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
b8a50d72
Commit
b8a50d72
authored
Mar 19, 2019
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
InputDesc -> tf.TensorSpec everywhere
parent
ba679ab1
Changes
29
Hide whitespace changes
Inline
Side-by-side
Showing
29 changed files
with
237 additions
and
235 deletions
+237
-235
CHANGES.md
CHANGES.md
+3
-0
docs/conf.py
docs/conf.py
+2
-0
docs/tutorial/extend/trainer.md
docs/tutorial/extend/trainer.md
+1
-1
docs/tutorial/training-interface.md
docs/tutorial/training-interface.md
+2
-2
examples/CaffeModels/load-alexnet.py
examples/CaffeModels/load-alexnet.py
+1
-1
examples/CaffeModels/load-cpm.py
examples/CaffeModels/load-cpm.py
+1
-1
examples/CaffeModels/load-vgg16.py
examples/CaffeModels/load-vgg16.py
+1
-1
examples/CaffeModels/load-vgg19.py
examples/CaffeModels/load-vgg19.py
+1
-1
examples/GAN/GAN.py
examples/GAN/GAN.py
+5
-5
examples/ImageNetModels/shufflenet.py
examples/ImageNetModels/shufflenet.py
+4
-7
examples/basics/cifar-convnet.py
examples/basics/cifar-convnet.py
+2
-2
examples/keras/imagenet-resnet-keras.py
examples/keras/imagenet-resnet-keras.py
+3
-3
examples/keras/mnist-keras-v2.py
examples/keras/mnist-keras-v2.py
+3
-3
tensorpack/callbacks/inference_runner.py
tensorpack/callbacks/inference_runner.py
+2
-2
tensorpack/callbacks/param.py
tensorpack/callbacks/param.py
+2
-2
tensorpack/contrib/keras.py
tensorpack/contrib/keras.py
+19
-13
tensorpack/graph_builder/model_desc.py
tensorpack/graph_builder/model_desc.py
+50
-83
tensorpack/input_source/input_source.py
tensorpack/input_source/input_source.py
+35
-35
tensorpack/input_source/input_source_base.py
tensorpack/input_source/input_source_base.py
+12
-11
tensorpack/predict/base.py
tensorpack/predict/base.py
+1
-1
tensorpack/predict/config.py
tensorpack/predict/config.py
+36
-22
tensorpack/predict/feedfree.py
tensorpack/predict/feedfree.py
+2
-2
tensorpack/predict/multigpu.py
tensorpack/predict/multigpu.py
+4
-5
tensorpack/tfutils/export.py
tensorpack/tfutils/export.py
+3
-3
tensorpack/tfutils/tower.py
tensorpack/tfutils/tower.py
+22
-14
tensorpack/train/interface.py
tensorpack/train/interface.py
+1
-1
tensorpack/train/tower.py
tensorpack/train/tower.py
+15
-10
tensorpack/train/trainers.py
tensorpack/train/trainers.py
+2
-2
tensorpack/utils/concurrency.py
tensorpack/utils/concurrency.py
+2
-2
No files found.
CHANGES.md
View file @
b8a50d72
...
@@ -8,6 +8,9 @@ so you don't need to look at here very often.
...
@@ -8,6 +8,9 @@ so you don't need to look at here very often.
Here are a list of things that were changed, starting from an early version.
Here are a list of things that were changed, starting from an early version.
TensorFlow itself also changes API and those are not listed here.
TensorFlow itself also changes API and those are not listed here.
+
[2019/03/20] The concept of
`InputDesc`
was replaced by its equivalent in TF:
`tf.TensorSpec`
. This may be a breaking change if you have customized
code that relies on internals of
`InputDesc`
.
+
[2018/08/27] msgpack is used again for "serialization to disk", because pyarrow
+
[2018/08/27] msgpack is used again for "serialization to disk", because pyarrow
has no compatibility between versions. To use pyarrow instead,
`export TENSORPACK_COMPATIBLE_SERIALIZE=pyarrow`
.
has no compatibility between versions. To use pyarrow instead,
`export TENSORPACK_COMPATIBLE_SERIALIZE=pyarrow`
.
+
[2018/04/05] msgpack is replaced by pyarrow in favor of its speed. If you want old behavior,
+
[2018/04/05] msgpack is replaced by pyarrow in favor of its speed. If you want old behavior,
...
...
docs/conf.py
View file @
b8a50d72
...
@@ -375,6 +375,8 @@ _DEPRECATED_NAMES = set([
...
@@ -375,6 +375,8 @@ _DEPRECATED_NAMES = set([
'PrefetchOnGPUs'
,
'PrefetchOnGPUs'
,
'DistributedTrainerReplicated'
,
'DistributedTrainerReplicated'
,
'DistributedTrainerParameterServer'
,
'DistributedTrainerParameterServer'
,
'InputDesc'
,
'inputs_desc'
,
# renamed items that should not appear in docs
# renamed items that should not appear in docs
'DumpTensor'
,
'DumpTensor'
,
...
...
docs/tutorial/extend/trainer.md
View file @
b8a50d72
...
@@ -48,7 +48,7 @@ Most neural network training tasks are single-cost optimization.
...
@@ -48,7 +48,7 @@ Most neural network training tasks are single-cost optimization.
Tensorpack provides some trainer implementations for such tasks.
Tensorpack provides some trainer implementations for such tasks.
These trainers will take care of step 1 (define the graph), with the following arguments:
These trainers will take care of step 1 (define the graph), with the following arguments:
1.
Some
`
InputDesc`
, the metadata about
the input.
1.
Some
`
tf.TensorSpec`
, the signature of
the input.
2.
An
`InputSource`
, where the input come from. See
[
Input Pipeline
](
input-source.html
)
.
2.
An
`InputSource`
, where the input come from. See
[
Input Pipeline
](
input-source.html
)
.
3.
A function which takes input tensors and returns the cost.
3.
A function which takes input tensors and returns the cost.
4.
A function which returns an optimizer.
4.
A function which returns an optimizer.
...
...
docs/tutorial/training-interface.md
View file @
b8a50d72
...
@@ -11,7 +11,7 @@ This interface is enough for most types of single-cost tasks.
...
@@ -11,7 +11,7 @@ This interface is enough for most types of single-cost tasks.
A lot of examples are written in this interface.
A lot of examples are written in this interface.
[
SingleCost trainers
](
../modules/train.html#tensorpack.train.SingleCostTrainer
)
[
SingleCost trainers
](
../modules/train.html#tensorpack.train.SingleCostTrainer
)
expects 4 arguments to setup the graph:
`InputDesc`
,
`InputSource`
, get_cost function, and an optimizer.
expects 4 arguments to setup the graph:
input signatures
,
`InputSource`
, get_cost function, and an optimizer.
`ModelDesc`
describes a model by packing 3 of them together into one object:
`ModelDesc`
describes a model by packing 3 of them together into one object:
```
python
```
python
...
@@ -62,7 +62,7 @@ The function `launch_train_with_config(config, trainer)`
...
@@ -62,7 +62,7 @@ The function `launch_train_with_config(config, trainer)`
uses the raw trainer interface under the hood, and is almost equivalent to the following two lines of code:
uses the raw trainer interface under the hood, and is almost equivalent to the following two lines of code:
```
python
```
python
trainer
.
setup_graph
(
trainer
.
setup_graph
(
my_model
.
get_input
s_desc
(),
my_model
.
get_input
_signature
(),
my_input_source
,
# or QueueInput(my_dataflow)
my_input_source
,
# or QueueInput(my_dataflow)
my_model
.
build_graph
,
my_model
.
build_graph
,
my_model
.
get_optimizer
)
my_model
.
get_optimizer
)
...
...
examples/CaffeModels/load-alexnet.py
View file @
b8a50d72
...
@@ -42,7 +42,7 @@ def tower_func(image):
...
@@ -42,7 +42,7 @@ def tower_func(image):
def
run_test
(
path
,
input
):
def
run_test
(
path
,
input
):
param_dict
=
dict
(
np
.
load
(
path
))
param_dict
=
dict
(
np
.
load
(
path
))
predictor
=
OfflinePredictor
(
PredictConfig
(
predictor
=
OfflinePredictor
(
PredictConfig
(
input
s_desc
=
[
InputDesc
(
tf
.
float32
,
(
None
,
227
,
227
,
3
)
,
'input'
)],
input
_signature
=
[
tf
.
TensorSpec
((
None
,
227
,
227
,
3
),
tf
.
float32
,
'input'
)],
tower_func
=
tower_func
,
tower_func
=
tower_func
,
session_init
=
DictRestore
(
param_dict
),
session_init
=
DictRestore
(
param_dict
),
input_names
=
[
'input'
],
input_names
=
[
'input'
],
...
...
examples/CaffeModels/load-cpm.py
View file @
b8a50d72
...
@@ -97,7 +97,7 @@ def CPM(image):
...
@@ -97,7 +97,7 @@ def CPM(image):
def
run_test
(
model_path
,
img_file
):
def
run_test
(
model_path
,
img_file
):
param_dict
=
dict
(
np
.
load
(
model_path
))
param_dict
=
dict
(
np
.
load
(
model_path
))
predict_func
=
OfflinePredictor
(
PredictConfig
(
predict_func
=
OfflinePredictor
(
PredictConfig
(
input
s_desc
=
[
InputDesc
(
tf
.
float32
,
(
None
,
368
,
368
,
3
)
,
'input'
)],
input
_signature
=
[
tf
.
TensorSpec
((
None
,
368
,
368
,
3
),
tf
.
float32
,
'input'
)],
tower_func
=
CPM
,
tower_func
=
CPM
,
session_init
=
DictRestore
(
param_dict
),
session_init
=
DictRestore
(
param_dict
),
input_names
=
[
'input'
],
input_names
=
[
'input'
],
...
...
examples/CaffeModels/load-vgg16.py
View file @
b8a50d72
...
@@ -59,7 +59,7 @@ def run_test(path, input):
...
@@ -59,7 +59,7 @@ def run_test(path, input):
param_dict
=
{
k
.
replace
(
'/W'
,
'/kernel'
)
.
replace
(
'/b'
,
'/bias'
):
v
for
k
,
v
in
six
.
iteritems
(
param_dict
)}
param_dict
=
{
k
.
replace
(
'/W'
,
'/kernel'
)
.
replace
(
'/b'
,
'/bias'
):
v
for
k
,
v
in
six
.
iteritems
(
param_dict
)}
predict_func
=
OfflinePredictor
(
PredictConfig
(
predict_func
=
OfflinePredictor
(
PredictConfig
(
input
s_desc
=
[
InputDesc
(
tf
.
float32
,
(
None
,
224
,
224
,
3
)
,
'input'
)],
input
_signature
=
[
tf
.
TensorSpec
((
None
,
224
,
224
,
3
),
tf
.
float32
,
'input'
)],
tower_func
=
tower_func
,
tower_func
=
tower_func
,
session_init
=
DictRestore
(
param_dict
),
session_init
=
DictRestore
(
param_dict
),
input_names
=
[
'input'
],
input_names
=
[
'input'
],
...
...
examples/CaffeModels/load-vgg19.py
View file @
b8a50d72
...
@@ -62,7 +62,7 @@ def run_test(path, input):
...
@@ -62,7 +62,7 @@ def run_test(path, input):
param_dict
=
{
k
.
replace
(
'/W'
,
'/kernel'
)
.
replace
(
'/b'
,
'/bias'
):
v
for
k
,
v
in
six
.
iteritems
(
param_dict
)}
param_dict
=
{
k
.
replace
(
'/W'
,
'/kernel'
)
.
replace
(
'/b'
,
'/bias'
):
v
for
k
,
v
in
six
.
iteritems
(
param_dict
)}
predict_func
=
OfflinePredictor
(
PredictConfig
(
predict_func
=
OfflinePredictor
(
PredictConfig
(
input
s_desc
=
[
InputDesc
(
tf
.
float32
,
(
None
,
224
,
224
,
3
)
,
'input'
)],
input
_signature
=
[
tf
.
TensorSpec
((
None
,
224
,
224
,
3
),
tf
.
float32
,
'input'
)],
tower_func
=
tower_func
,
tower_func
=
tower_func
,
session_init
=
DictRestore
(
param_dict
),
session_init
=
DictRestore
(
param_dict
),
input_names
=
[
'input'
],
input_names
=
[
'input'
],
...
...
examples/GAN/GAN.py
View file @
b8a50d72
...
@@ -88,7 +88,7 @@ class GANTrainer(TowerTrainer):
...
@@ -88,7 +88,7 @@ class GANTrainer(TowerTrainer):
input
=
StagingInput
(
input
)
input
=
StagingInput
(
input
)
# Setup input
# Setup input
cbs
=
input
.
setup
(
model
.
get_input
s_desc
())
cbs
=
input
.
setup
(
model
.
get_input
_signature
())
self
.
register_callback
(
cbs
)
self
.
register_callback
(
cbs
)
if
num_gpu
<=
1
:
if
num_gpu
<=
1
:
...
@@ -105,7 +105,7 @@ class GANTrainer(TowerTrainer):
...
@@ -105,7 +105,7 @@ class GANTrainer(TowerTrainer):
not needed. Just calling model.build_graph directly is OK.
not needed. Just calling model.build_graph directly is OK.
"""
"""
# Build the graph
# Build the graph
self
.
tower_func
=
TowerFuncWrapper
(
model
.
build_graph
,
model
.
get_input
s_desc
())
self
.
tower_func
=
TowerFuncWrapper
(
model
.
build_graph
,
model
.
get_input
_signature
())
with
TowerContext
(
''
,
is_training
=
True
):
with
TowerContext
(
''
,
is_training
=
True
):
self
.
tower_func
(
*
input
.
get_input_tensors
())
self
.
tower_func
(
*
input
.
get_input_tensors
())
opt
=
model
.
get_optimizer
()
opt
=
model
.
get_optimizer
()
...
@@ -127,7 +127,7 @@ class GANTrainer(TowerTrainer):
...
@@ -127,7 +127,7 @@ class GANTrainer(TowerTrainer):
model
.
build_graph
(
*
inputs
)
model
.
build_graph
(
*
inputs
)
return
[
model
.
d_loss
,
model
.
g_loss
]
return
[
model
.
d_loss
,
model
.
g_loss
]
self
.
tower_func
=
TowerFuncWrapper
(
get_cost
,
model
.
get_input
s_desc
())
self
.
tower_func
=
TowerFuncWrapper
(
get_cost
,
model
.
get_input
_signature
())
devices
=
[
LeastLoadedDeviceSetter
(
d
,
raw_devices
)
for
d
in
raw_devices
]
devices
=
[
LeastLoadedDeviceSetter
(
d
,
raw_devices
)
for
d
in
raw_devices
]
cost_list
=
DataParallelBuilder
.
build_on_towers
(
cost_list
=
DataParallelBuilder
.
build_on_towers
(
list
(
range
(
num_gpu
)),
list
(
range
(
num_gpu
)),
...
@@ -163,11 +163,11 @@ class SeparateGANTrainer(TowerTrainer):
...
@@ -163,11 +163,11 @@ class SeparateGANTrainer(TowerTrainer):
assert
min
(
d_period
,
g_period
)
==
1
assert
min
(
d_period
,
g_period
)
==
1
# Setup input
# Setup input
cbs
=
input
.
setup
(
model
.
get_input
s_desc
())
cbs
=
input
.
setup
(
model
.
get_input
_signature
())
self
.
register_callback
(
cbs
)
self
.
register_callback
(
cbs
)
# Build the graph
# Build the graph
self
.
tower_func
=
TowerFuncWrapper
(
model
.
build_graph
,
model
.
get_input
s_desc
())
self
.
tower_func
=
TowerFuncWrapper
(
model
.
build_graph
,
model
.
get_input
_signature
())
with
TowerContext
(
''
,
is_training
=
True
),
\
with
TowerContext
(
''
,
is_training
=
True
),
\
argscope
(
BatchNorm
,
internal_update
=
True
):
argscope
(
BatchNorm
,
internal_update
=
True
):
# should not hook the updates to both train_op, it will hurt training speed.
# should not hook the updates to both train_op, it will hurt training speed.
...
...
examples/ImageNetModels/shufflenet.py
View file @
b8a50d72
...
@@ -254,14 +254,11 @@ if __name__ == '__main__':
...
@@ -254,14 +254,11 @@ if __name__ == '__main__':
eval_on_ILSVRC12
(
model
,
get_model_loader
(
args
.
load
),
ds
)
eval_on_ILSVRC12
(
model
,
get_model_loader
(
args
.
load
),
ds
)
elif
args
.
flops
:
elif
args
.
flops
:
# manually build the graph with batch=1
# manually build the graph with batch=1
input_desc
=
[
InputDesc
(
tf
.
float32
,
[
1
,
224
,
224
,
3
],
'input'
),
InputDesc
(
tf
.
int32
,
[
1
],
'label'
)
]
input
=
PlaceholderInput
()
input
.
setup
(
input_desc
)
with
TowerContext
(
''
,
is_training
=
False
):
with
TowerContext
(
''
,
is_training
=
False
):
model
.
build_graph
(
*
input
.
get_input_tensors
())
model
.
build_graph
(
tf
.
placeholder
(
tf
.
float32
,
[
1
,
224
,
224
,
3
],
'input'
),
tf
.
placeholder
(
tf
.
int32
,
[
1
],
'label'
)
)
model_utils
.
describe_trainable_vars
()
model_utils
.
describe_trainable_vars
()
tf
.
profiler
.
profile
(
tf
.
profiler
.
profile
(
...
...
examples/basics/cifar-convnet.py
View file @
b8a50d72
...
@@ -64,7 +64,7 @@ class Model(ModelDesc):
...
@@ -64,7 +64,7 @@ class Model(ModelDesc):
cost
=
tf
.
nn
.
sparse_softmax_cross_entropy_with_logits
(
logits
=
logits
,
labels
=
label
)
cost
=
tf
.
nn
.
sparse_softmax_cross_entropy_with_logits
(
logits
=
logits
,
labels
=
label
)
cost
=
tf
.
reduce_mean
(
cost
,
name
=
'cross_entropy_loss'
)
cost
=
tf
.
reduce_mean
(
cost
,
name
=
'cross_entropy_loss'
)
correct
=
tf
.
cast
(
tf
.
nn
.
in_top_k
(
logits
,
label
,
1
),
tf
.
float32
,
name
=
'correct'
)
correct
=
tf
.
cast
(
tf
.
nn
.
in_top_k
(
predictions
=
logits
,
targets
=
label
,
k
=
1
),
tf
.
float32
,
name
=
'correct'
)
# monitor training error
# monitor training error
add_moving_summary
(
tf
.
reduce_mean
(
correct
,
name
=
'accuracy'
))
add_moving_summary
(
tf
.
reduce_mean
(
correct
,
name
=
'accuracy'
))
...
@@ -76,7 +76,7 @@ class Model(ModelDesc):
...
@@ -76,7 +76,7 @@ class Model(ModelDesc):
return
tf
.
add_n
([
cost
,
wd_cost
],
name
=
'cost'
)
return
tf
.
add_n
([
cost
,
wd_cost
],
name
=
'cost'
)
def
optimizer
(
self
):
def
optimizer
(
self
):
lr
=
tf
.
get_variable
(
'learning_rate'
,
initializer
=
1e-2
,
trainable
=
False
)
lr
=
tf
.
Variable
(
1e-2
,
name
=
'learning_rate'
,
trainable
=
False
)
tf
.
summary
.
scalar
(
'lr'
,
lr
)
tf
.
summary
.
scalar
(
'lr'
,
lr
)
return
tf
.
train
.
AdamOptimizer
(
lr
,
epsilon
=
1e-3
)
return
tf
.
train
.
AdamOptimizer
(
lr
,
epsilon
=
1e-3
)
...
...
examples/keras/imagenet-resnet-keras.py
View file @
b8a50d72
...
@@ -9,7 +9,7 @@ import os
...
@@ -9,7 +9,7 @@ import os
import
tensorflow
as
tf
import
tensorflow
as
tf
from
tensorflow.python.keras.layers
import
*
from
tensorflow.python.keras.layers
import
*
from
tensorpack
import
InputDesc
,
SyncMultiGPUTrainerReplicated
from
tensorpack
import
SyncMultiGPUTrainerReplicated
from
tensorpack.callbacks
import
*
from
tensorpack.callbacks
import
*
from
tensorpack.contrib.keras
import
KerasModel
from
tensorpack.contrib.keras
import
KerasModel
from
tensorpack.dataflow
import
FakeData
,
MapDataComponent
from
tensorpack.dataflow
import
FakeData
,
MapDataComponent
...
@@ -166,8 +166,8 @@ if __name__ == '__main__':
...
@@ -166,8 +166,8 @@ if __name__ == '__main__':
M
=
KerasModel
(
M
=
KerasModel
(
resnet50
,
resnet50
,
input
s_desc
=
[
InputDesc
(
tf
.
uint8
,
[
None
,
224
,
224
,
3
]
,
'images'
)],
input
_signature
=
[
tf
.
TensorSpec
([
None
,
224
,
224
,
3
],
tf
.
uint8
,
'images'
)],
target
s_desc
=
[
InputDesc
(
tf
.
float32
,
[
None
,
1000
]
,
'labels'
)],
target
_signature
=
[
tf
.
TensorSpec
([
None
,
1000
],
tf
.
float32
,
'labels'
)],
input
=
df_train
,
input
=
df_train
,
trainer
=
SyncMultiGPUTrainerReplicated
(
num_gpu
))
trainer
=
SyncMultiGPUTrainerReplicated
(
num_gpu
))
...
...
examples/keras/mnist-keras-v2.py
View file @
b8a50d72
...
@@ -7,7 +7,7 @@ import numpy as np
...
@@ -7,7 +7,7 @@ import numpy as np
import
tensorflow
as
tf
import
tensorflow
as
tf
from
tensorflow
import
keras
from
tensorflow
import
keras
from
tensorpack
import
InputDesc
,
QueueInput
from
tensorpack
import
QueueInput
from
tensorpack.callbacks
import
ModelSaver
from
tensorpack.callbacks
import
ModelSaver
from
tensorpack.contrib.keras
import
KerasModel
from
tensorpack.contrib.keras
import
KerasModel
from
tensorpack.dataflow
import
BatchData
,
MapData
,
dataset
from
tensorpack.dataflow
import
BatchData
,
MapData
,
dataset
...
@@ -57,8 +57,8 @@ if __name__ == '__main__':
...
@@ -57,8 +57,8 @@ if __name__ == '__main__':
M
=
KerasModel
(
M
=
KerasModel
(
model_func
,
model_func
,
input
s_desc
=
[
InputDesc
(
tf
.
float32
,
[
None
,
IMAGE_SIZE
,
IMAGE_SIZE
,
1
]
,
'images'
)],
input
_signature
=
[
tf
.
TensorSpec
([
None
,
IMAGE_SIZE
,
IMAGE_SIZE
,
1
],
tf
.
float32
,
'images'
)],
target
s_desc
=
[
InputDesc
(
tf
.
float32
,
[
None
,
10
]
,
'labels'
)],
target
_signature
=
[
tf
.
TensorSpec
([
None
,
10
],
tf
.
float32
,
'labels'
)],
input
=
QueueInput
(
dataset_train
))
input
=
QueueInput
(
dataset_train
))
M
.
compile
(
M
.
compile
(
optimizer
=
tf
.
train
.
AdamOptimizer
(
1e-3
),
optimizer
=
tf
.
train
.
AdamOptimizer
(
1e-3
),
...
...
tensorpack/callbacks/inference_runner.py
View file @
b8a50d72
...
@@ -141,7 +141,7 @@ class InferenceRunner(InferenceRunnerBase):
...
@@ -141,7 +141,7 @@ class InferenceRunner(InferenceRunnerBase):
if
self
.
_tower_func
is
None
:
if
self
.
_tower_func
is
None
:
assert
self
.
trainer
.
tower_func
is
not
None
,
"You must set tower_func of the trainer to use InferenceRunner!"
assert
self
.
trainer
.
tower_func
is
not
None
,
"You must set tower_func of the trainer to use InferenceRunner!"
self
.
_tower_func
=
self
.
trainer
.
tower_func
self
.
_tower_func
=
self
.
trainer
.
tower_func
input_callbacks
=
self
.
_input_source
.
setup
(
self
.
_tower_func
.
input
s_desc
)
input_callbacks
=
self
.
_input_source
.
setup
(
self
.
_tower_func
.
input
_signature
)
vs_name
=
self
.
trainer
.
_vs_name_for_predictor
(
self
.
_device_id
)
vs_name
=
self
.
trainer
.
_vs_name_for_predictor
(
self
.
_device_id
)
logger
.
info
(
"[InferenceRunner] Building tower '{}' on device {} {}..."
.
format
(
logger
.
info
(
"[InferenceRunner] Building tower '{}' on device {} {}..."
.
format
(
...
@@ -223,7 +223,7 @@ class DataParallelInferenceRunner(InferenceRunnerBase):
...
@@ -223,7 +223,7 @@ class DataParallelInferenceRunner(InferenceRunnerBase):
assert
self
.
trainer
.
tower_func
is
not
None
,
"You must set tower_func of the trainer to use InferenceRunner!"
assert
self
.
trainer
.
tower_func
is
not
None
,
"You must set tower_func of the trainer to use InferenceRunner!"
self
.
_tower_func
=
self
.
trainer
.
tower_func
self
.
_tower_func
=
self
.
trainer
.
tower_func
input_callbacks
=
self
.
_input_source
.
setup
(
self
.
_tower_func
.
input
s_desc
)
input_callbacks
=
self
.
_input_source
.
setup
(
self
.
_tower_func
.
input
_signature
)
with
tf
.
variable_scope
(
tf
.
get_variable_scope
(),
reuse
=
True
):
with
tf
.
variable_scope
(
tf
.
get_variable_scope
(),
reuse
=
True
):
for
idx
,
dev
in
enumerate
(
self
.
_devices
):
for
idx
,
dev
in
enumerate
(
self
.
_devices
):
vs_name
=
self
.
trainer
.
_vs_name_for_predictor
(
idx
)
vs_name
=
self
.
trainer
.
_vs_name_for_predictor
(
idx
)
...
...
tensorpack/callbacks/param.py
View file @
b8a50d72
...
@@ -8,8 +8,8 @@ import numpy as np
...
@@ -8,8 +8,8 @@ import numpy as np
from
abc
import
ABCMeta
,
abstractmethod
from
abc
import
ABCMeta
,
abstractmethod
from
collections
import
deque
from
collections
import
deque
import
six
import
six
import
tensorflow
as
tf
from
..compat
import
tfv1
from
..tfutils.common
import
get_op_tensor_name
from
..tfutils.common
import
get_op_tensor_name
from
..utils
import
logger
from
..utils
import
logger
from
.base
import
Callback
from
.base
import
Callback
...
@@ -67,7 +67,7 @@ class GraphVarParam(HyperParam):
...
@@ -67,7 +67,7 @@ class GraphVarParam(HyperParam):
def
setup_graph
(
self
):
def
setup_graph
(
self
):
""" Will setup the assign operator for that variable. """
""" Will setup the assign operator for that variable. """
all_vars
=
tf
.
global_variables
()
+
tf
.
local_variables
()
all_vars
=
tf
v1
.
global_variables
()
+
tfv1
.
local_variables
()
for
v
in
all_vars
:
for
v
in
all_vars
:
if
v
.
name
==
self
.
var_name
:
if
v
.
name
==
self
.
var_name
:
self
.
var
=
v
self
.
var
=
v
...
...
tensorpack/contrib/keras.py
View file @
b8a50d72
...
@@ -141,7 +141,7 @@ class KerasPhaseCallback(Callback):
...
@@ -141,7 +141,7 @@ class KerasPhaseCallback(Callback):
def
setup_keras_trainer
(
def
setup_keras_trainer
(
trainer
,
get_model
,
trainer
,
get_model
,
input
s_desc
,
targets_desc
,
input
_signature
,
target_signature
,
input
,
optimizer
,
loss
,
metrics
):
input
,
optimizer
,
loss
,
metrics
):
"""
"""
Args:
Args:
...
@@ -159,7 +159,7 @@ def setup_keras_trainer(
...
@@ -159,7 +159,7 @@ def setup_keras_trainer(
assert
isinstance
(
metrics
,
list
),
metrics
assert
isinstance
(
metrics
,
list
),
metrics
model_caller
=
KerasModelCaller
(
get_model
)
model_caller
=
KerasModelCaller
(
get_model
)
nr_inputs
=
len
(
input
s_desc
)
nr_inputs
=
len
(
input
_signature
)
def
get_cost
(
*
inputs
):
def
get_cost
(
*
inputs
):
ctx
=
get_current_tower_context
()
ctx
=
get_current_tower_context
()
...
@@ -211,7 +211,7 @@ def setup_keras_trainer(
...
@@ -211,7 +211,7 @@ def setup_keras_trainer(
return
total_loss
return
total_loss
trainer
.
setup_graph
(
trainer
.
setup_graph
(
input
s_desc
+
targets_desc
,
input
_signature
+
target_signature
,
input
,
input
,
get_cost
,
get_cost
,
lambda
:
optimizer
)
lambda
:
optimizer
)
...
@@ -221,23 +221,27 @@ def setup_keras_trainer(
...
@@ -221,23 +221,27 @@ def setup_keras_trainer(
class
KerasModel
(
object
):
class
KerasModel
(
object
):
def
__init__
(
self
,
get_model
,
input
s_desc
,
targets_desc
,
def
__init__
(
self
,
get_model
,
input
_signature
=
None
,
target_signature
=
None
,
input
,
trainer
=
None
):
input
=
None
,
trainer
=
None
,
inputs_desc
=
None
,
targets_desc
=
None
):
"""
"""
Args:
Args:
get_model (input1, input2, ... -> keras.Model):
get_model (input1, input2, ... -> keras.Model):
A function which takes tensors, builds and returns a Keras model.
A function which takes tensors, builds and returns a Keras model.
It will be part of the tower function.
It will be part of the tower function.
input
s_desc ([InputDesc]):
input
_signature ([tf.TensorSpec]): required. The signature for inputs.
target
s_desc ([InputDesc]):
target
_signature ([tf.TensorSpec]): required. The signature for the targets tensors.
input (InputSource | DataFlow):
input (InputSource | DataFlow):
the InputSource or DataFlow where the input data comes from.
trainer (Trainer): the default will check the number of available
trainer (Trainer): the default will check the number of available
GPUs and use them all.
GPUs and use them all.
inputs_desc, targets_desc: deprecated names for `input_signature` and `target_signature`
"""
"""
if
inputs_desc
is
not
None
:
input_signature
=
inputs_desc
if
targets_desc
is
not
None
:
target_signature
=
targets_desc
self
.
get_model
=
get_model
self
.
get_model
=
get_model
assert
callable
(
get_model
),
get_model
assert
callable
(
get_model
),
get_model
self
.
input
s_desc
=
inputs_desc
self
.
input
_signature
=
input_signature
self
.
target
s_desc
=
targets_desc
self
.
target
_signature
=
target_signature
if
trainer
is
None
:
if
trainer
is
None
:
nr_gpu
=
get_nr_gpu
()
nr_gpu
=
get_nr_gpu
()
if
nr_gpu
<=
1
:
if
nr_gpu
<=
1
:
...
@@ -248,6 +252,7 @@ class KerasModel(object):
...
@@ -248,6 +252,7 @@ class KerasModel(object):
assert
isinstance
(
trainer
,
Trainer
),
trainer
assert
isinstance
(
trainer
,
Trainer
),
trainer
assert
not
isinstance
(
trainer
,
DistributedTrainerBase
)
assert
not
isinstance
(
trainer
,
DistributedTrainerBase
)
assert
input
is
not
None
,
"Argument 'input' is required!"
self
.
input
=
apply_default_prefetch
(
input
,
trainer
)
self
.
input
=
apply_default_prefetch
(
input
,
trainer
)
self
.
trainer
=
trainer
self
.
trainer
=
trainer
...
@@ -267,7 +272,8 @@ class KerasModel(object):
...
@@ -267,7 +272,8 @@ class KerasModel(object):
self
.
_stats_to_inference
=
loss
+
metrics
+
[
TOTAL_LOSS_NAME
]
self
.
_stats_to_inference
=
loss
+
metrics
+
[
TOTAL_LOSS_NAME
]
setup_keras_trainer
(
setup_keras_trainer
(
self
.
trainer
,
get_model
=
self
.
get_model
,
self
.
trainer
,
get_model
=
self
.
get_model
,
inputs_desc
=
self
.
inputs_desc
,
targets_desc
=
self
.
targets_desc
,
input_signature
=
self
.
input_signature
,
target_signature
=
self
.
target_signature
,
input
=
self
.
input
,
input
=
self
.
input
,
optimizer
=
optimizer
,
optimizer
=
optimizer
,
loss
=
loss
,
loss
=
loss
,
...
...
tensorpack/graph_builder/model_desc.py
View file @
b8a50d72
...
@@ -18,12 +18,41 @@ TensorSpec = backport_tensor_spec()
...
@@ -18,12 +18,41 @@ TensorSpec = backport_tensor_spec()
__all__
=
[
'InputDesc'
,
'ModelDesc'
,
'ModelDescBase'
]
__all__
=
[
'InputDesc'
,
'ModelDesc'
,
'ModelDescBase'
]
def
build_or_reuse_placeholder
(
tensor_spec
):
"""
Build a tf.placeholder from the metadata in the given tensor spec, or return an existing one.
Args:
tensor_spec (tf.TensorSpec):
Returns:
tf.Tensor:
"""
g
=
tfv1
.
get_default_graph
()
name
=
tensor_spec
.
name
try
:
tensor
=
g
.
get_tensor_by_name
(
name
+
':0'
)
assert
"Placeholder"
in
tensor
.
op
.
type
,
"Tensor {} exists but is not a placeholder!"
.
format
(
name
)
assert
tensor_spec
.
is_compatible_with
(
tensor
),
\
"Tensor {} exists but is not compatible with the signature!"
.
format
(
tensor
)
return
tensor
except
KeyError
:
with
tfv1
.
name_scope
(
None
):
# clear any name scope it might get called in
ret
=
tfv1
.
placeholder
(
tensor_spec
.
dtype
,
shape
=
tensor_spec
.
shape
,
name
=
tensor_spec
.
name
)
return
ret
class
InputDesc
(
class
InputDesc
(
namedtuple
(
'InputDescTuple'
,
[
'type'
,
'shape'
,
'name'
])):
namedtuple
(
'InputDescTuple'
,
[
'type'
,
'shape'
,
'name'
])):
"""
"""
Metadata about an input entry point to the graph.
An equivalent of `tf.TensorSpec`.
This metadata can be later used to build placeholders or other types of
input source.
History: this concept is used to represent metadata about the inputs,
which can be later used to build placeholders or other types of input source.
It is introduced much much earlier than the equivalent concept `tf.TensorSpec`
was introduced in TensorFlow.
Therefore, we now switched to use `tf.TensorSpec`, but keep this here for compatibility reasons.
"""
"""
def
__new__
(
cls
,
type
,
shape
,
name
):
def
__new__
(
cls
,
type
,
shape
,
name
):
...
@@ -33,64 +62,9 @@ class InputDesc(
...
@@ -33,64 +62,9 @@ class InputDesc(
shape (tuple):
shape (tuple):
name (str):
name (str):
"""
"""
shape
=
tuple
(
shape
)
# has to be tuple for "self" to be hashable
# TODO mark deprecated
assert
isinstance
(
type
,
tf
.
DType
),
type
assert
isinstance
(
type
,
tf
.
DType
),
type
if
any
(
k
in
name
for
k
in
[
':'
,
'/'
,
' '
]):
return
tf
.
TensorSpec
(
shape
=
shape
,
dtype
=
type
,
name
=
name
)
raise
ValueError
(
"Invalid InputDesc name: '{}'"
.
format
(
name
))
self
=
super
(
InputDesc
,
cls
)
.
__new__
(
cls
,
type
,
shape
,
name
)
self
.
_cached_placeholder
=
{}
return
self
def
_build_placeholder
(
self
):
"""
Build a tf.placeholder from the metadata.
Returns:
tf.Tensor:
"""
with
tfv1
.
name_scope
(
None
):
# clear any name scope it might get called in
ret
=
tfv1
.
placeholder
(
self
.
type
,
shape
=
self
.
shape
,
name
=
self
.
name
)
self
.
_register_cached_placeholder
(
ret
)
return
ret
# cannot memoize here, because InputDesc is hashed by its fields.
def
build_placeholder_reuse
(
self
):
"""
Build a tf.placeholder from the metadata, or return an old one.
Returns:
tf.Tensor:
"""
g
=
tfv1
.
get_default_graph
()
if
g
in
self
.
_cached_placeholder
:
return
self
.
_cached_placeholder
[
g
]
else
:
return
self
.
_build_placeholder
()
def
_register_cached_placeholder
(
self
,
placeholder
):
graph
=
placeholder
.
graph
assert
graph
not
in
self
.
_cached_placeholder
,
\
"Placeholder for this InputDesc had been created before! This is a bug."
self
.
_cached_placeholder
[
graph
]
=
placeholder
@
staticmethod
def
_from_placeholder
(
placeholder
):
name
=
placeholder
.
op
.
name
if
name
.
endswith
(
'_1'
)
or
name
.
endswith
(
'_2'
):
logger
.
error
(
"Creating InputDesc from a placeholder named {}."
.
format
(
name
))
logger
.
error
(
"You might have mistakenly created this placeholder multiple times!"
)
ret
=
InputDesc
(
placeholder
.
dtype
,
tuple
(
placeholder
.
shape
.
as_list
()),
name
)
ret
.
_register_cached_placeholder
(
placeholder
)
return
ret
@
staticmethod
def
_from_tensor_spec
(
spec
):
assert
spec
.
name
is
not
None
,
"TensorSpec should have a name!"
return
InputDesc
(
spec
.
dtype
,
tuple
(
spec
.
shape
.
as_list
()),
spec
.
name
)
class
ModelDescBase
(
object
):
class
ModelDescBase
(
object
):
...
@@ -100,29 +74,22 @@ class ModelDescBase(object):
...
@@ -100,29 +74,22 @@ class ModelDescBase(object):
@
memoized_method
@
memoized_method
def
get_inputs_desc
(
self
):
def
get_inputs_desc
(
self
):
# TODO mark deprecated
return
self
.
get_input_signature
()
@
memoized_method
def
get_input_signature
(
self
):
"""
"""
Returns:
Returns:
A list of :class:`
InputDes
c`, which describes the inputs of this model.
A list of :class:`
tf.TensorSpe
c`, which describes the inputs of this model.
The result is cached for each instance of :class:`ModelDescBase`.
The result is cached for each instance of :class:`ModelDescBase`.
"""
"""
try
:
with
tf
.
Graph
()
.
as_default
()
as
G
:
# create these placeholder in a temporary graph
ret
=
self
.
_get_inputs
()
inputs
=
self
.
inputs
()
log_deprecated
(
if
isinstance
(
inputs
[
0
],
tf
.
Tensor
):
"ModelDescBase._get_inputs() interface"
,
for
p
in
inputs
:
"Use inputs() instead!"
,
assert
p
.
graph
==
G
,
"Placeholders returned by inputs() should be created inside inputs()!"
"2019-03-30"
)
return
[
TensorSpec
(
shape
=
p
.
shape
,
dtype
=
p
.
dtype
,
name
=
p
.
name
)
for
p
in
inputs
]
return
ret
except
NotImplementedError
:
with
tf
.
Graph
()
.
as_default
()
as
G
:
# create these placeholder in a temporary graph
inputs
=
self
.
inputs
()
if
isinstance
(
inputs
[
0
],
tf
.
Tensor
):
for
p
in
inputs
:
assert
p
.
graph
==
G
,
"Placeholders returned by inputs() should be created inside inputs()!"
return
[
InputDesc
.
_from_placeholder
(
p
)
for
p
in
inputs
]
else
:
for
p
in
inputs
:
assert
isinstance
(
p
,
TensorSpec
),
type
(
p
)
return
[
InputDesc
.
_from_tensor_spec
(
p
)
for
p
in
inputs
]
@
property
@
property
def
input_names
(
self
):
def
input_names
(
self
):
...
@@ -130,7 +97,7 @@ class ModelDescBase(object):
...
@@ -130,7 +97,7 @@ class ModelDescBase(object):
Returns:
Returns:
[str]: the names of all the inputs.
[str]: the names of all the inputs.
"""
"""
return
[
k
.
name
for
k
in
self
.
get_input
s_desc
()]
return
[
k
.
name
for
k
in
self
.
get_input
_signature
()]
def
_get_inputs
(
self
):
def
_get_inputs
(
self
):
raise
NotImplementedError
()
raise
NotImplementedError
()
...
@@ -147,7 +114,7 @@ class ModelDescBase(object):
...
@@ -147,7 +114,7 @@ class ModelDescBase(object):
Also, you should never call this method by yourself.
Also, you should never call this method by yourself.
Returns:
Returns:
list[tf.
placeholder] or list[tf.TensorSpec], to be converted to :class:`InputDes
c`.
list[tf.
TensorSpec or tf.placeholder]. To be converted to :class:`tf.TensorSpe
c`.
"""
"""
raise
NotImplementedError
()
raise
NotImplementedError
()
...
@@ -166,9 +133,9 @@ class ModelDescBase(object):
...
@@ -166,9 +133,9 @@ class ModelDescBase(object):
may require it to return necessary information to build the trainer.
may require it to return necessary information to build the trainer.
For example, `SingleCostTrainer` expect this method to return the cost tensor.
For example, `SingleCostTrainer` expect this method to return the cost tensor.
"""
"""
assert
len
(
args
)
==
len
(
self
.
get_input
s_desc
()),
\
assert
len
(
args
)
==
len
(
self
.
get_input
_signature
()),
\
"Number of inputs passed to the graph != number of inputs defined "
\
"Number of inputs passed to the graph != number of inputs defined "
\
"in ModelDesc! ({} != {})"
.
format
(
len
(
args
),
len
(
self
.
get_input
s_desc
()))
"in ModelDesc! ({} != {})"
.
format
(
len
(
args
),
len
(
self
.
get_input
_signature
()))
log_deprecated
(
log_deprecated
(
"ModelDescBase._build_graph() interface"
,
"ModelDescBase._build_graph() interface"
,
"Use build_graph() instead!"
,
"Use build_graph() instead!"
,
...
...
tensorpack/input_source/input_source.py
View file @
b8a50d72
...
@@ -19,6 +19,7 @@ from ..tfutils.tower import get_current_tower_context
...
@@ -19,6 +19,7 @@ from ..tfutils.tower import get_current_tower_context
from
..utils
import
logger
from
..utils
import
logger
from
..utils.concurrency
import
ShareSessionThread
from
..utils.concurrency
import
ShareSessionThread
from
.input_source_base
import
InputSource
from
.input_source_base
import
InputSource
from
..graph_builder.model_desc
import
build_or_reuse_placeholder
try
:
try
:
from
tensorflow.python.ops.data_flow_ops
import
StagingArea
from
tensorflow.python.ops.data_flow_ops
import
StagingArea
...
@@ -59,7 +60,7 @@ class PlaceholderInput(InputSource):
...
@@ -59,7 +60,7 @@ class PlaceholderInput(InputSource):
pass
pass
def
_setup
(
self
,
inputs
):
def
_setup
(
self
,
inputs
):
self
.
_all_placehdrs
=
[
v
.
build_placeholder_reuse
(
)
for
v
in
inputs
]
self
.
_all_placehdrs
=
[
build_or_reuse_placeholder
(
v
)
for
v
in
inputs
]
def
_get_input_tensors
(
self
):
def
_get_input_tensors
(
self
):
return
self
.
_all_placehdrs
return
self
.
_all_placehdrs
...
@@ -110,7 +111,7 @@ class FeedInput(InputSource):
...
@@ -110,7 +111,7 @@ class FeedInput(InputSource):
def
_setup
(
self
,
inputs
):
def
_setup
(
self
,
inputs
):
# placeholders as input are always safe to reuse.
# placeholders as input are always safe to reuse.
self
.
_all_placehdrs
=
[
v
.
build_placeholder_reuse
(
)
for
v
in
inputs
]
self
.
_all_placehdrs
=
[
build_or_reuse_placeholder
(
v
)
for
v
in
inputs
]
self
.
_cb
=
self
.
_FeedCallback
(
self
.
_iter_ds
,
self
.
_all_placehdrs
)
self
.
_cb
=
self
.
_FeedCallback
(
self
.
_iter_ds
,
self
.
_all_placehdrs
)
def
_get_input_tensors
(
self
):
def
_get_input_tensors
(
self
):
...
@@ -196,7 +197,7 @@ class QueueInput(FeedfreeInput):
...
@@ -196,7 +197,7 @@ class QueueInput(FeedfreeInput):
Args:
Args:
ds(DataFlow): the input DataFlow.
ds(DataFlow): the input DataFlow.
queue (tf.QueueBase): A :class:`tf.QueueBase` whose type
queue (tf.QueueBase): A :class:`tf.QueueBase` whose type
should match the corresponding
InputDesc
of the model.
should match the corresponding
input signature
of the model.
Defaults to a FIFO queue of size 50.
Defaults to a FIFO queue of size 50.
"""
"""
if
not
isinstance
(
ds
,
DataFlow
):
if
not
isinstance
(
ds
,
DataFlow
):
...
@@ -210,12 +211,12 @@ class QueueInput(FeedfreeInput):
...
@@ -210,12 +211,12 @@ class QueueInput(FeedfreeInput):
return
len
(
self
.
ds
)
return
len
(
self
.
ds
)
def
_setup
(
self
,
inputs
):
def
_setup
(
self
,
inputs
):
self
.
_input_placehdrs
=
[
v
.
build_placeholder_reuse
(
)
for
v
in
inputs
]
self
.
_input_placehdrs
=
[
build_or_reuse_placeholder
(
v
)
for
v
in
inputs
]
assert
len
(
self
.
_input_placehdrs
)
>
0
,
\
assert
len
(
self
.
_input_placehdrs
)
>
0
,
\
"QueueInput has to be used with some inputs!"
"QueueInput has to be used with some inputs!"
with
self
.
cached_name_scope
():
with
self
.
cached_name_scope
():
if
self
.
queue
is
None
:
if
self
.
queue
is
None
:
self
.
queue
=
tf
.
FIFOQueue
(
self
.
queue
=
tf
v1
.
FIFOQueue
(
50
,
[
x
.
dtype
for
x
in
self
.
_input_placehdrs
],
50
,
[
x
.
dtype
for
x
in
self
.
_input_placehdrs
],
name
=
'input_queue'
)
name
=
'input_queue'
)
logger
.
info
(
"Setting up the queue '{}' for CPU prefetching ..."
.
format
(
self
.
queue
.
name
))
logger
.
info
(
"Setting up the queue '{}' for CPU prefetching ..."
.
format
(
self
.
queue
.
name
))
...
@@ -287,7 +288,7 @@ class BatchQueueInput(QueueInput):
...
@@ -287,7 +288,7 @@ class BatchQueueInput(QueueInput):
ds(DataFlow): the input DataFlow.
ds(DataFlow): the input DataFlow.
batch_size(int): the batch size.
batch_size(int): the batch size.
queue (tf.QueueBase): A :class:`tf.QueueBase` whose type
queue (tf.QueueBase): A :class:`tf.QueueBase` whose type
should match the corresponding
InputDesc
of the model.
should match the corresponding
input signature
of the model.
Defaults to a FIFO queue of size 3000.
Defaults to a FIFO queue of size 3000.
"""
"""
super
(
BatchQueueInput
,
self
)
.
__init__
(
ds
,
queue
)
super
(
BatchQueueInput
,
self
)
.
__init__
(
ds
,
queue
)
...
@@ -298,9 +299,9 @@ class BatchQueueInput(QueueInput):
...
@@ -298,9 +299,9 @@ class BatchQueueInput(QueueInput):
def
_setup
(
self
,
inputs
):
def
_setup
(
self
,
inputs
):
logger
.
info
(
"Setting up the queue for CPU prefetching ..."
)
logger
.
info
(
"Setting up the queue for CPU prefetching ..."
)
self
.
input_placehdrs
=
[
v
.
build_placeholder_reuse
(
)
for
v
in
inputs
]
self
.
input_placehdrs
=
[
build_or_reuse_placeholder
(
v
)
for
v
in
inputs
]
assert
len
(
self
.
input_placehdrs
)
>
0
,
\
assert
len
(
self
.
input_placehdrs
)
>
0
,
\
"BatchQueueInput has to be used with some
InputDesc
!"
"BatchQueueInput has to be used with some
input signature
!"
# prepare placeholders without the first dimension
# prepare placeholders without the first dimension
placehdrs_nobatch
=
[]
placehdrs_nobatch
=
[]
...
@@ -364,8 +365,8 @@ class TensorInput(FeedfreeInput):
...
@@ -364,8 +365,8 @@ class TensorInput(FeedfreeInput):
assert
size
>
0
assert
size
>
0
self
.
_fixed_size
=
size
self
.
_fixed_size
=
size
def
_setup
(
self
,
input
s_desc
):
def
_setup
(
self
,
input
_signature
):
self
.
_
desc
=
inputs_desc
self
.
_
spec
=
input_signature
def
_size
(
self
):
def
_size
(
self
):
if
self
.
_fixed_size
is
None
:
if
self
.
_fixed_size
is
None
:
...
@@ -376,8 +377,8 @@ class TensorInput(FeedfreeInput):
...
@@ -376,8 +377,8 @@ class TensorInput(FeedfreeInput):
with
self
.
cached_name_scope
():
with
self
.
cached_name_scope
():
ret
=
self
.
get_tensor_fn
()
ret
=
self
.
get_tensor_fn
()
assert
isinstance
(
ret
,
(
list
,
tuple
)),
"get_tensor_fn needs to return a list!"
assert
isinstance
(
ret
,
(
list
,
tuple
)),
"get_tensor_fn needs to return a list!"
assert
len
(
ret
)
==
len
(
self
.
_
des
c
),
\
assert
len
(
ret
)
==
len
(
self
.
_
spe
c
),
\
"get_tensor_fn returns {} tensors but there are {} inputs"
.
format
(
len
(
ret
),
len
(
self
.
_
des
c
))
"get_tensor_fn returns {} tensors but there are {} inputs"
.
format
(
len
(
ret
),
len
(
self
.
_
spe
c
))
return
ret
return
ret
...
@@ -399,7 +400,7 @@ class DummyConstantInput(TensorInput):
...
@@ -399,7 +400,7 @@ class DummyConstantInput(TensorInput):
assert
len
(
self
.
shapes
)
==
len
(
self
.
_desc
)
assert
len
(
self
.
shapes
)
==
len
(
self
.
_desc
)
for
idx
,
p
in
enumerate
(
self
.
_desc
):
for
idx
,
p
in
enumerate
(
self
.
_desc
):
tlist
.
append
(
tf
.
constant
(
tlist
.
append
(
tf
.
constant
(
0
,
dtype
=
p
.
type
,
0
,
dtype
=
p
.
d
type
,
name
=
'dummy-{}-{}'
.
format
(
p
.
name
,
ctx
.
index
),
name
=
'dummy-{}-{}'
.
format
(
p
.
name
,
ctx
.
index
),
shape
=
self
.
shapes
[
idx
]))
shape
=
self
.
shapes
[
idx
]))
return
tlist
return
tlist
...
@@ -429,15 +430,14 @@ class ZMQInput(TensorInput):
...
@@ -429,15 +430,14 @@ class ZMQInput(TensorInput):
return
ret
return
ret
super
(
ZMQInput
,
self
)
.
__init__
(
fn
)
super
(
ZMQInput
,
self
)
.
__init__
(
fn
)
def
_setup
(
self
,
inputs_desc
):
def
_setup
(
self
,
input_signature
):
assert
len
(
inputs_desc
)
>
0
,
\
assert
len
(
input_signature
)
>
0
,
\
"ZMQInput has to be used with InputDesc!"
"ZMQInput has to be used with input signature!"
self
.
_desc
=
inputs_desc
import
zmq_ops
import
zmq_ops
self
.
_zmq_pull_socket
=
zmq_ops
.
ZMQPullSocket
(
self
.
_zmq_pull_socket
=
zmq_ops
.
ZMQPullSocket
(
self
.
_end_point
,
self
.
_end_point
,
[
x
.
type
for
x
in
inputs_desc
],
[
x
.
dtype
for
x
in
input_signature
],
hwm
=
self
.
_hwm
,
hwm
=
self
.
_hwm
,
bind
=
self
.
_bind
)
bind
=
self
.
_bind
)
...
@@ -458,23 +458,23 @@ class TFDatasetInput(FeedfreeInput):
...
@@ -458,23 +458,23 @@ class TFDatasetInput(FeedfreeInput):
raise
ValueError
(
"TFDatasetInput takes a tf.data.Dataset! Got {}"
.
format
(
dataset
))
raise
ValueError
(
"TFDatasetInput takes a tf.data.Dataset! Got {}"
.
format
(
dataset
))
self
.
_dataset
=
dataset
self
.
_dataset
=
dataset
def
_setup
(
self
,
input
s_desc
):
def
_setup
(
self
,
input
_signature
):
self
.
_
desc
=
inputs_desc
self
.
_
spec
=
input_signature
types
=
self
.
_dataset
.
output_types
types
=
self
.
_dataset
.
output_types
desc_types
=
tuple
([
k
.
type
for
k
in
inputs_desc
])
spec_types
=
tuple
([
k
.
dtype
for
k
in
input_signature
])
assert
len
(
types
)
==
len
(
des
c_types
),
\
assert
len
(
types
)
==
len
(
spe
c_types
),
\
"Dataset and
InputDesc has
different length! {} != {}"
.
format
(
"Dataset and
input signature have
different length! {} != {}"
.
format
(
len
(
types
),
len
(
des
c_types
))
len
(
types
),
len
(
spe
c_types
))
assert
types
==
des
c_types
,
\
assert
types
==
spe
c_types
,
\
"
Types of dataset and InputDesc
don't match! {} != {}"
.
format
(
"
Data types of dataset and input signature
don't match! {} != {}"
.
format
(
str
(
types
),
str
(
des
c_types
))
str
(
types
),
str
(
spe
c_types
))
shapes
=
self
.
_dataset
.
output_shapes
shapes
=
self
.
_dataset
.
output_shapes
desc_shapes
=
[
k
.
shape
for
k
in
inputs_desc
]
spec_shapes
=
[
k
.
shape
for
k
in
input_signature
]
for
idx
,
(
s1
,
s2
)
in
enumerate
(
zip
(
shapes
,
des
c_shapes
)):
for
idx
,
(
s1
,
s2
)
in
enumerate
(
zip
(
shapes
,
spe
c_shapes
)):
s2
=
tf
.
TensorShape
(
s2
)
s2
=
tf
.
TensorShape
(
s2
)
assert
s2
.
is_compatible_with
(
s1
),
\
assert
s2
.
is_compatible_with
(
s1
),
\
"Input
Desc
'{}' has incompatible shape with dataset! {} vs {}"
.
format
(
"Input
signature
'{}' has incompatible shape with dataset! {} vs {}"
.
format
(
input
s_desc
[
idx
]
.
name
,
s2
,
s1
)
input
_signature
[
idx
]
.
name
,
s2
,
s1
)
self
.
_iterator
=
self
.
_dataset
.
make_initializable_iterator
()
self
.
_iterator
=
self
.
_dataset
.
make_initializable_iterator
()
self
.
_init_op
=
self
.
_iterator
.
initializer
self
.
_init_op
=
self
.
_iterator
.
initializer
...
@@ -482,11 +482,11 @@ class TFDatasetInput(FeedfreeInput):
...
@@ -482,11 +482,11 @@ class TFDatasetInput(FeedfreeInput):
self
.
_init_op
.
run
()
self
.
_init_op
.
run
()
def
_get_input_tensors
(
self
):
def
_get_input_tensors
(
self
):
desc_shapes
=
[
k
.
shape
for
k
in
self
.
_des
c
]
spec_shapes
=
[
k
.
shape
for
k
in
self
.
_spe
c
]
ret
=
self
.
_iterator
.
get_next
()
ret
=
self
.
_iterator
.
get_next
()
assert
len
(
ret
)
==
len
(
des
c_shapes
),
\
assert
len
(
ret
)
==
len
(
spe
c_shapes
),
\
"Dataset returns {} tensors but there are {} inputs!"
.
format
(
len
(
ret
),
len
(
des
c_shapes
))
"Dataset returns {} tensors but there are {} inputs!"
.
format
(
len
(
ret
),
len
(
spe
c_shapes
))
for
t
,
shp
in
zip
(
ret
,
des
c_shapes
):
for
t
,
shp
in
zip
(
ret
,
spe
c_shapes
):
t
.
set_shape
(
shp
)
t
.
set_shape
(
shp
)
return
ret
return
ret
...
...
tensorpack/input_source/input_source_base.py
View file @
b8a50d72
...
@@ -12,6 +12,7 @@ from ..callbacks.base import CallbackFactory
...
@@ -12,6 +12,7 @@ from ..callbacks.base import CallbackFactory
from
..tfutils.common
import
get_op_tensor_name
from
..tfutils.common
import
get_op_tensor_name
from
..utils
import
logger
from
..utils
import
logger
from
..utils.argtools
import
call_only_once
,
memoized_method
from
..utils.argtools
import
call_only_once
,
memoized_method
from
..graph_builder.model_desc
import
build_or_reuse_placeholder
__all__
=
[
'InputSource'
,
'remap_input_source'
]
__all__
=
[
'InputSource'
,
'remap_input_source'
]
...
@@ -86,20 +87,20 @@ class InputSource(object):
...
@@ -86,20 +87,20 @@ class InputSource(object):
pass
pass
@
call_only_once
@
call_only_once
def
setup
(
self
,
input
s_desc
):
def
setup
(
self
,
input
_signature
):
"""
"""
Args:
Args:
input
s_desc (list[InputDesc]): list of input desc
input
_signature (list[tf.TensorSpec]): list of specs for each input tensor
Returns:
Returns:
list[Callback]: extra callbacks needed by this InputSource.
list[Callback]: extra callbacks needed by this InputSource.
callbacks of InputSource cannot use any `trigger*()` method.
callbacks of InputSource cannot use any `trigger*()` method.
"""
"""
self
.
_setup
(
input
s_desc
)
self
.
_setup
(
input
_signature
)
self
.
_setup_done
=
True
self
.
_setup_done
=
True
return
self
.
get_callbacks
()
return
self
.
get_callbacks
()
def
_setup
(
self
,
input
s_desc
):
def
_setup
(
self
,
input
_signature
):
pass
pass
def
setup_done
(
self
):
def
setup_done
(
self
):
...
@@ -190,8 +191,8 @@ class ProxyInputSource(InputSource):
...
@@ -190,8 +191,8 @@ class ProxyInputSource(InputSource):
def
_get_input_tensors
(
self
):
def
_get_input_tensors
(
self
):
return
self
.
_input
.
get_input_tensors
()
return
self
.
_input
.
get_input_tensors
()
def
_setup
(
self
,
input
s_desc
):
def
_setup
(
self
,
input
_signature
):
self
.
_input
.
setup
(
input
s_desc
)
self
.
_input
.
setup
(
input
_signature
)
def
_get_callbacks
(
self
):
def
_get_callbacks
(
self
):
return
self
.
_input
.
get_callbacks
()
return
self
.
_input
.
get_callbacks
()
...
@@ -226,11 +227,11 @@ def remap_input_source(input, names):
...
@@ -226,11 +227,11 @@ def remap_input_source(input, names):
input1 = QueueInput(ds)
input1 = QueueInput(ds)
# assume ds produces 'image' and 'label', but the graph takes more
# assume ds produces 'image' and 'label', but the graph takes more
# inputs for some reasons, or takes inputs of a different order:
# inputs for some reasons, or takes inputs of a different order:
input
s_desc = [InputDesc(tf.float32, (None,10)
, 'score'),
input
_signature = [tf.TensorSpec((None,10), tf.float32
, 'score'),
InputDesc(tf.float32, (None,20,20,3)
, 'label'),
tf.TensorSpec((None,20,20,3), tf.float32
, 'label'),
InputDesc(tf.int32, (None,)
, 'image') ]
tf.TensorSpec((None,), tf.int32
, 'image') ]
input2 = remap_input_source(input1, ['image', 'label'])
input2 = remap_input_source(input1, ['image', 'label'])
input2.setup(input
s_desc
)
input2.setup(input
_signature
)
# now, input2.get_input_tensors() will return a placeholder for 'score',
# now, input2.get_input_tensors() will return a placeholder for 'score',
# plus the tensors returned by input1.get_input_tensors()
# plus the tensors returned by input1.get_input_tensors()
"""
"""
...
@@ -240,7 +241,7 @@ def remap_input_source(input, names):
...
@@ -240,7 +241,7 @@ def remap_input_source(input, names):
self
.
_names
=
tuple
(
names
)
self
.
_names
=
tuple
(
names
)
def
_setup
(
self
,
inputs
):
def
_setup
(
self
,
inputs
):
self
.
_all_placehdrs
=
[
v
.
build_placeholder_reuse
(
)
for
v
in
inputs
]
self
.
_all_placehdrs
=
[
build_or_reuse_placeholder
(
v
)
for
v
in
inputs
]
inputs_subset
=
get_sublist_by_names
(
inputs
,
self
.
_names
)
inputs_subset
=
get_sublist_by_names
(
inputs
,
self
.
_names
)
self
.
_input
.
setup
(
inputs_subset
)
self
.
_input
.
setup
(
inputs_subset
)
...
...
tensorpack/predict/base.py
View file @
b8a50d72
...
@@ -155,7 +155,7 @@ class OfflinePredictor(OnlinePredictor):
...
@@ -155,7 +155,7 @@ class OfflinePredictor(OnlinePredictor):
self
.
graph
=
config
.
_maybe_create_graph
()
self
.
graph
=
config
.
_maybe_create_graph
()
with
self
.
graph
.
as_default
():
with
self
.
graph
.
as_default
():
input
=
PlaceholderInput
()
input
=
PlaceholderInput
()
input
.
setup
(
config
.
input
s_desc
)
input
.
setup
(
config
.
input
_signature
)
with
PredictTowerContext
(
''
):
with
PredictTowerContext
(
''
):
config
.
tower_func
(
*
input
.
get_input_tensors
())
config
.
tower_func
(
*
input
.
get_input_tensors
())
...
...
tensorpack/predict/config.py
View file @
b8a50d72
...
@@ -18,7 +18,7 @@ class PredictConfig(object):
...
@@ -18,7 +18,7 @@ class PredictConfig(object):
def
__init__
(
self
,
def
__init__
(
self
,
model
=
None
,
model
=
None
,
tower_func
=
None
,
tower_func
=
None
,
input
s_desc
=
None
,
input
_signature
=
None
,
input_names
=
None
,
input_names
=
None
,
output_names
=
None
,
output_names
=
None
,
...
@@ -27,11 +27,18 @@ class PredictConfig(object):
...
@@ -27,11 +27,18 @@ class PredictConfig(object):
session_init
=
None
,
session_init
=
None
,
return_input
=
False
,
return_input
=
False
,
create_graph
=
True
,
create_graph
=
True
,
inputs_desc
=
None
):
):
"""
"""
You need to set either `model`, or `inputs_desc` plus `tower_func`.
Users need to provide enough arguments to create a tower function,
They are needed to construct the graph.
which will be used to construct the graph.
You'll also have to set `output_names` as it does not have a default.
This can be provided in the following ways:
1. `model`: a :class:`ModelDesc` instance. It will contain a tower function by itself.
2. `tower_func`: a :class:`tfutils.TowerFuncWrapper` instance.
Provide a tower function instance directly.
3. `tower_func`: a symbolic function and `input_signature`: the signature of the function.
Provide both a function and its signature.
Example:
Example:
...
@@ -42,15 +49,14 @@ class PredictConfig(object):
...
@@ -42,15 +49,14 @@ class PredictConfig(object):
output_names=['linear/output', 'prediction'])
output_names=['linear/output', 'prediction'])
Args:
Args:
model (ModelDescBase): to be used to
obtain inputs_desc and tower_func
.
model (ModelDescBase): to be used to
construct a tower function
.
tower_func: a callable which takes input tensors (by positional args) and construct a tower.
tower_func: a callable which takes input tensors (by positional args) and construct a tower.
or a :class:`tfutils.TowerFuncWrapper` instance
, which packs both `inputs_desc` and function together
.
or a :class:`tfutils.TowerFuncWrapper` instance.
input
s_desc ([InputDesc]): if tower_func is a plain function (instead of a TowerFuncWrapper), this describes
input
_signature ([tf.TensorSpec]): if tower_func is a plain function (instead of a TowerFuncWrapper),
the list of inputs it takes.
th
is describes th
e list of inputs it takes.
input_names (list): a list of input tensor names. Defaults to match inputs_desc.
input_names (list): a list of input tensor names. Defaults to match input_signature.
The name can be either the name of a tensor, or the name of one input defined
The name can be either the name of a tensor, or the name of one input of the tower.
by `inputs_desc` or by `model`.
output_names (list): a list of names of the output tensors to predict, the
output_names (list): a list of names of the output tensors to predict, the
tensors can be any tensor in the graph that's computable from the tensors correponding to `input_names`.
tensors can be any tensor in the graph that's computable from the tensors correponding to `input_names`.
...
@@ -62,23 +68,29 @@ class PredictConfig(object):
...
@@ -62,23 +68,29 @@ class PredictConfig(object):
return_input (bool): same as in :attr:`PredictorBase.return_input`.
return_input (bool): same as in :attr:`PredictorBase.return_input`.
create_graph (bool): create a new graph, or use the default graph
create_graph (bool): create a new graph, or use the default graph
when predictor is first initialized.
when predictor is first initialized.
inputs_desc (list[tf.TensorSpec]): old (deprecated) name for `input_signature`.
"""
"""
def
assert_type
(
v
,
tp
,
name
):
def
assert_type
(
v
,
tp
,
name
):
assert
isinstance
(
v
,
tp
),
\
assert
isinstance
(
v
,
tp
),
\
"
{}
has to be type '{}', but an object of type '{}' found."
.
format
(
"
Argument '{}'
has to be type '{}', but an object of type '{}' found."
.
format
(
name
,
tp
.
__name__
,
v
.
__class__
.
__name__
)
name
,
tp
.
__name__
,
v
.
__class__
.
__name__
)
if
inputs_desc
is
not
None
:
# TODO warn deprecated or not?
assert
input_signature
is
None
,
"Cannot set both inputs_desc and input_signature!"
input_signature
=
inputs_desc
if
model
is
not
None
:
if
model
is
not
None
:
assert_type
(
model
,
ModelDescBase
,
'model'
)
assert_type
(
model
,
ModelDescBase
,
'model'
)
assert
input
s_desc
is
None
and
tower_func
is
None
assert
input
_signature
is
None
and
tower_func
is
None
self
.
input
s_desc
=
model
.
get_inputs_desc
()
self
.
input
_signature
=
model
.
get_input_signature
()
self
.
tower_func
=
TowerFuncWrapper
(
model
.
build_graph
,
self
.
input
s_desc
)
self
.
tower_func
=
TowerFuncWrapper
(
model
.
build_graph
,
self
.
input
_signature
)
else
:
else
:
if
isinstance
(
tower_func
,
TowerFuncWrapper
):
if
isinstance
(
tower_func
,
TowerFuncWrapper
):
input
s_desc
=
tower_func
.
inputs_desc
input
_signature
=
tower_func
.
input_signature
assert
input
s_desc
is
not
None
and
tower_func
is
not
None
assert
input
_signature
is
not
None
and
tower_func
is
not
None
self
.
input
s_desc
=
inputs_desc
self
.
input
_signature
=
input_signature
self
.
tower_func
=
TowerFuncWrapper
(
tower_func
,
input
s_desc
)
self
.
tower_func
=
TowerFuncWrapper
(
tower_func
,
input
_signature
)
if
session_init
is
None
:
if
session_init
is
None
:
session_init
=
JustCurrentSession
()
session_init
=
JustCurrentSession
()
...
@@ -93,20 +105,22 @@ class PredictConfig(object):
...
@@ -93,20 +105,22 @@ class PredictConfig(object):
# inputs & outputs
# inputs & outputs
self
.
input_names
=
input_names
self
.
input_names
=
input_names
if
self
.
input_names
is
None
:
if
self
.
input_names
is
None
:
self
.
input_names
=
[
k
.
name
for
k
in
self
.
inputs_desc
]
self
.
input_names
=
[
k
.
name
for
k
in
self
.
input_signature
]
assert
output_names
is
not
None
,
"Argument 'output_names' is not provided!"
self
.
output_names
=
output_names
self
.
output_names
=
output_names
assert_type
(
self
.
output_names
,
list
,
'output_names'
)
assert_type
(
self
.
output_names
,
list
,
'output_names'
)
assert_type
(
self
.
input_names
,
list
,
'input_names'
)
assert_type
(
self
.
input_names
,
list
,
'input_names'
)
if
len
(
self
.
input_names
)
==
0
:
if
len
(
self
.
input_names
)
==
0
:
logger
.
warn
(
'PredictConfig receives empty "input_names".'
)
logger
.
warn
(
'PredictConfig receives empty "input_names".'
)
# assert len(self.input_names), self.input_names
for
v
in
self
.
input_names
:
for
v
in
self
.
input_names
:
assert_type
(
v
,
six
.
string_types
,
'Each item in input_names'
)
assert_type
(
v
,
six
.
string_types
,
'Each item in input_names'
)
assert
len
(
self
.
output_names
),
self
.
output_names
assert
len
(
self
.
output_names
),
"Argument 'output_names' cannot be empty!"
self
.
return_input
=
bool
(
return_input
)
self
.
return_input
=
bool
(
return_input
)
self
.
create_graph
=
bool
(
create_graph
)
self
.
create_graph
=
bool
(
create_graph
)
self
.
inputs_desc
=
input_signature
# TODO a little bit of compatibility
def
_maybe_create_graph
(
self
):
def
_maybe_create_graph
(
self
):
if
self
.
create_graph
:
if
self
.
create_graph
:
return
tf
.
Graph
()
return
tf
.
Graph
()
...
...
tensorpack/predict/feedfree.py
View file @
b8a50d72
...
@@ -21,7 +21,7 @@ class FeedfreePredictor(PredictorBase):
...
@@ -21,7 +21,7 @@ class FeedfreePredictor(PredictorBase):
Args:
Args:
config (PredictConfig): the config to use.
config (PredictConfig): the config to use.
input_source (InputSource): the feedfree InputSource to use.
input_source (InputSource): the feedfree InputSource to use.
Must match the
inputs_desc
in config.
Must match the
signature of the tower function
in config.
"""
"""
self
.
_config
=
config
self
.
_config
=
config
self
.
_input_source
=
input_source
self
.
_input_source
=
input_source
...
@@ -33,7 +33,7 @@ class FeedfreePredictor(PredictorBase):
...
@@ -33,7 +33,7 @@ class FeedfreePredictor(PredictorBase):
self
.
graph
=
config
.
_maybe_create_graph
()
self
.
graph
=
config
.
_maybe_create_graph
()
with
self
.
graph
.
as_default
():
with
self
.
graph
.
as_default
():
self
.
_input_callbacks
=
Callbacks
(
self
.
_input_callbacks
=
Callbacks
(
self
.
_input_source
.
setup
(
config
.
input
s_desc
))
self
.
_input_source
.
setup
(
config
.
input
_signature
))
with
PredictTowerContext
(
''
):
with
PredictTowerContext
(
''
):
self
.
_input_tensors
=
self
.
_input_source
.
get_input_tensors
()
self
.
_input_tensors
=
self
.
_input_source
.
get_input_tensors
()
config
.
tower_func
(
*
self
.
_input_tensors
)
config
.
tower_func
(
*
self
.
_input_tensors
)
...
...
tensorpack/predict/multigpu.py
View file @
b8a50d72
...
@@ -4,7 +4,6 @@
...
@@ -4,7 +4,6 @@
import
tensorflow
as
tf
import
tensorflow
as
tf
from
..graph_builder.model_desc
import
InputDesc
from
..input_source
import
PlaceholderInput
from
..input_source
import
PlaceholderInput
from
..tfutils.tower
import
PredictTowerContext
from
..tfutils.tower
import
PredictTowerContext
from
..utils
import
logger
from
..utils
import
logger
...
@@ -33,7 +32,7 @@ class MultiTowerOfflinePredictor(OnlinePredictor):
...
@@ -33,7 +32,7 @@ class MultiTowerOfflinePredictor(OnlinePredictor):
handles
=
[]
handles
=
[]
input
=
PlaceholderInput
()
input
=
PlaceholderInput
()
input
.
setup
(
config
.
input
s_desc
)
input
.
setup
(
config
.
input
_signature
)
for
idx
,
t
in
enumerate
(
towers
):
for
idx
,
t
in
enumerate
(
towers
):
tower_name
=
'tower'
+
str
(
t
)
tower_name
=
'tower'
+
str
(
t
)
...
@@ -102,10 +101,10 @@ class DataParallelOfflinePredictor(OnlinePredictor):
...
@@ -102,10 +101,10 @@ class DataParallelOfflinePredictor(OnlinePredictor):
for
idx
,
t
in
enumerate
(
towers
):
for
idx
,
t
in
enumerate
(
towers
):
tower_name
=
'tower'
+
str
(
t
)
tower_name
=
'tower'
+
str
(
t
)
inputs_desc
=
[
InputDesc
(
desc
.
type
,
desc
.
shape
,
tower_name
+
'_'
+
desc
.
name
)
new_sig
=
[
tf
.
TensorSpec
(
dtype
=
p
.
dtype
,
shape
=
p
.
shape
,
name
=
tower_name
+
'_'
+
p
.
name
)
for
desc
in
config
.
inputs_desc
]
for
p
in
config
.
input_signature
]
input
=
PlaceholderInput
()
input
=
PlaceholderInput
()
input
.
setup
(
inputs_desc
)
input
.
setup
(
new_sig
)
with
tf
.
variable_scope
(
tf
.
get_variable_scope
(),
reuse
=
idx
>
0
),
\
with
tf
.
variable_scope
(
tf
.
get_variable_scope
(),
reuse
=
idx
>
0
),
\
tf
.
device
(
'/gpu:{}'
.
format
(
t
)),
\
tf
.
device
(
'/gpu:{}'
.
format
(
t
)),
\
...
...
tensorpack/tfutils/export.py
View file @
b8a50d72
...
@@ -28,7 +28,7 @@ class ModelExporter(object):
...
@@ -28,7 +28,7 @@ class ModelExporter(object):
Args:
Args:
config (PredictConfig): the config to use.
config (PredictConfig): the config to use.
The graph will be built with
`config.tower_func` and `config.inputs_desc
`.
The graph will be built with
the tower function defined by this `PredictConfig
`.
Then the input / output names will be used to export models for inference.
Then the input / output names will be used to export models for inference.
"""
"""
super
(
ModelExporter
,
self
)
.
__init__
()
super
(
ModelExporter
,
self
)
.
__init__
()
...
@@ -51,7 +51,7 @@ class ModelExporter(object):
...
@@ -51,7 +51,7 @@ class ModelExporter(object):
self
.
graph
=
self
.
config
.
_maybe_create_graph
()
self
.
graph
=
self
.
config
.
_maybe_create_graph
()
with
self
.
graph
.
as_default
():
with
self
.
graph
.
as_default
():
input
=
PlaceholderInput
()
input
=
PlaceholderInput
()
input
.
setup
(
self
.
config
.
input
s_desc
)
input
.
setup
(
self
.
config
.
input
_signature
)
with
PredictTowerContext
(
''
):
with
PredictTowerContext
(
''
):
self
.
config
.
tower_func
(
*
input
.
get_input_tensors
())
self
.
config
.
tower_func
(
*
input
.
get_input_tensors
())
...
@@ -116,7 +116,7 @@ class ModelExporter(object):
...
@@ -116,7 +116,7 @@ class ModelExporter(object):
self
.
graph
=
self
.
config
.
_maybe_create_graph
()
self
.
graph
=
self
.
config
.
_maybe_create_graph
()
with
self
.
graph
.
as_default
():
with
self
.
graph
.
as_default
():
input
=
PlaceholderInput
()
input
=
PlaceholderInput
()
input
.
setup
(
self
.
config
.
input
s_desc
)
input
.
setup
(
self
.
config
.
input
_signature
)
with
PredictTowerContext
(
''
):
with
PredictTowerContext
(
''
):
self
.
config
.
tower_func
(
*
input
.
get_input_tensors
())
self
.
config
.
tower_func
(
*
input
.
get_input_tensors
())
...
...
tensorpack/tfutils/tower.py
View file @
b8a50d72
...
@@ -257,24 +257,27 @@ class TowerFuncWrapper(object):
...
@@ -257,24 +257,27 @@ class TowerFuncWrapper(object):
Conceptually, this class is roughly equivalent to `tf.function` with input signature, introduced in TF 2.0.
Conceptually, this class is roughly equivalent to `tf.function` with input signature, introduced in TF 2.0.
"""
"""
def
__init__
(
self
,
tower_fn
,
input
s_desc
):
def
__init__
(
self
,
tower_fn
,
input
_signature
):
"""
"""
Args:
Args:
tower_func: a function which builds one tower in the graph.
tower_func: a function which builds one tower in the graph.
It takes several input tensors and could return anything.
It takes several input tensors and could return anything.
input
s_desc ([InputDesc]): list of :class:`InputDes
c`.
input
_signature ([TensorSpec]): list of :class:`tf.TensorSpe
c`.
They are used to figure out the names for the input tensors.
They are used to figure out the names for the input tensors.
"""
"""
assert
callable
(
tower_fn
),
tower_fn
assert
callable
(
tower_fn
),
tower_fn
self
.
_inputs_desc_names
=
[
k
.
name
for
k
in
inputs_desc
]
self
.
_inputs_names
=
[
k
.
name
for
k
in
input_signature
]
assert
len
(
set
(
self
.
_inputs_desc_names
))
==
len
(
self
.
_inputs_desc_names
),
\
assert
len
(
set
(
self
.
_inputs_names
))
==
len
(
self
.
_inputs_names
),
\
"Duplicated names in inputs_desc! "
+
str
(
self
.
_inputs_desc_names
)
"Duplicated names in input_signature! "
+
str
(
self
.
_inputs_names
)
for
name
in
self
.
_inputs_names
:
if
any
(
k
in
name
for
k
in
[
':'
,
'/'
,
' '
]):
raise
ValueError
(
"Invalid input name: '{}'"
.
format
(
name
))
self
.
_tower_fn
=
tower_fn
self
.
_tower_fn
=
tower_fn
self
.
_input
s_desc
=
inputs_desc
self
.
_input
_signature
=
input_signature
self
.
_handles
=
[]
self
.
_handles
=
[]
def
__new__
(
cls
,
tower_fn
,
inputs_desc
):
def
__new__
(
cls
,
tower_fn
,
_
):
# to avoid double-wrapping a function
# to avoid double-wrapping a function
if
isinstance
(
tower_fn
,
TowerFuncWrapper
):
if
isinstance
(
tower_fn
,
TowerFuncWrapper
):
return
tower_fn
return
tower_fn
...
@@ -285,7 +288,7 @@ class TowerFuncWrapper(object):
...
@@ -285,7 +288,7 @@ class TowerFuncWrapper(object):
ctx
=
get_current_tower_context
()
ctx
=
get_current_tower_context
()
assert
ctx
is
not
None
,
"Function must be called under TowerContext!"
assert
ctx
is
not
None
,
"Function must be called under TowerContext!"
output
=
self
.
_tower_fn
(
*
args
)
output
=
self
.
_tower_fn
(
*
args
)
handle
=
TowerTensorHandle
(
ctx
,
args
,
output
,
self
.
_input
s_desc
)
handle
=
TowerTensorHandle
(
ctx
,
args
,
output
,
self
.
_input
_signature
)
self
.
_handles
.
append
(
handle
)
self
.
_handles
.
append
(
handle
)
return
output
return
output
...
@@ -298,9 +301,14 @@ class TowerFuncWrapper(object):
...
@@ -298,9 +301,14 @@ class TowerFuncWrapper(object):
"""
"""
return
TowerTensorHandles
(
self
.
_handles
)
return
TowerTensorHandles
(
self
.
_handles
)
@
property
def
input_signature
(
self
):
return
self
.
_input_signature
@
property
@
property
def
inputs_desc
(
self
):
def
inputs_desc
(
self
):
return
self
.
_inputs_desc
# TODO mark deprecated
return
self
.
_input_signature
class
TowerTensorHandles
(
object
):
class
TowerTensorHandles
(
object
):
...
@@ -354,14 +362,14 @@ class TowerTensorHandle(object):
...
@@ -354,14 +362,14 @@ class TowerTensorHandle(object):
"""
"""
@
HIDE_DOC
@
HIDE_DOC
def
__init__
(
self
,
ctx
,
input
,
output
,
input
s_desc
=
None
):
def
__init__
(
self
,
ctx
,
input
,
output
,
input
_signature
=
None
):
self
.
_ctx
=
ctx
self
.
_ctx
=
ctx
self
.
_extra_tensor_names
=
{}
self
.
_extra_tensor_names
=
{}
if
input
s_desc
is
not
None
:
if
input
_signature
is
not
None
:
assert
len
(
input
s_desc
)
==
len
(
input
)
assert
len
(
input
_signature
)
==
len
(
input
)
self
.
_extra_tensor_names
=
{
self
.
_extra_tensor_names
=
{
get_op_tensor_name
(
x
.
name
)[
1
]:
y
for
x
,
y
in
zip
(
input
s_desc
,
input
)}
get_op_tensor_name
(
x
.
name
)[
1
]:
y
for
x
,
y
in
zip
(
input
_signature
,
input
)}
self
.
_input
=
input
self
.
_input
=
input
self
.
_output
=
output
self
.
_output
=
output
...
@@ -379,7 +387,7 @@ class TowerTensorHandle(object):
...
@@ -379,7 +387,7 @@ class TowerTensorHandle(object):
1. The name of the tensor without any tower prefix.
1. The name of the tensor without any tower prefix.
2.
The name of an :class:`InputDesc`
, if it is used when building the tower.
2.
A name in the input signature
, if it is used when building the tower.
In the second case, this method will return the tensor that's used as the corresponding
In the second case, this method will return the tensor that's used as the corresponding
input to the tower. Note that this tensor may have a different name (e.g. may be an output of a queue).
input to the tower. Note that this tensor may have a different name (e.g. may be an output of a queue).
...
...
tensorpack/train/interface.py
View file @
b8a50d72
...
@@ -87,7 +87,7 @@ def launch_train_with_config(config, trainer):
...
@@ -87,7 +87,7 @@ def launch_train_with_config(config, trainer):
# We should gradually stay away from this unuseful abstraction.
# We should gradually stay away from this unuseful abstraction.
# TowerFuncWrapper is a better abstraction (similar to tf.defun in the future)
# TowerFuncWrapper is a better abstraction (similar to tf.defun in the future)
trainer
.
setup_graph
(
trainer
.
setup_graph
(
model
.
get_input
s_desc
(),
input
,
model
.
get_input
_signature
(),
input
,
model
.
_build_graph_get_cost
,
model
.
get_optimizer
)
model
.
_build_graph_get_cost
,
model
.
get_optimizer
)
_check_unused_regularization
()
_check_unused_regularization
()
trainer
.
train_with_defaults
(
trainer
.
train_with_defaults
(
...
...
tensorpack/train/tower.py
View file @
b8a50d72
...
@@ -56,11 +56,16 @@ class TowerTrainer(Trainer):
...
@@ -56,11 +56,16 @@ class TowerTrainer(Trainer):
@
property
@
property
def
inputs_desc
(
self
):
def
inputs_desc
(
self
):
# TODO mark deprecated
return
self
.
input_signature
@
property
def
input_signature
(
self
):
"""
"""
Returns:
Returns:
list[
InputDes
c]: metainfo about the inputs to the tower.
list[
tf.TensorSpe
c]: metainfo about the inputs to the tower.
"""
"""
return
self
.
tower_func
.
input
s_desc
return
self
.
tower_func
.
input
_signature
@
property
@
property
def
towers
(
self
):
def
towers
(
self
):
...
@@ -124,7 +129,7 @@ class TowerTrainer(Trainer):
...
@@ -124,7 +129,7 @@ class TowerTrainer(Trainer):
if
tower
is
None
:
if
tower
is
None
:
input
=
PlaceholderInput
()
input
=
PlaceholderInput
()
input
.
setup
(
self
.
input
s_desc
)
input
.
setup
(
self
.
input
_signature
)
vs_name
=
self
.
_vs_name_for_predictor
(
device_id
)
vs_name
=
self
.
_vs_name_for_predictor
(
device_id
)
with
tfv1
.
variable_scope
(
tfv1
.
get_variable_scope
(),
reuse
=
True
),
\
with
tfv1
.
variable_scope
(
tfv1
.
get_variable_scope
(),
reuse
=
True
),
\
...
@@ -164,7 +169,7 @@ class SingleCostTrainer(TowerTrainer):
...
@@ -164,7 +169,7 @@ class SingleCostTrainer(TowerTrainer):
Base class for single-cost trainer.
Base class for single-cost trainer.
Single-cost trainer has a :meth:`setup_graph` method which takes
Single-cost trainer has a :meth:`setup_graph` method which takes
(input
s_desc
, input, get_cost_fn, get_opt_fn), and build the training graph from them.
(input
_signature
, input, get_cost_fn, get_opt_fn), and build the training graph from them.
To use a :class:`SingleCostTrainer` object, call `trainer.setup_graph(...); trainer.train(...)`.
To use a :class:`SingleCostTrainer` object, call `trainer.setup_graph(...); trainer.train(...)`.
"""
"""
...
@@ -194,12 +199,12 @@ class SingleCostTrainer(TowerTrainer):
...
@@ -194,12 +199,12 @@ class SingleCostTrainer(TowerTrainer):
"""
"""
@
call_only_once
@
call_only_once
def
setup_graph
(
self
,
input
s_desc
,
input
,
get_cost_fn
,
get_opt_fn
):
def
setup_graph
(
self
,
input
_signature
,
input
,
get_cost_fn
,
get_opt_fn
):
"""
"""
Responsible for building the main training graph for single-cost training.
Responsible for building the main training graph for single-cost training.
Args:
Args:
input
s_desc ([InputDesc]):
input
_signature ([TensorSpec]): list of TensorSpec that describe the inputs
input (InputSource):
input (InputSource):
get_cost_fn ([tf.Tensor] -> tf.Tensor): callable, takes some input tensors and return a cost tensor.
get_cost_fn ([tf.Tensor] -> tf.Tensor): callable, takes some input tensors and return a cost tensor.
get_opt_fn (-> tf.train.Optimizer): callable which returns an
get_opt_fn (-> tf.train.Optimizer): callable which returns an
...
@@ -210,12 +215,12 @@ class SingleCostTrainer(TowerTrainer):
...
@@ -210,12 +215,12 @@ class SingleCostTrainer(TowerTrainer):
It must follows the `rules of tower function.
It must follows the `rules of tower function.
<http://tensorpack.readthedocs.io/tutorial/trainer.html#tower-trainer>`_.
<http://tensorpack.readthedocs.io/tutorial/trainer.html#tower-trainer>`_.
"""
"""
get_cost_fn
=
TowerFuncWrapper
(
get_cost_fn
,
input
s_desc
)
get_cost_fn
=
TowerFuncWrapper
(
get_cost_fn
,
input
_signature
)
get_opt_fn
=
memoized
(
get_opt_fn
)
get_opt_fn
=
memoized
(
get_opt_fn
)
self
.
tower_func
=
get_cost_fn
self
.
tower_func
=
get_cost_fn
# TODO setup may want to register monitor as well??
# TODO setup may want to register monitor as well??
input_callbacks
=
self
.
_setup_input
(
input
s_desc
,
input
)
input_callbacks
=
self
.
_setup_input
(
input
_signature
,
input
)
train_callbacks
=
self
.
_setup_graph
(
input
,
get_cost_fn
,
get_opt_fn
)
train_callbacks
=
self
.
_setup_graph
(
input
,
get_cost_fn
,
get_opt_fn
)
self
.
register_callback
(
input_callbacks
+
train_callbacks
)
self
.
register_callback
(
input_callbacks
+
train_callbacks
)
...
@@ -229,9 +234,9 @@ class SingleCostTrainer(TowerTrainer):
...
@@ -229,9 +234,9 @@ class SingleCostTrainer(TowerTrainer):
[Callback]: list of callbacks needed
[Callback]: list of callbacks needed
"""
"""
def
_setup_input
(
self
,
input
s_desc
,
input
):
def
_setup_input
(
self
,
input
_signature
,
input
):
assert
not
input
.
setup_done
()
assert
not
input
.
setup_done
()
return
input
.
setup
(
input
s_desc
)
return
input
.
setup
(
input
_signature
)
def
_make_get_grad_fn
(
self
,
input
,
get_cost_fn
,
get_opt_fn
):
def
_make_get_grad_fn
(
self
,
input
,
get_cost_fn
,
get_opt_fn
):
"""
"""
...
...
tensorpack/train/trainers.py
View file @
b8a50d72
...
@@ -272,7 +272,7 @@ class DistributedTrainerReplicated(DistributedTrainerBase):
...
@@ -272,7 +272,7 @@ class DistributedTrainerReplicated(DistributedTrainerBase):
self
.
_builder
=
DistributedReplicatedBuilder
(
gpus
,
server
)
self
.
_builder
=
DistributedReplicatedBuilder
(
gpus
,
server
)
self
.
is_chief
=
self
.
_builder
.
is_chief
self
.
is_chief
=
self
.
_builder
.
is_chief
def
_setup_input
(
self
,
input
s_desc
,
input
):
def
_setup_input
(
self
,
input
_signature
,
input
):
with
override_to_local_variable
():
with
override_to_local_variable
():
get_global_step_var
()
# gs should be local
get_global_step_var
()
# gs should be local
# input source may create variable (queue size summary)
# input source may create variable (queue size summary)
...
@@ -280,7 +280,7 @@ class DistributedTrainerReplicated(DistributedTrainerBase):
...
@@ -280,7 +280,7 @@ class DistributedTrainerReplicated(DistributedTrainerBase):
# whether something should be global or local. We now assume
# whether something should be global or local. We now assume
# they should be local.
# they should be local.
assert
not
input
.
setup_done
()
assert
not
input
.
setup_done
()
return
input
.
setup
(
input
s_desc
)
return
input
.
setup
(
input
_signature
)
def
_setup_graph
(
self
,
input
,
get_cost_fn
,
get_opt_fn
):
def
_setup_graph
(
self
,
input
,
get_cost_fn
,
get_opt_fn
):
assert
isinstance
(
input
,
FeedfreeInput
),
input
assert
isinstance
(
input
,
FeedfreeInput
),
input
...
...
tensorpack/utils/concurrency.py
View file @
b8a50d72
...
@@ -132,8 +132,8 @@ class ShareSessionThread(threading.Thread):
...
@@ -132,8 +132,8 @@ class ShareSessionThread(threading.Thread):
yield
None
yield
None
def
start
(
self
):
def
start
(
self
):
import
tensorflow
as
tf
from
..compat
import
tfv1
self
.
_sess
=
tf
.
get_default_session
()
self
.
_sess
=
tf
v1
.
get_default_session
()
super
(
ShareSessionThread
,
self
)
.
start
()
super
(
ShareSessionThread
,
self
)
.
start
()
def
run
(
self
):
def
run
(
self
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment