Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
b8a50d72
Commit
b8a50d72
authored
Mar 19, 2019
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
InputDesc -> tf.TensorSpec everywhere
parent
ba679ab1
Changes
29
Hide whitespace changes
Inline
Side-by-side
Showing
29 changed files
with
237 additions
and
235 deletions
+237
-235
CHANGES.md
CHANGES.md
+3
-0
docs/conf.py
docs/conf.py
+2
-0
docs/tutorial/extend/trainer.md
docs/tutorial/extend/trainer.md
+1
-1
docs/tutorial/training-interface.md
docs/tutorial/training-interface.md
+2
-2
examples/CaffeModels/load-alexnet.py
examples/CaffeModels/load-alexnet.py
+1
-1
examples/CaffeModels/load-cpm.py
examples/CaffeModels/load-cpm.py
+1
-1
examples/CaffeModels/load-vgg16.py
examples/CaffeModels/load-vgg16.py
+1
-1
examples/CaffeModels/load-vgg19.py
examples/CaffeModels/load-vgg19.py
+1
-1
examples/GAN/GAN.py
examples/GAN/GAN.py
+5
-5
examples/ImageNetModels/shufflenet.py
examples/ImageNetModels/shufflenet.py
+4
-7
examples/basics/cifar-convnet.py
examples/basics/cifar-convnet.py
+2
-2
examples/keras/imagenet-resnet-keras.py
examples/keras/imagenet-resnet-keras.py
+3
-3
examples/keras/mnist-keras-v2.py
examples/keras/mnist-keras-v2.py
+3
-3
tensorpack/callbacks/inference_runner.py
tensorpack/callbacks/inference_runner.py
+2
-2
tensorpack/callbacks/param.py
tensorpack/callbacks/param.py
+2
-2
tensorpack/contrib/keras.py
tensorpack/contrib/keras.py
+19
-13
tensorpack/graph_builder/model_desc.py
tensorpack/graph_builder/model_desc.py
+50
-83
tensorpack/input_source/input_source.py
tensorpack/input_source/input_source.py
+35
-35
tensorpack/input_source/input_source_base.py
tensorpack/input_source/input_source_base.py
+12
-11
tensorpack/predict/base.py
tensorpack/predict/base.py
+1
-1
tensorpack/predict/config.py
tensorpack/predict/config.py
+36
-22
tensorpack/predict/feedfree.py
tensorpack/predict/feedfree.py
+2
-2
tensorpack/predict/multigpu.py
tensorpack/predict/multigpu.py
+4
-5
tensorpack/tfutils/export.py
tensorpack/tfutils/export.py
+3
-3
tensorpack/tfutils/tower.py
tensorpack/tfutils/tower.py
+22
-14
tensorpack/train/interface.py
tensorpack/train/interface.py
+1
-1
tensorpack/train/tower.py
tensorpack/train/tower.py
+15
-10
tensorpack/train/trainers.py
tensorpack/train/trainers.py
+2
-2
tensorpack/utils/concurrency.py
tensorpack/utils/concurrency.py
+2
-2
No files found.
CHANGES.md
View file @
b8a50d72
...
...
@@ -8,6 +8,9 @@ so you don't need to look at here very often.
Here are a list of things that were changed, starting from an early version.
TensorFlow itself also changes API and those are not listed here.
+
[2019/03/20] The concept of
`InputDesc`
was replaced by its equivalent in TF:
`tf.TensorSpec`
. This may be a breaking change if you have customized
code that relies on internals of
`InputDesc`
.
+
[2018/08/27] msgpack is used again for "serialization to disk", because pyarrow
has no compatibility between versions. To use pyarrow instead,
`export TENSORPACK_COMPATIBLE_SERIALIZE=pyarrow`
.
+
[2018/04/05] msgpack is replaced by pyarrow in favor of its speed. If you want old behavior,
...
...
docs/conf.py
View file @
b8a50d72
...
...
@@ -375,6 +375,8 @@ _DEPRECATED_NAMES = set([
'PrefetchOnGPUs'
,
'DistributedTrainerReplicated'
,
'DistributedTrainerParameterServer'
,
'InputDesc'
,
'inputs_desc'
,
# renamed items that should not appear in docs
'DumpTensor'
,
...
...
docs/tutorial/extend/trainer.md
View file @
b8a50d72
...
...
@@ -48,7 +48,7 @@ Most neural network training tasks are single-cost optimization.
Tensorpack provides some trainer implementations for such tasks.
These trainers will take care of step 1 (define the graph), with the following arguments:
1.
Some
`
InputDesc`
, the metadata about
the input.
1.
Some
`
tf.TensorSpec`
, the signature of
the input.
2.
An
`InputSource`
, where the input come from. See
[
Input Pipeline
](
input-source.html
)
.
3.
A function which takes input tensors and returns the cost.
4.
A function which returns an optimizer.
...
...
docs/tutorial/training-interface.md
View file @
b8a50d72
...
...
@@ -11,7 +11,7 @@ This interface is enough for most types of single-cost tasks.
A lot of examples are written in this interface.
[
SingleCost trainers
](
../modules/train.html#tensorpack.train.SingleCostTrainer
)
expects 4 arguments to setup the graph:
`InputDesc`
,
`InputSource`
, get_cost function, and an optimizer.
expects 4 arguments to setup the graph:
input signatures
,
`InputSource`
, get_cost function, and an optimizer.
`ModelDesc`
describes a model by packing 3 of them together into one object:
```
python
...
...
@@ -62,7 +62,7 @@ The function `launch_train_with_config(config, trainer)`
uses the raw trainer interface under the hood, and is almost equivalent to the following two lines of code:
```
python
trainer
.
setup_graph
(
my_model
.
get_input
s_desc
(),
my_model
.
get_input
_signature
(),
my_input_source
,
# or QueueInput(my_dataflow)
my_model
.
build_graph
,
my_model
.
get_optimizer
)
...
...
examples/CaffeModels/load-alexnet.py
View file @
b8a50d72
...
...
@@ -42,7 +42,7 @@ def tower_func(image):
def
run_test
(
path
,
input
):
param_dict
=
dict
(
np
.
load
(
path
))
predictor
=
OfflinePredictor
(
PredictConfig
(
input
s_desc
=
[
InputDesc
(
tf
.
float32
,
(
None
,
227
,
227
,
3
)
,
'input'
)],
input
_signature
=
[
tf
.
TensorSpec
((
None
,
227
,
227
,
3
),
tf
.
float32
,
'input'
)],
tower_func
=
tower_func
,
session_init
=
DictRestore
(
param_dict
),
input_names
=
[
'input'
],
...
...
examples/CaffeModels/load-cpm.py
View file @
b8a50d72
...
...
@@ -97,7 +97,7 @@ def CPM(image):
def
run_test
(
model_path
,
img_file
):
param_dict
=
dict
(
np
.
load
(
model_path
))
predict_func
=
OfflinePredictor
(
PredictConfig
(
input
s_desc
=
[
InputDesc
(
tf
.
float32
,
(
None
,
368
,
368
,
3
)
,
'input'
)],
input
_signature
=
[
tf
.
TensorSpec
((
None
,
368
,
368
,
3
),
tf
.
float32
,
'input'
)],
tower_func
=
CPM
,
session_init
=
DictRestore
(
param_dict
),
input_names
=
[
'input'
],
...
...
examples/CaffeModels/load-vgg16.py
View file @
b8a50d72
...
...
@@ -59,7 +59,7 @@ def run_test(path, input):
param_dict
=
{
k
.
replace
(
'/W'
,
'/kernel'
)
.
replace
(
'/b'
,
'/bias'
):
v
for
k
,
v
in
six
.
iteritems
(
param_dict
)}
predict_func
=
OfflinePredictor
(
PredictConfig
(
input
s_desc
=
[
InputDesc
(
tf
.
float32
,
(
None
,
224
,
224
,
3
)
,
'input'
)],
input
_signature
=
[
tf
.
TensorSpec
((
None
,
224
,
224
,
3
),
tf
.
float32
,
'input'
)],
tower_func
=
tower_func
,
session_init
=
DictRestore
(
param_dict
),
input_names
=
[
'input'
],
...
...
examples/CaffeModels/load-vgg19.py
View file @
b8a50d72
...
...
@@ -62,7 +62,7 @@ def run_test(path, input):
param_dict
=
{
k
.
replace
(
'/W'
,
'/kernel'
)
.
replace
(
'/b'
,
'/bias'
):
v
for
k
,
v
in
six
.
iteritems
(
param_dict
)}
predict_func
=
OfflinePredictor
(
PredictConfig
(
input
s_desc
=
[
InputDesc
(
tf
.
float32
,
(
None
,
224
,
224
,
3
)
,
'input'
)],
input
_signature
=
[
tf
.
TensorSpec
((
None
,
224
,
224
,
3
),
tf
.
float32
,
'input'
)],
tower_func
=
tower_func
,
session_init
=
DictRestore
(
param_dict
),
input_names
=
[
'input'
],
...
...
examples/GAN/GAN.py
View file @
b8a50d72
...
...
@@ -88,7 +88,7 @@ class GANTrainer(TowerTrainer):
input
=
StagingInput
(
input
)
# Setup input
cbs
=
input
.
setup
(
model
.
get_input
s_desc
())
cbs
=
input
.
setup
(
model
.
get_input
_signature
())
self
.
register_callback
(
cbs
)
if
num_gpu
<=
1
:
...
...
@@ -105,7 +105,7 @@ class GANTrainer(TowerTrainer):
not needed. Just calling model.build_graph directly is OK.
"""
# Build the graph
self
.
tower_func
=
TowerFuncWrapper
(
model
.
build_graph
,
model
.
get_input
s_desc
())
self
.
tower_func
=
TowerFuncWrapper
(
model
.
build_graph
,
model
.
get_input
_signature
())
with
TowerContext
(
''
,
is_training
=
True
):
self
.
tower_func
(
*
input
.
get_input_tensors
())
opt
=
model
.
get_optimizer
()
...
...
@@ -127,7 +127,7 @@ class GANTrainer(TowerTrainer):
model
.
build_graph
(
*
inputs
)
return
[
model
.
d_loss
,
model
.
g_loss
]
self
.
tower_func
=
TowerFuncWrapper
(
get_cost
,
model
.
get_input
s_desc
())
self
.
tower_func
=
TowerFuncWrapper
(
get_cost
,
model
.
get_input
_signature
())
devices
=
[
LeastLoadedDeviceSetter
(
d
,
raw_devices
)
for
d
in
raw_devices
]
cost_list
=
DataParallelBuilder
.
build_on_towers
(
list
(
range
(
num_gpu
)),
...
...
@@ -163,11 +163,11 @@ class SeparateGANTrainer(TowerTrainer):
assert
min
(
d_period
,
g_period
)
==
1
# Setup input
cbs
=
input
.
setup
(
model
.
get_input
s_desc
())
cbs
=
input
.
setup
(
model
.
get_input
_signature
())
self
.
register_callback
(
cbs
)
# Build the graph
self
.
tower_func
=
TowerFuncWrapper
(
model
.
build_graph
,
model
.
get_input
s_desc
())
self
.
tower_func
=
TowerFuncWrapper
(
model
.
build_graph
,
model
.
get_input
_signature
())
with
TowerContext
(
''
,
is_training
=
True
),
\
argscope
(
BatchNorm
,
internal_update
=
True
):
# should not hook the updates to both train_op, it will hurt training speed.
...
...
examples/ImageNetModels/shufflenet.py
View file @
b8a50d72
...
...
@@ -254,14 +254,11 @@ if __name__ == '__main__':
eval_on_ILSVRC12
(
model
,
get_model_loader
(
args
.
load
),
ds
)
elif
args
.
flops
:
# manually build the graph with batch=1
input_desc
=
[
InputDesc
(
tf
.
float32
,
[
1
,
224
,
224
,
3
],
'input'
),
InputDesc
(
tf
.
int32
,
[
1
],
'label'
)
]
input
=
PlaceholderInput
()
input
.
setup
(
input_desc
)
with
TowerContext
(
''
,
is_training
=
False
):
model
.
build_graph
(
*
input
.
get_input_tensors
())
model
.
build_graph
(
tf
.
placeholder
(
tf
.
float32
,
[
1
,
224
,
224
,
3
],
'input'
),
tf
.
placeholder
(
tf
.
int32
,
[
1
],
'label'
)
)
model_utils
.
describe_trainable_vars
()
tf
.
profiler
.
profile
(
...
...
examples/basics/cifar-convnet.py
View file @
b8a50d72
...
...
@@ -64,7 +64,7 @@ class Model(ModelDesc):
cost
=
tf
.
nn
.
sparse_softmax_cross_entropy_with_logits
(
logits
=
logits
,
labels
=
label
)
cost
=
tf
.
reduce_mean
(
cost
,
name
=
'cross_entropy_loss'
)
correct
=
tf
.
cast
(
tf
.
nn
.
in_top_k
(
logits
,
label
,
1
),
tf
.
float32
,
name
=
'correct'
)
correct
=
tf
.
cast
(
tf
.
nn
.
in_top_k
(
predictions
=
logits
,
targets
=
label
,
k
=
1
),
tf
.
float32
,
name
=
'correct'
)
# monitor training error
add_moving_summary
(
tf
.
reduce_mean
(
correct
,
name
=
'accuracy'
))
...
...
@@ -76,7 +76,7 @@ class Model(ModelDesc):
return
tf
.
add_n
([
cost
,
wd_cost
],
name
=
'cost'
)
def
optimizer
(
self
):
lr
=
tf
.
get_variable
(
'learning_rate'
,
initializer
=
1e-2
,
trainable
=
False
)
lr
=
tf
.
Variable
(
1e-2
,
name
=
'learning_rate'
,
trainable
=
False
)
tf
.
summary
.
scalar
(
'lr'
,
lr
)
return
tf
.
train
.
AdamOptimizer
(
lr
,
epsilon
=
1e-3
)
...
...
examples/keras/imagenet-resnet-keras.py
View file @
b8a50d72
...
...
@@ -9,7 +9,7 @@ import os
import
tensorflow
as
tf
from
tensorflow.python.keras.layers
import
*
from
tensorpack
import
InputDesc
,
SyncMultiGPUTrainerReplicated
from
tensorpack
import
SyncMultiGPUTrainerReplicated
from
tensorpack.callbacks
import
*
from
tensorpack.contrib.keras
import
KerasModel
from
tensorpack.dataflow
import
FakeData
,
MapDataComponent
...
...
@@ -166,8 +166,8 @@ if __name__ == '__main__':
M
=
KerasModel
(
resnet50
,
input
s_desc
=
[
InputDesc
(
tf
.
uint8
,
[
None
,
224
,
224
,
3
]
,
'images'
)],
target
s_desc
=
[
InputDesc
(
tf
.
float32
,
[
None
,
1000
]
,
'labels'
)],
input
_signature
=
[
tf
.
TensorSpec
([
None
,
224
,
224
,
3
],
tf
.
uint8
,
'images'
)],
target
_signature
=
[
tf
.
TensorSpec
([
None
,
1000
],
tf
.
float32
,
'labels'
)],
input
=
df_train
,
trainer
=
SyncMultiGPUTrainerReplicated
(
num_gpu
))
...
...
examples/keras/mnist-keras-v2.py
View file @
b8a50d72
...
...
@@ -7,7 +7,7 @@ import numpy as np
import
tensorflow
as
tf
from
tensorflow
import
keras
from
tensorpack
import
InputDesc
,
QueueInput
from
tensorpack
import
QueueInput
from
tensorpack.callbacks
import
ModelSaver
from
tensorpack.contrib.keras
import
KerasModel
from
tensorpack.dataflow
import
BatchData
,
MapData
,
dataset
...
...
@@ -57,8 +57,8 @@ if __name__ == '__main__':
M
=
KerasModel
(
model_func
,
input
s_desc
=
[
InputDesc
(
tf
.
float32
,
[
None
,
IMAGE_SIZE
,
IMAGE_SIZE
,
1
]
,
'images'
)],
target
s_desc
=
[
InputDesc
(
tf
.
float32
,
[
None
,
10
]
,
'labels'
)],
input
_signature
=
[
tf
.
TensorSpec
([
None
,
IMAGE_SIZE
,
IMAGE_SIZE
,
1
],
tf
.
float32
,
'images'
)],
target
_signature
=
[
tf
.
TensorSpec
([
None
,
10
],
tf
.
float32
,
'labels'
)],
input
=
QueueInput
(
dataset_train
))
M
.
compile
(
optimizer
=
tf
.
train
.
AdamOptimizer
(
1e-3
),
...
...
tensorpack/callbacks/inference_runner.py
View file @
b8a50d72
...
...
@@ -141,7 +141,7 @@ class InferenceRunner(InferenceRunnerBase):
if
self
.
_tower_func
is
None
:
assert
self
.
trainer
.
tower_func
is
not
None
,
"You must set tower_func of the trainer to use InferenceRunner!"
self
.
_tower_func
=
self
.
trainer
.
tower_func
input_callbacks
=
self
.
_input_source
.
setup
(
self
.
_tower_func
.
input
s_desc
)
input_callbacks
=
self
.
_input_source
.
setup
(
self
.
_tower_func
.
input
_signature
)
vs_name
=
self
.
trainer
.
_vs_name_for_predictor
(
self
.
_device_id
)
logger
.
info
(
"[InferenceRunner] Building tower '{}' on device {} {}..."
.
format
(
...
...
@@ -223,7 +223,7 @@ class DataParallelInferenceRunner(InferenceRunnerBase):
assert
self
.
trainer
.
tower_func
is
not
None
,
"You must set tower_func of the trainer to use InferenceRunner!"
self
.
_tower_func
=
self
.
trainer
.
tower_func
input_callbacks
=
self
.
_input_source
.
setup
(
self
.
_tower_func
.
input
s_desc
)
input_callbacks
=
self
.
_input_source
.
setup
(
self
.
_tower_func
.
input
_signature
)
with
tf
.
variable_scope
(
tf
.
get_variable_scope
(),
reuse
=
True
):
for
idx
,
dev
in
enumerate
(
self
.
_devices
):
vs_name
=
self
.
trainer
.
_vs_name_for_predictor
(
idx
)
...
...
tensorpack/callbacks/param.py
View file @
b8a50d72
...
...
@@ -8,8 +8,8 @@ import numpy as np
from
abc
import
ABCMeta
,
abstractmethod
from
collections
import
deque
import
six
import
tensorflow
as
tf
from
..compat
import
tfv1
from
..tfutils.common
import
get_op_tensor_name
from
..utils
import
logger
from
.base
import
Callback
...
...
@@ -67,7 +67,7 @@ class GraphVarParam(HyperParam):
def
setup_graph
(
self
):
""" Will setup the assign operator for that variable. """
all_vars
=
tf
.
global_variables
()
+
tf
.
local_variables
()
all_vars
=
tf
v1
.
global_variables
()
+
tfv1
.
local_variables
()
for
v
in
all_vars
:
if
v
.
name
==
self
.
var_name
:
self
.
var
=
v
...
...
tensorpack/contrib/keras.py
View file @
b8a50d72
...
...
@@ -141,7 +141,7 @@ class KerasPhaseCallback(Callback):
def
setup_keras_trainer
(
trainer
,
get_model
,
input
s_desc
,
targets_desc
,
input
_signature
,
target_signature
,
input
,
optimizer
,
loss
,
metrics
):
"""
Args:
...
...
@@ -159,7 +159,7 @@ def setup_keras_trainer(
assert
isinstance
(
metrics
,
list
),
metrics
model_caller
=
KerasModelCaller
(
get_model
)
nr_inputs
=
len
(
input
s_desc
)
nr_inputs
=
len
(
input
_signature
)
def
get_cost
(
*
inputs
):
ctx
=
get_current_tower_context
()
...
...
@@ -211,7 +211,7 @@ def setup_keras_trainer(
return
total_loss
trainer
.
setup_graph
(
input
s_desc
+
targets_desc
,
input
_signature
+
target_signature
,
input
,
get_cost
,
lambda
:
optimizer
)
...
...
@@ -221,23 +221,27 @@ def setup_keras_trainer(
class
KerasModel
(
object
):
def
__init__
(
self
,
get_model
,
input
s_desc
,
targets_desc
,
input
,
trainer
=
None
):
def
__init__
(
self
,
get_model
,
input
_signature
=
None
,
target_signature
=
None
,
input
=
None
,
trainer
=
None
,
inputs_desc
=
None
,
targets_desc
=
None
):
"""
Args:
get_model (input1, input2, ... -> keras.Model):
A function which takes tensors, builds and returns a Keras model.
It will be part of the tower function.
input
s_desc ([InputDesc]):
target
s_desc ([InputDesc]):
input (InputSource | DataFlow):
trainer (Trainer): the default will check the number of available
GPUs and use them all.
input
_signature ([tf.TensorSpec]): required. The signature for inputs.
target
_signature ([tf.TensorSpec]): required. The signature for the targets tensors.
input (InputSource | DataFlow):
the InputSource or DataFlow where the input data comes from.
trainer (Trainer): the default will check the number of available
GPUs and use them all.
inputs_desc, targets_desc: deprecated names for `input_signature` and `target_signature`
"""
if
inputs_desc
is
not
None
:
input_signature
=
inputs_desc
if
targets_desc
is
not
None
:
target_signature
=
targets_desc
self
.
get_model
=
get_model
assert
callable
(
get_model
),
get_model
self
.
input
s_desc
=
inputs_desc
self
.
target
s_desc
=
targets_desc
self
.
input
_signature
=
input_signature
self
.
target
_signature
=
target_signature
if
trainer
is
None
:
nr_gpu
=
get_nr_gpu
()
if
nr_gpu
<=
1
:
...
...
@@ -248,6 +252,7 @@ class KerasModel(object):
assert
isinstance
(
trainer
,
Trainer
),
trainer
assert
not
isinstance
(
trainer
,
DistributedTrainerBase
)
assert
input
is
not
None
,
"Argument 'input' is required!"
self
.
input
=
apply_default_prefetch
(
input
,
trainer
)
self
.
trainer
=
trainer
...
...
@@ -267,7 +272,8 @@ class KerasModel(object):
self
.
_stats_to_inference
=
loss
+
metrics
+
[
TOTAL_LOSS_NAME
]
setup_keras_trainer
(
self
.
trainer
,
get_model
=
self
.
get_model
,
inputs_desc
=
self
.
inputs_desc
,
targets_desc
=
self
.
targets_desc
,
input_signature
=
self
.
input_signature
,
target_signature
=
self
.
target_signature
,
input
=
self
.
input
,
optimizer
=
optimizer
,
loss
=
loss
,
...
...
tensorpack/graph_builder/model_desc.py
View file @
b8a50d72
...
...
@@ -18,12 +18,41 @@ TensorSpec = backport_tensor_spec()
__all__
=
[
'InputDesc'
,
'ModelDesc'
,
'ModelDescBase'
]
def
build_or_reuse_placeholder
(
tensor_spec
):
"""
Build a tf.placeholder from the metadata in the given tensor spec, or return an existing one.
Args:
tensor_spec (tf.TensorSpec):
Returns:
tf.Tensor:
"""
g
=
tfv1
.
get_default_graph
()
name
=
tensor_spec
.
name
try
:
tensor
=
g
.
get_tensor_by_name
(
name
+
':0'
)
assert
"Placeholder"
in
tensor
.
op
.
type
,
"Tensor {} exists but is not a placeholder!"
.
format
(
name
)
assert
tensor_spec
.
is_compatible_with
(
tensor
),
\
"Tensor {} exists but is not compatible with the signature!"
.
format
(
tensor
)
return
tensor
except
KeyError
:
with
tfv1
.
name_scope
(
None
):
# clear any name scope it might get called in
ret
=
tfv1
.
placeholder
(
tensor_spec
.
dtype
,
shape
=
tensor_spec
.
shape
,
name
=
tensor_spec
.
name
)
return
ret
class
InputDesc
(
namedtuple
(
'InputDescTuple'
,
[
'type'
,
'shape'
,
'name'
])):
"""
Metadata about an input entry point to the graph.
This metadata can be later used to build placeholders or other types of
input source.
An equivalent of `tf.TensorSpec`.
History: this concept is used to represent metadata about the inputs,
which can be later used to build placeholders or other types of input source.
It is introduced much much earlier than the equivalent concept `tf.TensorSpec`
was introduced in TensorFlow.
Therefore, we now switched to use `tf.TensorSpec`, but keep this here for compatibility reasons.
"""
def
__new__
(
cls
,
type
,
shape
,
name
):
...
...
@@ -33,64 +62,9 @@ class InputDesc(
shape (tuple):
name (str):
"""
shape
=
tuple
(
shape
)
# has to be tuple for "self" to be hashable
# TODO mark deprecated
assert
isinstance
(
type
,
tf
.
DType
),
type
if
any
(
k
in
name
for
k
in
[
':'
,
'/'
,
' '
]):
raise
ValueError
(
"Invalid InputDesc name: '{}'"
.
format
(
name
))
self
=
super
(
InputDesc
,
cls
)
.
__new__
(
cls
,
type
,
shape
,
name
)
self
.
_cached_placeholder
=
{}
return
self
def
_build_placeholder
(
self
):
"""
Build a tf.placeholder from the metadata.
Returns:
tf.Tensor:
"""
with
tfv1
.
name_scope
(
None
):
# clear any name scope it might get called in
ret
=
tfv1
.
placeholder
(
self
.
type
,
shape
=
self
.
shape
,
name
=
self
.
name
)
self
.
_register_cached_placeholder
(
ret
)
return
ret
# cannot memoize here, because InputDesc is hashed by its fields.
def
build_placeholder_reuse
(
self
):
"""
Build a tf.placeholder from the metadata, or return an old one.
Returns:
tf.Tensor:
"""
g
=
tfv1
.
get_default_graph
()
if
g
in
self
.
_cached_placeholder
:
return
self
.
_cached_placeholder
[
g
]
else
:
return
self
.
_build_placeholder
()
def
_register_cached_placeholder
(
self
,
placeholder
):
graph
=
placeholder
.
graph
assert
graph
not
in
self
.
_cached_placeholder
,
\
"Placeholder for this InputDesc had been created before! This is a bug."
self
.
_cached_placeholder
[
graph
]
=
placeholder
@
staticmethod
def
_from_placeholder
(
placeholder
):
name
=
placeholder
.
op
.
name
if
name
.
endswith
(
'_1'
)
or
name
.
endswith
(
'_2'
):
logger
.
error
(
"Creating InputDesc from a placeholder named {}."
.
format
(
name
))
logger
.
error
(
"You might have mistakenly created this placeholder multiple times!"
)
ret
=
InputDesc
(
placeholder
.
dtype
,
tuple
(
placeholder
.
shape
.
as_list
()),
name
)
ret
.
_register_cached_placeholder
(
placeholder
)
return
ret
@
staticmethod
def
_from_tensor_spec
(
spec
):
assert
spec
.
name
is
not
None
,
"TensorSpec should have a name!"
return
InputDesc
(
spec
.
dtype
,
tuple
(
spec
.
shape
.
as_list
()),
spec
.
name
)
return
tf
.
TensorSpec
(
shape
=
shape
,
dtype
=
type
,
name
=
name
)
class
ModelDescBase
(
object
):
...
...
@@ -100,29 +74,22 @@ class ModelDescBase(object):
@
memoized_method
def
get_inputs_desc
(
self
):
# TODO mark deprecated
return
self
.
get_input_signature
()
@
memoized_method
def
get_input_signature
(
self
):
"""
Returns:
A list of :class:`
InputDes
c`, which describes the inputs of this model.
A list of :class:`
tf.TensorSpe
c`, which describes the inputs of this model.
The result is cached for each instance of :class:`ModelDescBase`.
"""
try
:
ret
=
self
.
_get_inputs
()
log_deprecated
(
"ModelDescBase._get_inputs() interface"
,
"Use inputs() instead!"
,
"2019-03-30"
)
return
ret
except
NotImplementedError
:
with
tf
.
Graph
()
.
as_default
()
as
G
:
# create these placeholder in a temporary graph
inputs
=
self
.
inputs
()
if
isinstance
(
inputs
[
0
],
tf
.
Tensor
):
for
p
in
inputs
:
assert
p
.
graph
==
G
,
"Placeholders returned by inputs() should be created inside inputs()!"
return
[
InputDesc
.
_from_placeholder
(
p
)
for
p
in
inputs
]
else
:
for
p
in
inputs
:
assert
isinstance
(
p
,
TensorSpec
),
type
(
p
)
return
[
InputDesc
.
_from_tensor_spec
(
p
)
for
p
in
inputs
]
with
tf
.
Graph
()
.
as_default
()
as
G
:
# create these placeholder in a temporary graph
inputs
=
self
.
inputs
()
if
isinstance
(
inputs
[
0
],
tf
.
Tensor
):
for
p
in
inputs
:
assert
p
.
graph
==
G
,
"Placeholders returned by inputs() should be created inside inputs()!"
return
[
TensorSpec
(
shape
=
p
.
shape
,
dtype
=
p
.
dtype
,
name
=
p
.
name
)
for
p
in
inputs
]
@
property
def
input_names
(
self
):
...
...
@@ -130,7 +97,7 @@ class ModelDescBase(object):
Returns:
[str]: the names of all the inputs.
"""
return
[
k
.
name
for
k
in
self
.
get_input
s_desc
()]
return
[
k
.
name
for
k
in
self
.
get_input
_signature
()]
def
_get_inputs
(
self
):
raise
NotImplementedError
()
...
...
@@ -147,7 +114,7 @@ class ModelDescBase(object):
Also, you should never call this method by yourself.
Returns:
list[tf.
placeholder] or list[tf.TensorSpec], to be converted to :class:`InputDes
c`.
list[tf.
TensorSpec or tf.placeholder]. To be converted to :class:`tf.TensorSpe
c`.
"""
raise
NotImplementedError
()
...
...
@@ -166,9 +133,9 @@ class ModelDescBase(object):
may require it to return necessary information to build the trainer.
For example, `SingleCostTrainer` expect this method to return the cost tensor.
"""
assert
len
(
args
)
==
len
(
self
.
get_input
s_desc
()),
\
assert
len
(
args
)
==
len
(
self
.
get_input
_signature
()),
\
"Number of inputs passed to the graph != number of inputs defined "
\
"in ModelDesc! ({} != {})"
.
format
(
len
(
args
),
len
(
self
.
get_input
s_desc
()))
"in ModelDesc! ({} != {})"
.
format
(
len
(
args
),
len
(
self
.
get_input
_signature
()))
log_deprecated
(
"ModelDescBase._build_graph() interface"
,
"Use build_graph() instead!"
,
...
...
tensorpack/input_source/input_source.py
View file @
b8a50d72
...
...
@@ -19,6 +19,7 @@ from ..tfutils.tower import get_current_tower_context
from
..utils
import
logger
from
..utils.concurrency
import
ShareSessionThread
from
.input_source_base
import
InputSource
from
..graph_builder.model_desc
import
build_or_reuse_placeholder
try
:
from
tensorflow.python.ops.data_flow_ops
import
StagingArea
...
...
@@ -59,7 +60,7 @@ class PlaceholderInput(InputSource):
pass
def
_setup
(
self
,
inputs
):
self
.
_all_placehdrs
=
[
v
.
build_placeholder_reuse
(
)
for
v
in
inputs
]
self
.
_all_placehdrs
=
[
build_or_reuse_placeholder
(
v
)
for
v
in
inputs
]
def
_get_input_tensors
(
self
):
return
self
.
_all_placehdrs
...
...
@@ -110,7 +111,7 @@ class FeedInput(InputSource):
def
_setup
(
self
,
inputs
):
# placeholders as input are always safe to reuse.
self
.
_all_placehdrs
=
[
v
.
build_placeholder_reuse
(
)
for
v
in
inputs
]
self
.
_all_placehdrs
=
[
build_or_reuse_placeholder
(
v
)
for
v
in
inputs
]
self
.
_cb
=
self
.
_FeedCallback
(
self
.
_iter_ds
,
self
.
_all_placehdrs
)
def
_get_input_tensors
(
self
):
...
...
@@ -196,7 +197,7 @@ class QueueInput(FeedfreeInput):
Args:
ds(DataFlow): the input DataFlow.
queue (tf.QueueBase): A :class:`tf.QueueBase` whose type
should match the corresponding
InputDesc
of the model.
should match the corresponding
input signature
of the model.
Defaults to a FIFO queue of size 50.
"""
if
not
isinstance
(
ds
,
DataFlow
):
...
...
@@ -210,12 +211,12 @@ class QueueInput(FeedfreeInput):
return
len
(
self
.
ds
)
def
_setup
(
self
,
inputs
):
self
.
_input_placehdrs
=
[
v
.
build_placeholder_reuse
(
)
for
v
in
inputs
]
self
.
_input_placehdrs
=
[
build_or_reuse_placeholder
(
v
)
for
v
in
inputs
]
assert
len
(
self
.
_input_placehdrs
)
>
0
,
\
"QueueInput has to be used with some inputs!"
with
self
.
cached_name_scope
():
if
self
.
queue
is
None
:
self
.
queue
=
tf
.
FIFOQueue
(
self
.
queue
=
tf
v1
.
FIFOQueue
(
50
,
[
x
.
dtype
for
x
in
self
.
_input_placehdrs
],
name
=
'input_queue'
)
logger
.
info
(
"Setting up the queue '{}' for CPU prefetching ..."
.
format
(
self
.
queue
.
name
))
...
...
@@ -287,7 +288,7 @@ class BatchQueueInput(QueueInput):
ds(DataFlow): the input DataFlow.
batch_size(int): the batch size.
queue (tf.QueueBase): A :class:`tf.QueueBase` whose type
should match the corresponding
InputDesc
of the model.
should match the corresponding
input signature
of the model.
Defaults to a FIFO queue of size 3000.
"""
super
(
BatchQueueInput
,
self
)
.
__init__
(
ds
,
queue
)
...
...
@@ -298,9 +299,9 @@ class BatchQueueInput(QueueInput):
def
_setup
(
self
,
inputs
):
logger
.
info
(
"Setting up the queue for CPU prefetching ..."
)
self
.
input_placehdrs
=
[
v
.
build_placeholder_reuse
(
)
for
v
in
inputs
]
self
.
input_placehdrs
=
[
build_or_reuse_placeholder
(
v
)
for
v
in
inputs
]
assert
len
(
self
.
input_placehdrs
)
>
0
,
\
"BatchQueueInput has to be used with some
InputDesc
!"
"BatchQueueInput has to be used with some
input signature
!"
# prepare placeholders without the first dimension
placehdrs_nobatch
=
[]
...
...
@@ -364,8 +365,8 @@ class TensorInput(FeedfreeInput):
assert
size
>
0
self
.
_fixed_size
=
size
def
_setup
(
self
,
input
s_desc
):
self
.
_
desc
=
inputs_desc
def
_setup
(
self
,
input
_signature
):
self
.
_
spec
=
input_signature
def
_size
(
self
):
if
self
.
_fixed_size
is
None
:
...
...
@@ -376,8 +377,8 @@ class TensorInput(FeedfreeInput):
with
self
.
cached_name_scope
():
ret
=
self
.
get_tensor_fn
()
assert
isinstance
(
ret
,
(
list
,
tuple
)),
"get_tensor_fn needs to return a list!"
assert
len
(
ret
)
==
len
(
self
.
_
des
c
),
\
"get_tensor_fn returns {} tensors but there are {} inputs"
.
format
(
len
(
ret
),
len
(
self
.
_
des
c
))
assert
len
(
ret
)
==
len
(
self
.
_
spe
c
),
\
"get_tensor_fn returns {} tensors but there are {} inputs"
.
format
(
len
(
ret
),
len
(
self
.
_
spe
c
))
return
ret
...
...
@@ -399,7 +400,7 @@ class DummyConstantInput(TensorInput):
assert
len
(
self
.
shapes
)
==
len
(
self
.
_desc
)
for
idx
,
p
in
enumerate
(
self
.
_desc
):
tlist
.
append
(
tf
.
constant
(
0
,
dtype
=
p
.
type
,
0
,
dtype
=
p
.
d
type
,
name
=
'dummy-{}-{}'
.
format
(
p
.
name
,
ctx
.
index
),
shape
=
self
.
shapes
[
idx
]))
return
tlist
...
...
@@ -429,15 +430,14 @@ class ZMQInput(TensorInput):
return
ret
super
(
ZMQInput
,
self
)
.
__init__
(
fn
)
def
_setup
(
self
,
inputs_desc
):
assert
len
(
inputs_desc
)
>
0
,
\
"ZMQInput has to be used with InputDesc!"
self
.
_desc
=
inputs_desc
def
_setup
(
self
,
input_signature
):
assert
len
(
input_signature
)
>
0
,
\
"ZMQInput has to be used with input signature!"
import
zmq_ops
self
.
_zmq_pull_socket
=
zmq_ops
.
ZMQPullSocket
(
self
.
_end_point
,
[
x
.
type
for
x
in
inputs_desc
],
[
x
.
dtype
for
x
in
input_signature
],
hwm
=
self
.
_hwm
,
bind
=
self
.
_bind
)
...
...
@@ -458,23 +458,23 @@ class TFDatasetInput(FeedfreeInput):
raise
ValueError
(
"TFDatasetInput takes a tf.data.Dataset! Got {}"
.
format
(
dataset
))
self
.
_dataset
=
dataset
def
_setup
(
self
,
input
s_desc
):
self
.
_
desc
=
inputs_desc
def
_setup
(
self
,
input
_signature
):
self
.
_
spec
=
input_signature
types
=
self
.
_dataset
.
output_types
desc_types
=
tuple
([
k
.
type
for
k
in
inputs_desc
])
assert
len
(
types
)
==
len
(
des
c_types
),
\
"Dataset and
InputDesc has
different length! {} != {}"
.
format
(
len
(
types
),
len
(
des
c_types
))
assert
types
==
des
c_types
,
\
"
Types of dataset and InputDesc
don't match! {} != {}"
.
format
(
str
(
types
),
str
(
des
c_types
))
spec_types
=
tuple
([
k
.
dtype
for
k
in
input_signature
])
assert
len
(
types
)
==
len
(
spe
c_types
),
\
"Dataset and
input signature have
different length! {} != {}"
.
format
(
len
(
types
),
len
(
spe
c_types
))
assert
types
==
spe
c_types
,
\
"
Data types of dataset and input signature
don't match! {} != {}"
.
format
(
str
(
types
),
str
(
spe
c_types
))
shapes
=
self
.
_dataset
.
output_shapes
desc_shapes
=
[
k
.
shape
for
k
in
inputs_desc
]
for
idx
,
(
s1
,
s2
)
in
enumerate
(
zip
(
shapes
,
des
c_shapes
)):
spec_shapes
=
[
k
.
shape
for
k
in
input_signature
]
for
idx
,
(
s1
,
s2
)
in
enumerate
(
zip
(
shapes
,
spe
c_shapes
)):
s2
=
tf
.
TensorShape
(
s2
)
assert
s2
.
is_compatible_with
(
s1
),
\
"Input
Desc
'{}' has incompatible shape with dataset! {} vs {}"
.
format
(
input
s_desc
[
idx
]
.
name
,
s2
,
s1
)
"Input
signature
'{}' has incompatible shape with dataset! {} vs {}"
.
format
(
input
_signature
[
idx
]
.
name
,
s2
,
s1
)
self
.
_iterator
=
self
.
_dataset
.
make_initializable_iterator
()
self
.
_init_op
=
self
.
_iterator
.
initializer
...
...
@@ -482,11 +482,11 @@ class TFDatasetInput(FeedfreeInput):
self
.
_init_op
.
run
()
def
_get_input_tensors
(
self
):
desc_shapes
=
[
k
.
shape
for
k
in
self
.
_des
c
]
spec_shapes
=
[
k
.
shape
for
k
in
self
.
_spe
c
]
ret
=
self
.
_iterator
.
get_next
()
assert
len
(
ret
)
==
len
(
des
c_shapes
),
\
"Dataset returns {} tensors but there are {} inputs!"
.
format
(
len
(
ret
),
len
(
des
c_shapes
))
for
t
,
shp
in
zip
(
ret
,
des
c_shapes
):
assert
len
(
ret
)
==
len
(
spe
c_shapes
),
\
"Dataset returns {} tensors but there are {} inputs!"
.
format
(
len
(
ret
),
len
(
spe
c_shapes
))
for
t
,
shp
in
zip
(
ret
,
spe
c_shapes
):
t
.
set_shape
(
shp
)
return
ret
...
...
tensorpack/input_source/input_source_base.py
View file @
b8a50d72
...
...
@@ -12,6 +12,7 @@ from ..callbacks.base import CallbackFactory
from
..tfutils.common
import
get_op_tensor_name
from
..utils
import
logger
from
..utils.argtools
import
call_only_once
,
memoized_method
from
..graph_builder.model_desc
import
build_or_reuse_placeholder
__all__
=
[
'InputSource'
,
'remap_input_source'
]
...
...
@@ -86,20 +87,20 @@ class InputSource(object):
pass
@
call_only_once
def
setup
(
self
,
input
s_desc
):
def
setup
(
self
,
input
_signature
):
"""
Args:
input
s_desc (list[InputDesc]): list of input desc
input
_signature (list[tf.TensorSpec]): list of specs for each input tensor
Returns:
list[Callback]: extra callbacks needed by this InputSource.
callbacks of InputSource cannot use any `trigger*()` method.
"""
self
.
_setup
(
input
s_desc
)
self
.
_setup
(
input
_signature
)
self
.
_setup_done
=
True
return
self
.
get_callbacks
()
def
_setup
(
self
,
input
s_desc
):
def
_setup
(
self
,
input
_signature
):
pass
def
setup_done
(
self
):
...
...
@@ -190,8 +191,8 @@ class ProxyInputSource(InputSource):
def
_get_input_tensors
(
self
):
return
self
.
_input
.
get_input_tensors
()
def
_setup
(
self
,
input
s_desc
):
self
.
_input
.
setup
(
input
s_desc
)
def
_setup
(
self
,
input
_signature
):
self
.
_input
.
setup
(
input
_signature
)
def
_get_callbacks
(
self
):
return
self
.
_input
.
get_callbacks
()
...
...
@@ -226,11 +227,11 @@ def remap_input_source(input, names):
input1 = QueueInput(ds)
# assume ds produces 'image' and 'label', but the graph takes more
# inputs for some reasons, or takes inputs of a different order:
input
s_desc = [InputDesc(tf.float32, (None,10)
, 'score'),
InputDesc(tf.float32, (None,20,20,3)
, 'label'),
InputDesc(tf.int32, (None,)
, 'image') ]
input
_signature = [tf.TensorSpec((None,10), tf.float32
, 'score'),
tf.TensorSpec((None,20,20,3), tf.float32
, 'label'),
tf.TensorSpec((None,), tf.int32
, 'image') ]
input2 = remap_input_source(input1, ['image', 'label'])
input2.setup(input
s_desc
)
input2.setup(input
_signature
)
# now, input2.get_input_tensors() will return a placeholder for 'score',
# plus the tensors returned by input1.get_input_tensors()
"""
...
...
@@ -240,7 +241,7 @@ def remap_input_source(input, names):
self
.
_names
=
tuple
(
names
)
def
_setup
(
self
,
inputs
):
self
.
_all_placehdrs
=
[
v
.
build_placeholder_reuse
(
)
for
v
in
inputs
]
self
.
_all_placehdrs
=
[
build_or_reuse_placeholder
(
v
)
for
v
in
inputs
]
inputs_subset
=
get_sublist_by_names
(
inputs
,
self
.
_names
)
self
.
_input
.
setup
(
inputs_subset
)
...
...
tensorpack/predict/base.py
View file @
b8a50d72
...
...
@@ -155,7 +155,7 @@ class OfflinePredictor(OnlinePredictor):
self
.
graph
=
config
.
_maybe_create_graph
()
with
self
.
graph
.
as_default
():
input
=
PlaceholderInput
()
input
.
setup
(
config
.
input
s_desc
)
input
.
setup
(
config
.
input
_signature
)
with
PredictTowerContext
(
''
):
config
.
tower_func
(
*
input
.
get_input_tensors
())
...
...
tensorpack/predict/config.py
View file @
b8a50d72
...
...
@@ -18,7 +18,7 @@ class PredictConfig(object):
def
__init__
(
self
,
model
=
None
,
tower_func
=
None
,
input
s_desc
=
None
,
input
_signature
=
None
,
input_names
=
None
,
output_names
=
None
,
...
...
@@ -27,11 +27,18 @@ class PredictConfig(object):
session_init
=
None
,
return_input
=
False
,
create_graph
=
True
,
inputs_desc
=
None
):
"""
You need to set either `model`, or `inputs_desc` plus `tower_func`.
They are needed to construct the graph.
You'll also have to set `output_names` as it does not have a default.
Users need to provide enough arguments to create a tower function,
which will be used to construct the graph.
This can be provided in the following ways:
1. `model`: a :class:`ModelDesc` instance. It will contain a tower function by itself.
2. `tower_func`: a :class:`tfutils.TowerFuncWrapper` instance.
Provide a tower function instance directly.
3. `tower_func`: a symbolic function and `input_signature`: the signature of the function.
Provide both a function and its signature.
Example:
...
...
@@ -42,15 +49,14 @@ class PredictConfig(object):
output_names=['linear/output', 'prediction'])
Args:
model (ModelDescBase): to be used to
obtain inputs_desc and tower_func
.
model (ModelDescBase): to be used to
construct a tower function
.
tower_func: a callable which takes input tensors (by positional args) and construct a tower.
or a :class:`tfutils.TowerFuncWrapper` instance
, which packs both `inputs_desc` and function together
.
input
s_desc ([InputDesc]): if tower_func is a plain function (instead of a TowerFuncWrapper), this describes
the list of inputs it takes.
or a :class:`tfutils.TowerFuncWrapper` instance.
input
_signature ([tf.TensorSpec]): if tower_func is a plain function (instead of a TowerFuncWrapper),
th
is describes th
e list of inputs it takes.
input_names (list): a list of input tensor names. Defaults to match inputs_desc.
The name can be either the name of a tensor, or the name of one input defined
by `inputs_desc` or by `model`.
input_names (list): a list of input tensor names. Defaults to match input_signature.
The name can be either the name of a tensor, or the name of one input of the tower.
output_names (list): a list of names of the output tensors to predict, the
tensors can be any tensor in the graph that's computable from the tensors correponding to `input_names`.
...
...
@@ -62,23 +68,29 @@ class PredictConfig(object):
return_input (bool): same as in :attr:`PredictorBase.return_input`.
create_graph (bool): create a new graph, or use the default graph
when predictor is first initialized.
inputs_desc (list[tf.TensorSpec]): old (deprecated) name for `input_signature`.
"""
def
assert_type
(
v
,
tp
,
name
):
assert
isinstance
(
v
,
tp
),
\
"
{}
has to be type '{}', but an object of type '{}' found."
.
format
(
"
Argument '{}'
has to be type '{}', but an object of type '{}' found."
.
format
(
name
,
tp
.
__name__
,
v
.
__class__
.
__name__
)
if
inputs_desc
is
not
None
:
# TODO warn deprecated or not?
assert
input_signature
is
None
,
"Cannot set both inputs_desc and input_signature!"
input_signature
=
inputs_desc
if
model
is
not
None
:
assert_type
(
model
,
ModelDescBase
,
'model'
)
assert
input
s_desc
is
None
and
tower_func
is
None
self
.
input
s_desc
=
model
.
get_inputs_desc
()
self
.
tower_func
=
TowerFuncWrapper
(
model
.
build_graph
,
self
.
input
s_desc
)
assert
input
_signature
is
None
and
tower_func
is
None
self
.
input
_signature
=
model
.
get_input_signature
()
self
.
tower_func
=
TowerFuncWrapper
(
model
.
build_graph
,
self
.
input
_signature
)
else
:
if
isinstance
(
tower_func
,
TowerFuncWrapper
):
input
s_desc
=
tower_func
.
inputs_desc
assert
input
s_desc
is
not
None
and
tower_func
is
not
None
self
.
input
s_desc
=
inputs_desc
self
.
tower_func
=
TowerFuncWrapper
(
tower_func
,
input
s_desc
)
input
_signature
=
tower_func
.
input_signature
assert
input
_signature
is
not
None
and
tower_func
is
not
None
self
.
input
_signature
=
input_signature
self
.
tower_func
=
TowerFuncWrapper
(
tower_func
,
input
_signature
)
if
session_init
is
None
:
session_init
=
JustCurrentSession
()
...
...
@@ -93,20 +105,22 @@ class PredictConfig(object):
# inputs & outputs
self
.
input_names
=
input_names
if
self
.
input_names
is
None
:
self
.
input_names
=
[
k
.
name
for
k
in
self
.
inputs_desc
]
self
.
input_names
=
[
k
.
name
for
k
in
self
.
input_signature
]
assert
output_names
is
not
None
,
"Argument 'output_names' is not provided!"
self
.
output_names
=
output_names
assert_type
(
self
.
output_names
,
list
,
'output_names'
)
assert_type
(
self
.
input_names
,
list
,
'input_names'
)
if
len
(
self
.
input_names
)
==
0
:
logger
.
warn
(
'PredictConfig receives empty "input_names".'
)
# assert len(self.input_names), self.input_names
for
v
in
self
.
input_names
:
assert_type
(
v
,
six
.
string_types
,
'Each item in input_names'
)
assert
len
(
self
.
output_names
),
self
.
output_names
assert
len
(
self
.
output_names
),
"Argument 'output_names' cannot be empty!"
self
.
return_input
=
bool
(
return_input
)
self
.
create_graph
=
bool
(
create_graph
)
self
.
inputs_desc
=
input_signature
# TODO a little bit of compatibility
def
_maybe_create_graph
(
self
):
if
self
.
create_graph
:
return
tf
.
Graph
()
...
...
tensorpack/predict/feedfree.py
View file @
b8a50d72
...
...
@@ -21,7 +21,7 @@ class FeedfreePredictor(PredictorBase):
Args:
config (PredictConfig): the config to use.
input_source (InputSource): the feedfree InputSource to use.
Must match the
inputs_desc
in config.
Must match the
signature of the tower function
in config.
"""
self
.
_config
=
config
self
.
_input_source
=
input_source
...
...
@@ -33,7 +33,7 @@ class FeedfreePredictor(PredictorBase):
self
.
graph
=
config
.
_maybe_create_graph
()
with
self
.
graph
.
as_default
():
self
.
_input_callbacks
=
Callbacks
(
self
.
_input_source
.
setup
(
config
.
input
s_desc
))
self
.
_input_source
.
setup
(
config
.
input
_signature
))
with
PredictTowerContext
(
''
):
self
.
_input_tensors
=
self
.
_input_source
.
get_input_tensors
()
config
.
tower_func
(
*
self
.
_input_tensors
)
...
...
tensorpack/predict/multigpu.py
View file @
b8a50d72
...
...
@@ -4,7 +4,6 @@
import
tensorflow
as
tf
from
..graph_builder.model_desc
import
InputDesc
from
..input_source
import
PlaceholderInput
from
..tfutils.tower
import
PredictTowerContext
from
..utils
import
logger
...
...
@@ -33,7 +32,7 @@ class MultiTowerOfflinePredictor(OnlinePredictor):
handles
=
[]
input
=
PlaceholderInput
()
input
.
setup
(
config
.
input
s_desc
)
input
.
setup
(
config
.
input
_signature
)
for
idx
,
t
in
enumerate
(
towers
):
tower_name
=
'tower'
+
str
(
t
)
...
...
@@ -102,10 +101,10 @@ class DataParallelOfflinePredictor(OnlinePredictor):
for
idx
,
t
in
enumerate
(
towers
):
tower_name
=
'tower'
+
str
(
t
)
inputs_desc
=
[
InputDesc
(
desc
.
type
,
desc
.
shape
,
tower_name
+
'_'
+
desc
.
name
)
for
desc
in
config
.
inputs_desc
]
new_sig
=
[
tf
.
TensorSpec
(
dtype
=
p
.
dtype
,
shape
=
p
.
shape
,
name
=
tower_name
+
'_'
+
p
.
name
)
for
p
in
config
.
input_signature
]
input
=
PlaceholderInput
()
input
.
setup
(
inputs_desc
)
input
.
setup
(
new_sig
)
with
tf
.
variable_scope
(
tf
.
get_variable_scope
(),
reuse
=
idx
>
0
),
\
tf
.
device
(
'/gpu:{}'
.
format
(
t
)),
\
...
...
tensorpack/tfutils/export.py
View file @
b8a50d72
...
...
@@ -28,7 +28,7 @@ class ModelExporter(object):
Args:
config (PredictConfig): the config to use.
The graph will be built with
`config.tower_func` and `config.inputs_desc
`.
The graph will be built with
the tower function defined by this `PredictConfig
`.
Then the input / output names will be used to export models for inference.
"""
super
(
ModelExporter
,
self
)
.
__init__
()
...
...
@@ -51,7 +51,7 @@ class ModelExporter(object):
self
.
graph
=
self
.
config
.
_maybe_create_graph
()
with
self
.
graph
.
as_default
():
input
=
PlaceholderInput
()
input
.
setup
(
self
.
config
.
input
s_desc
)
input
.
setup
(
self
.
config
.
input
_signature
)
with
PredictTowerContext
(
''
):
self
.
config
.
tower_func
(
*
input
.
get_input_tensors
())
...
...
@@ -116,7 +116,7 @@ class ModelExporter(object):
self
.
graph
=
self
.
config
.
_maybe_create_graph
()
with
self
.
graph
.
as_default
():
input
=
PlaceholderInput
()
input
.
setup
(
self
.
config
.
input
s_desc
)
input
.
setup
(
self
.
config
.
input
_signature
)
with
PredictTowerContext
(
''
):
self
.
config
.
tower_func
(
*
input
.
get_input_tensors
())
...
...
tensorpack/tfutils/tower.py
View file @
b8a50d72
...
...
@@ -257,24 +257,27 @@ class TowerFuncWrapper(object):
Conceptually, this class is roughly equivalent to `tf.function` with input signature, introduced in TF 2.0.
"""
def
__init__
(
self
,
tower_fn
,
input
s_desc
):
def
__init__
(
self
,
tower_fn
,
input
_signature
):
"""
Args:
tower_func: a function which builds one tower in the graph.
It takes several input tensors and could return anything.
input
s_desc ([InputDesc]): list of :class:`InputDes
c`.
input
_signature ([TensorSpec]): list of :class:`tf.TensorSpe
c`.
They are used to figure out the names for the input tensors.
"""
assert
callable
(
tower_fn
),
tower_fn
self
.
_inputs_desc_names
=
[
k
.
name
for
k
in
inputs_desc
]
assert
len
(
set
(
self
.
_inputs_desc_names
))
==
len
(
self
.
_inputs_desc_names
),
\
"Duplicated names in inputs_desc! "
+
str
(
self
.
_inputs_desc_names
)
self
.
_inputs_names
=
[
k
.
name
for
k
in
input_signature
]
assert
len
(
set
(
self
.
_inputs_names
))
==
len
(
self
.
_inputs_names
),
\
"Duplicated names in input_signature! "
+
str
(
self
.
_inputs_names
)
for
name
in
self
.
_inputs_names
:
if
any
(
k
in
name
for
k
in
[
':'
,
'/'
,
' '
]):
raise
ValueError
(
"Invalid input name: '{}'"
.
format
(
name
))
self
.
_tower_fn
=
tower_fn
self
.
_input
s_desc
=
inputs_desc
self
.
_input
_signature
=
input_signature
self
.
_handles
=
[]
def
__new__
(
cls
,
tower_fn
,
inputs_desc
):
def
__new__
(
cls
,
tower_fn
,
_
):
# to avoid double-wrapping a function
if
isinstance
(
tower_fn
,
TowerFuncWrapper
):
return
tower_fn
...
...
@@ -285,7 +288,7 @@ class TowerFuncWrapper(object):
ctx
=
get_current_tower_context
()
assert
ctx
is
not
None
,
"Function must be called under TowerContext!"
output
=
self
.
_tower_fn
(
*
args
)
handle
=
TowerTensorHandle
(
ctx
,
args
,
output
,
self
.
_input
s_desc
)
handle
=
TowerTensorHandle
(
ctx
,
args
,
output
,
self
.
_input
_signature
)
self
.
_handles
.
append
(
handle
)
return
output
...
...
@@ -298,9 +301,14 @@ class TowerFuncWrapper(object):
"""
return
TowerTensorHandles
(
self
.
_handles
)
@
property
def
input_signature
(
self
):
return
self
.
_input_signature
@
property
def
inputs_desc
(
self
):
return
self
.
_inputs_desc
# TODO mark deprecated
return
self
.
_input_signature
class
TowerTensorHandles
(
object
):
...
...
@@ -354,14 +362,14 @@ class TowerTensorHandle(object):
"""
@
HIDE_DOC
def
__init__
(
self
,
ctx
,
input
,
output
,
input
s_desc
=
None
):
def
__init__
(
self
,
ctx
,
input
,
output
,
input
_signature
=
None
):
self
.
_ctx
=
ctx
self
.
_extra_tensor_names
=
{}
if
input
s_desc
is
not
None
:
assert
len
(
input
s_desc
)
==
len
(
input
)
if
input
_signature
is
not
None
:
assert
len
(
input
_signature
)
==
len
(
input
)
self
.
_extra_tensor_names
=
{
get_op_tensor_name
(
x
.
name
)[
1
]:
y
for
x
,
y
in
zip
(
input
s_desc
,
input
)}
get_op_tensor_name
(
x
.
name
)[
1
]:
y
for
x
,
y
in
zip
(
input
_signature
,
input
)}
self
.
_input
=
input
self
.
_output
=
output
...
...
@@ -379,7 +387,7 @@ class TowerTensorHandle(object):
1. The name of the tensor without any tower prefix.
2.
The name of an :class:`InputDesc`
, if it is used when building the tower.
2.
A name in the input signature
, if it is used when building the tower.
In the second case, this method will return the tensor that's used as the corresponding
input to the tower. Note that this tensor may have a different name (e.g. may be an output of a queue).
...
...
tensorpack/train/interface.py
View file @
b8a50d72
...
...
@@ -87,7 +87,7 @@ def launch_train_with_config(config, trainer):
# We should gradually stay away from this unuseful abstraction.
# TowerFuncWrapper is a better abstraction (similar to tf.defun in the future)
trainer
.
setup_graph
(
model
.
get_input
s_desc
(),
input
,
model
.
get_input
_signature
(),
input
,
model
.
_build_graph_get_cost
,
model
.
get_optimizer
)
_check_unused_regularization
()
trainer
.
train_with_defaults
(
...
...
tensorpack/train/tower.py
View file @
b8a50d72
...
...
@@ -56,11 +56,16 @@ class TowerTrainer(Trainer):
@
property
def
inputs_desc
(
self
):
# TODO mark deprecated
return
self
.
input_signature
@
property
def
input_signature
(
self
):
"""
Returns:
list[
InputDes
c]: metainfo about the inputs to the tower.
list[
tf.TensorSpe
c]: metainfo about the inputs to the tower.
"""
return
self
.
tower_func
.
input
s_desc
return
self
.
tower_func
.
input
_signature
@
property
def
towers
(
self
):
...
...
@@ -124,7 +129,7 @@ class TowerTrainer(Trainer):
if
tower
is
None
:
input
=
PlaceholderInput
()
input
.
setup
(
self
.
input
s_desc
)
input
.
setup
(
self
.
input
_signature
)
vs_name
=
self
.
_vs_name_for_predictor
(
device_id
)
with
tfv1
.
variable_scope
(
tfv1
.
get_variable_scope
(),
reuse
=
True
),
\
...
...
@@ -164,7 +169,7 @@ class SingleCostTrainer(TowerTrainer):
Base class for single-cost trainer.
Single-cost trainer has a :meth:`setup_graph` method which takes
(input
s_desc
, input, get_cost_fn, get_opt_fn), and build the training graph from them.
(input
_signature
, input, get_cost_fn, get_opt_fn), and build the training graph from them.
To use a :class:`SingleCostTrainer` object, call `trainer.setup_graph(...); trainer.train(...)`.
"""
...
...
@@ -194,12 +199,12 @@ class SingleCostTrainer(TowerTrainer):
"""
@
call_only_once
def
setup_graph
(
self
,
input
s_desc
,
input
,
get_cost_fn
,
get_opt_fn
):
def
setup_graph
(
self
,
input
_signature
,
input
,
get_cost_fn
,
get_opt_fn
):
"""
Responsible for building the main training graph for single-cost training.
Args:
input
s_desc ([InputDesc]):
input
_signature ([TensorSpec]): list of TensorSpec that describe the inputs
input (InputSource):
get_cost_fn ([tf.Tensor] -> tf.Tensor): callable, takes some input tensors and return a cost tensor.
get_opt_fn (-> tf.train.Optimizer): callable which returns an
...
...
@@ -210,12 +215,12 @@ class SingleCostTrainer(TowerTrainer):
It must follows the `rules of tower function.
<http://tensorpack.readthedocs.io/tutorial/trainer.html#tower-trainer>`_.
"""
get_cost_fn
=
TowerFuncWrapper
(
get_cost_fn
,
input
s_desc
)
get_cost_fn
=
TowerFuncWrapper
(
get_cost_fn
,
input
_signature
)
get_opt_fn
=
memoized
(
get_opt_fn
)
self
.
tower_func
=
get_cost_fn
# TODO setup may want to register monitor as well??
input_callbacks
=
self
.
_setup_input
(
input
s_desc
,
input
)
input_callbacks
=
self
.
_setup_input
(
input
_signature
,
input
)
train_callbacks
=
self
.
_setup_graph
(
input
,
get_cost_fn
,
get_opt_fn
)
self
.
register_callback
(
input_callbacks
+
train_callbacks
)
...
...
@@ -229,9 +234,9 @@ class SingleCostTrainer(TowerTrainer):
[Callback]: list of callbacks needed
"""
def
_setup_input
(
self
,
input
s_desc
,
input
):
def
_setup_input
(
self
,
input
_signature
,
input
):
assert
not
input
.
setup_done
()
return
input
.
setup
(
input
s_desc
)
return
input
.
setup
(
input
_signature
)
def
_make_get_grad_fn
(
self
,
input
,
get_cost_fn
,
get_opt_fn
):
"""
...
...
tensorpack/train/trainers.py
View file @
b8a50d72
...
...
@@ -272,7 +272,7 @@ class DistributedTrainerReplicated(DistributedTrainerBase):
self
.
_builder
=
DistributedReplicatedBuilder
(
gpus
,
server
)
self
.
is_chief
=
self
.
_builder
.
is_chief
def
_setup_input
(
self
,
input
s_desc
,
input
):
def
_setup_input
(
self
,
input
_signature
,
input
):
with
override_to_local_variable
():
get_global_step_var
()
# gs should be local
# input source may create variable (queue size summary)
...
...
@@ -280,7 +280,7 @@ class DistributedTrainerReplicated(DistributedTrainerBase):
# whether something should be global or local. We now assume
# they should be local.
assert
not
input
.
setup_done
()
return
input
.
setup
(
input
s_desc
)
return
input
.
setup
(
input
_signature
)
def
_setup_graph
(
self
,
input
,
get_cost_fn
,
get_opt_fn
):
assert
isinstance
(
input
,
FeedfreeInput
),
input
...
...
tensorpack/utils/concurrency.py
View file @
b8a50d72
...
...
@@ -132,8 +132,8 @@ class ShareSessionThread(threading.Thread):
yield
None
def
start
(
self
):
import
tensorflow
as
tf
self
.
_sess
=
tf
.
get_default_session
()
from
..compat
import
tfv1
self
.
_sess
=
tf
v1
.
get_default_session
()
super
(
ShareSessionThread
,
self
)
.
start
()
def
run
(
self
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment