Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
17b34c69
Commit
17b34c69
authored
Sep 01, 2019
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Use SmartInit globally - a simpler interface to initialization
parent
cbd698ad
Changes
43
Hide whitespace changes
Inline
Side-by-side
Showing
43 changed files
with
85 additions
and
105 deletions
+85
-105
docs/tutorial/save-load.md
docs/tutorial/save-load.md
+5
-5
examples/A3C-Gym/train-atari.py
examples/A3C-Gym/train-atari.py
+2
-2
examples/CTC-TIMIT/train-timit.py
examples/CTC-TIMIT/train-timit.py
+1
-2
examples/CaffeModels/load-alexnet.py
examples/CaffeModels/load-alexnet.py
+1
-3
examples/CaffeModels/load-cpm.py
examples/CaffeModels/load-cpm.py
+1
-2
examples/CaffeModels/load-vgg16.py
examples/CaffeModels/load-vgg16.py
+1
-1
examples/CaffeModels/load-vgg19.py
examples/CaffeModels/load-vgg19.py
+1
-1
examples/Char-RNN/char-rnn.py
examples/Char-RNN/char-rnn.py
+2
-3
examples/DeepQNetwork/DQN.py
examples/DeepQNetwork/DQN.py
+2
-3
examples/DoReFa-Net/alexnet-dorefa.py
examples/DoReFa-Net/alexnet-dorefa.py
+4
-5
examples/DoReFa-Net/resnet-dorefa.py
examples/DoReFa-Net/resnet-dorefa.py
+2
-2
examples/DynamicFilterNetwork/steering-filter.py
examples/DynamicFilterNetwork/steering-filter.py
+1
-2
examples/FasterRCNN/predict.py
examples/FasterRCNN/predict.py
+3
-3
examples/FasterRCNN/train.py
examples/FasterRCNN/train.py
+2
-2
examples/GAN/BEGAN.py
examples/GAN/BEGAN.py
+1
-1
examples/GAN/ConditionalGAN-mnist.py
examples/GAN/ConditionalGAN-mnist.py
+2
-2
examples/GAN/CycleGAN.py
examples/GAN/CycleGAN.py
+1
-1
examples/GAN/DCGAN.py
examples/GAN/DCGAN.py
+2
-2
examples/GAN/DiscoGAN-CelebA.py
examples/GAN/DiscoGAN-CelebA.py
+1
-1
examples/GAN/Image2Image.py
examples/GAN/Image2Image.py
+2
-2
examples/GAN/Improved-WGAN.py
examples/GAN/Improved-WGAN.py
+1
-1
examples/GAN/InfoGAN-mnist.py
examples/GAN/InfoGAN-mnist.py
+2
-2
examples/GAN/WGAN.py
examples/GAN/WGAN.py
+1
-1
examples/HED/hed.py
examples/HED/hed.py
+2
-3
examples/ImageNetModels/imagenet_utils.py
examples/ImageNetModels/imagenet_utils.py
+1
-1
examples/ImageNetModels/inception-bn.py
examples/ImageNetModels/inception-bn.py
+1
-2
examples/ImageNetModels/shufflenet.py
examples/ImageNetModels/shufflenet.py
+3
-4
examples/OpticalFlow/flownet2.py
examples/OpticalFlow/flownet2.py
+2
-2
examples/PennTreebank/PTB-LSTM.py
examples/PennTreebank/PTB-LSTM.py
+1
-2
examples/ResNet/cifar10-preact18-mixup.py
examples/ResNet/cifar10-preact18-mixup.py
+1
-1
examples/ResNet/cifar10-resnet.py
examples/ResNet/cifar10-resnet.py
+1
-1
examples/ResNet/imagenet-resnet.py
examples/ResNet/imagenet-resnet.py
+3
-4
examples/ResNet/load-resnet.py
examples/ResNet/load-resnet.py
+2
-2
examples/Saliency/CAM-resnet.py
examples/Saliency/CAM-resnet.py
+2
-3
examples/Saliency/saliency-maps.py
examples/Saliency/saliency-maps.py
+1
-1
examples/SimilarityLearning/mnist-embeddings.py
examples/SimilarityLearning/mnist-embeddings.py
+3
-5
examples/SpatialTransformer/mnist-addition.py
examples/SpatialTransformer/mnist-addition.py
+2
-3
examples/SuperResolution/enet-pat.py
examples/SuperResolution/enet-pat.py
+3
-3
examples/basics/cifar-convnet.py
examples/basics/cifar-convnet.py
+1
-2
examples/basics/export-model.py
examples/basics/export-model.py
+4
-4
examples/basics/svhn-digit-convnet.py
examples/basics/svhn-digit-convnet.py
+1
-1
examples/boilerplate.py
examples/boilerplate.py
+1
-3
tensorpack/tfutils/sessinit.py
tensorpack/tfutils/sessinit.py
+9
-9
No files found.
docs/tutorial/save-load.md
View file @
17b34c69
...
@@ -39,13 +39,13 @@ For inference, use `session_init` in `PredictConfig(...)`.
...
@@ -39,13 +39,13 @@ For inference, use `session_init` in `PredictConfig(...)`.
There are a few ways a session can be initialized:
There are a few ways a session can be initialized:
```
```
session_init=Smart
Restore
("path/to/checkpoint") # load a TF checkpoint
session_init=Smart
Init
("path/to/checkpoint") # load a TF checkpoint
session_init=Smart
Restore
("path/to/model_zoo.npz") # load tensorpack model zoo
session_init=Smart
Init
("path/to/model_zoo.npz") # load tensorpack model zoo
session_init=Smart
Restore
(dict_of_parameters) # load a dictionary
session_init=Smart
Init
(dict_of_parameters) # load a dictionary
session_init=Smart
Restore
(["path1", dict2]) # load them sequentially
session_init=Smart
Init
(["path1", dict2]) # load them sequentially
```
```
[
Smart
Restore
](
../modules/tfutils.html#tensorpack.tfutils.sessinit.SmartRestore
)
[
Smart
Init
](
../modules/tfutils.html#tensorpack.tfutils.sessinit.SmartInit
)
is in fact a small helper which uses some heuristics to return you one of
is in fact a small helper which uses some heuristics to return you one of
[
SaverRestore
](
../modules/tfutils.html#tensorpack.tfutils.sessinit.SaverRestore
)
or
[
SaverRestore
](
../modules/tfutils.html#tensorpack.tfutils.sessinit.SaverRestore
)
or
[
DictRestore
](
../modules/tfutils.html#tensorpack.tfutils.sessinit.DictRestore
)
.
[
DictRestore
](
../modules/tfutils.html#tensorpack.tfutils.sessinit.DictRestore
)
.
...
...
examples/A3C-Gym/train-atari.py
View file @
17b34c69
...
@@ -265,7 +265,7 @@ def train():
...
@@ -265,7 +265,7 @@ def train():
],
],
session_creator
=
sesscreate
.
NewSessionCreator
(
config
=
get_default_sess_config
(
0.5
)),
session_creator
=
sesscreate
.
NewSessionCreator
(
config
=
get_default_sess_config
(
0.5
)),
steps_per_epoch
=
STEPS_PER_EPOCH
,
steps_per_epoch
=
STEPS_PER_EPOCH
,
session_init
=
get_model_loader
(
args
.
load
)
if
args
.
load
else
None
,
session_init
=
SmartInit
(
args
.
load
)
,
max_epoch
=
1000
,
max_epoch
=
1000
,
)
)
trainer
=
SimpleTrainer
()
if
num_gpu
==
1
else
AsyncMultiGPUTrainer
(
train_tower
)
trainer
=
SimpleTrainer
()
if
num_gpu
==
1
else
AsyncMultiGPUTrainer
(
train_tower
)
...
@@ -294,7 +294,7 @@ if __name__ == '__main__':
...
@@ -294,7 +294,7 @@ if __name__ == '__main__':
assert
args
.
load
is
not
None
assert
args
.
load
is
not
None
pred
=
OfflinePredictor
(
PredictConfig
(
pred
=
OfflinePredictor
(
PredictConfig
(
model
=
Model
(),
model
=
Model
(),
session_init
=
get_model_loader
(
args
.
load
),
session_init
=
SmartInit
(
args
.
load
),
input_names
=
[
'state'
],
input_names
=
[
'state'
],
output_names
=
[
'policy'
]))
output_names
=
[
'policy'
]))
if
args
.
task
==
'play'
:
if
args
.
task
==
'play'
:
...
...
examples/CTC-TIMIT/train-timit.py
View file @
17b34c69
...
@@ -119,6 +119,5 @@ if __name__ == '__main__':
...
@@ -119,6 +119,5 @@ if __name__ == '__main__':
ds_test
=
get_data
(
args
.
test
,
False
,
args
.
stat
)
ds_test
=
get_data
(
args
.
test
,
False
,
args
.
stat
)
config
=
get_config
(
ds_train
,
ds_test
)
config
=
get_config
(
ds_train
,
ds_test
)
if
args
.
load
:
config
.
session_init
=
SmartInit
(
args
.
load
)
config
.
session_init
=
SaverRestore
(
args
.
load
)
launch_train_with_config
(
config
,
SimpleTrainer
())
launch_train_with_config
(
config
,
SimpleTrainer
())
examples/CaffeModels/load-alexnet.py
View file @
17b34c69
...
@@ -5,7 +5,6 @@
...
@@ -5,7 +5,6 @@
from
__future__
import
print_function
from
__future__
import
print_function
import
argparse
import
argparse
import
numpy
as
np
import
os
import
os
import
cv2
import
cv2
import
tensorflow
as
tf
import
tensorflow
as
tf
...
@@ -39,11 +38,10 @@ def tower_func(image):
...
@@ -39,11 +38,10 @@ def tower_func(image):
def
run_test
(
path
,
input
):
def
run_test
(
path
,
input
):
param_dict
=
dict
(
np
.
load
(
path
))
predictor
=
OfflinePredictor
(
PredictConfig
(
predictor
=
OfflinePredictor
(
PredictConfig
(
input_signature
=
[
tf
.
TensorSpec
((
None
,
227
,
227
,
3
),
tf
.
float32
,
'input'
)],
input_signature
=
[
tf
.
TensorSpec
((
None
,
227
,
227
,
3
),
tf
.
float32
,
'input'
)],
tower_func
=
tower_func
,
tower_func
=
tower_func
,
session_init
=
DictRestore
(
param_dict
),
session_init
=
SmartInit
(
path
),
input_names
=
[
'input'
],
input_names
=
[
'input'
],
output_names
=
[
'prob'
]
output_names
=
[
'prob'
]
))
))
...
...
examples/CaffeModels/load-cpm.py
View file @
17b34c69
...
@@ -95,11 +95,10 @@ def CPM(image):
...
@@ -95,11 +95,10 @@ def CPM(image):
def
run_test
(
model_path
,
img_file
):
def
run_test
(
model_path
,
img_file
):
param_dict
=
dict
(
np
.
load
(
model_path
))
predict_func
=
OfflinePredictor
(
PredictConfig
(
predict_func
=
OfflinePredictor
(
PredictConfig
(
input_signature
=
[
tf
.
TensorSpec
((
None
,
368
,
368
,
3
),
tf
.
float32
,
'input'
)],
input_signature
=
[
tf
.
TensorSpec
((
None
,
368
,
368
,
3
),
tf
.
float32
,
'input'
)],
tower_func
=
CPM
,
tower_func
=
CPM
,
session_init
=
DictRestore
(
param_dict
),
session_init
=
SmartInit
(
model_path
),
input_names
=
[
'input'
],
input_names
=
[
'input'
],
output_names
=
[
'resized_map'
]
output_names
=
[
'resized_map'
]
))
))
...
...
examples/CaffeModels/load-vgg16.py
View file @
17b34c69
...
@@ -61,7 +61,7 @@ def run_test(path, input):
...
@@ -61,7 +61,7 @@ def run_test(path, input):
predict_func
=
OfflinePredictor
(
PredictConfig
(
predict_func
=
OfflinePredictor
(
PredictConfig
(
input_signature
=
[
tf
.
TensorSpec
((
None
,
224
,
224
,
3
),
tf
.
float32
,
'input'
)],
input_signature
=
[
tf
.
TensorSpec
((
None
,
224
,
224
,
3
),
tf
.
float32
,
'input'
)],
tower_func
=
tower_func
,
tower_func
=
tower_func
,
session_init
=
DictRestore
(
param_dict
),
session_init
=
SmartInit
(
param_dict
),
input_names
=
[
'input'
],
input_names
=
[
'input'
],
output_names
=
[
'prob'
]
# prob:0 is the probability distribution
output_names
=
[
'prob'
]
# prob:0 is the probability distribution
))
))
...
...
examples/CaffeModels/load-vgg19.py
View file @
17b34c69
...
@@ -64,7 +64,7 @@ def run_test(path, input):
...
@@ -64,7 +64,7 @@ def run_test(path, input):
predict_func
=
OfflinePredictor
(
PredictConfig
(
predict_func
=
OfflinePredictor
(
PredictConfig
(
input_signature
=
[
tf
.
TensorSpec
((
None
,
224
,
224
,
3
),
tf
.
float32
,
'input'
)],
input_signature
=
[
tf
.
TensorSpec
((
None
,
224
,
224
,
3
),
tf
.
float32
,
'input'
)],
tower_func
=
tower_func
,
tower_func
=
tower_func
,
session_init
=
DictRestore
(
param_dict
),
session_init
=
SmartInit
(
param_dict
),
input_names
=
[
'input'
],
input_names
=
[
'input'
],
output_names
=
[
'prob'
]
# prob:0 is the probability distribution
output_names
=
[
'prob'
]
# prob:0 is the probability distribution
))
))
...
...
examples/Char-RNN/char-rnn.py
View file @
17b34c69
...
@@ -141,7 +141,7 @@ def sample(path, start, length):
...
@@ -141,7 +141,7 @@ def sample(path, start, length):
pred
=
OfflinePredictor
(
PredictConfig
(
pred
=
OfflinePredictor
(
PredictConfig
(
model
=
Model
(),
model
=
Model
(),
session_init
=
S
averRestore
(
path
),
session_init
=
S
martInit
(
path
),
input_names
=
[
'input'
,
'c0'
,
'h0'
,
'c1'
,
'h1'
],
input_names
=
[
'input'
,
'c0'
,
'h0'
,
'c1'
,
'h1'
],
output_names
=
[
'prob'
,
'last_state'
]))
output_names
=
[
'prob'
,
'last_state'
]))
...
@@ -193,6 +193,5 @@ if __name__ == '__main__':
...
@@ -193,6 +193,5 @@ if __name__ == '__main__':
else
:
else
:
param
.
corpus
=
args
.
corpus
param
.
corpus
=
args
.
corpus
config
=
get_config
()
config
=
get_config
()
if
args
.
load
:
config
.
session_init
=
SmartInit
(
args
.
load
)
config
.
session_init
=
SaverRestore
(
args
.
load
)
launch_train_with_config
(
config
,
SimpleTrainer
())
launch_train_with_config
(
config
,
SimpleTrainer
())
examples/DeepQNetwork/DQN.py
View file @
17b34c69
...
@@ -171,7 +171,7 @@ if __name__ == '__main__':
...
@@ -171,7 +171,7 @@ if __name__ == '__main__':
assert
args
.
load
is
not
None
assert
args
.
load
is
not
None
pred
=
OfflinePredictor
(
PredictConfig
(
pred
=
OfflinePredictor
(
PredictConfig
(
model
=
model
,
model
=
model
,
session_init
=
get_model_loader
(
args
.
load
),
session_init
=
SmartInit
(
args
.
load
),
input_names
=
[
'state'
],
input_names
=
[
'state'
],
output_names
=
[
'Qvalue'
]))
output_names
=
[
'Qvalue'
]))
if
args
.
task
==
'play'
:
if
args
.
task
==
'play'
:
...
@@ -183,6 +183,5 @@ if __name__ == '__main__':
...
@@ -183,6 +183,5 @@ if __name__ == '__main__':
os
.
path
.
join
(
'train_log'
,
'DQN-{}'
.
format
(
os
.
path
.
join
(
'train_log'
,
'DQN-{}'
.
format
(
os
.
path
.
basename
(
args
.
env
)
.
split
(
'.'
)[
0
])))
os
.
path
.
basename
(
args
.
env
)
.
split
(
'.'
)[
0
])))
config
=
get_config
(
model
)
config
=
get_config
(
model
)
if
args
.
load
:
config
.
session_init
=
SmartInit
(
args
.
load
)
config
.
session_init
=
get_model_loader
(
args
.
load
)
launch_train_with_config
(
config
,
SimpleTrainer
())
launch_train_with_config
(
config
,
SimpleTrainer
())
examples/DoReFa-Net/alexnet-dorefa.py
View file @
17b34c69
...
@@ -12,7 +12,7 @@ import tensorflow as tf
...
@@ -12,7 +12,7 @@ import tensorflow as tf
from
tensorpack
import
*
from
tensorpack
import
*
from
tensorpack.dataflow
import
dataset
from
tensorpack.dataflow
import
dataset
from
tensorpack.tfutils.sessinit
import
get_model_loader
from
tensorpack.tfutils.sessinit
import
SmartInit
from
tensorpack.tfutils.summary
import
add_param_summary
from
tensorpack.tfutils.summary
import
add_param_summary
from
tensorpack.tfutils.varreplace
import
remap_variables
from
tensorpack.tfutils.varreplace
import
remap_variables
from
tensorpack.utils.gpu
import
get_num_gpu
from
tensorpack.utils.gpu
import
get_num_gpu
...
@@ -214,12 +214,12 @@ if __name__ == '__main__':
...
@@ -214,12 +214,12 @@ if __name__ == '__main__':
if
args
.
run
:
if
args
.
run
:
assert
args
.
load
.
endswith
(
'.npz'
)
assert
args
.
load
.
endswith
(
'.npz'
)
run_image
(
Model
(),
DictRestore
(
dict
(
np
.
load
(
args
.
load
))
),
args
.
run
)
run_image
(
Model
(),
SmartInit
(
args
.
load
),
args
.
run
)
sys
.
exit
()
sys
.
exit
()
if
args
.
eval
:
if
args
.
eval
:
BATCH_SIZE
=
128
BATCH_SIZE
=
128
ds
=
get_data
(
'val'
)
ds
=
get_data
(
'val'
)
eval_classification
(
Model
(),
get_model_loader
(
args
.
load
),
ds
)
eval_classification
(
Model
(),
SmartInit
(
args
.
load
),
ds
)
sys
.
exit
()
sys
.
exit
()
nr_tower
=
max
(
get_num_gpu
(),
1
)
nr_tower
=
max
(
get_num_gpu
(),
1
)
...
@@ -229,6 +229,5 @@ if __name__ == '__main__':
...
@@ -229,6 +229,5 @@ if __name__ == '__main__':
logger
.
info
(
"Batch per tower: {}"
.
format
(
BATCH_SIZE
))
logger
.
info
(
"Batch per tower: {}"
.
format
(
BATCH_SIZE
))
config
=
get_config
()
config
=
get_config
()
if
args
.
load
:
config
.
session_init
=
SmartInit
(
args
.
load
)
config
.
session_init
=
SaverRestore
(
args
.
load
)
launch_train_with_config
(
config
,
SyncMultiGPUTrainerReplicated
(
nr_tower
))
launch_train_with_config
(
config
,
SyncMultiGPUTrainerReplicated
(
nr_tower
))
examples/DoReFa-Net/resnet-dorefa.py
View file @
17b34c69
...
@@ -163,7 +163,7 @@ if __name__ == '__main__':
...
@@ -163,7 +163,7 @@ if __name__ == '__main__':
ds
=
dataset
.
ILSVRC12
(
args
.
data
,
'val'
,
shuffle
=
False
)
ds
=
dataset
.
ILSVRC12
(
args
.
data
,
'val'
,
shuffle
=
False
)
ds
=
AugmentImageComponent
(
ds
,
get_inference_augmentor
())
ds
=
AugmentImageComponent
(
ds
,
get_inference_augmentor
())
ds
=
BatchData
(
ds
,
192
,
remainder
=
True
)
ds
=
BatchData
(
ds
,
192
,
remainder
=
True
)
eval_classification
(
Model
(),
get_model_loader
(
args
.
load
),
ds
)
eval_classification
(
Model
(),
SmartInit
(
args
.
load
),
ds
)
elif
args
.
run
:
elif
args
.
run
:
assert
args
.
load
.
endswith
(
'.npz'
)
assert
args
.
load
.
endswith
(
'.npz'
)
run_image
(
Model
(),
DictRestore
(
dict
(
np
.
load
(
args
.
load
))
),
args
.
run
)
run_image
(
Model
(),
SmartInit
(
args
.
load
),
args
.
run
)
examples/DynamicFilterNetwork/steering-filter.py
View file @
17b34c69
...
@@ -255,6 +255,5 @@ if __name__ == '__main__':
...
@@ -255,6 +255,5 @@ if __name__ == '__main__':
with
change_gpu
(
args
.
gpu
):
with
change_gpu
(
args
.
gpu
):
NGPU
=
len
(
args
.
gpu
.
split
(
','
))
NGPU
=
len
(
args
.
gpu
.
split
(
','
))
config
=
get_config
()
config
=
get_config
()
if
args
.
load
:
config
.
session_init
=
SmartInit
(
args
.
load
)
config
.
session_init
=
SaverRestore
(
args
.
load
)
launch_train_with_config
(
config
,
SyncMultiGPUTrainer
(
NGPU
))
launch_train_with_config
(
config
,
SyncMultiGPUTrainer
(
NGPU
))
examples/FasterRCNN/predict.py
View file @
17b34c69
...
@@ -14,7 +14,7 @@ assert six.PY3, "This example requires Python 3!"
...
@@ -14,7 +14,7 @@ assert six.PY3, "This example requires Python 3!"
import
tensorpack.utils.viz
as
tpviz
import
tensorpack.utils.viz
as
tpviz
from
tensorpack.predict
import
MultiTowerOfflinePredictor
,
OfflinePredictor
,
PredictConfig
from
tensorpack.predict
import
MultiTowerOfflinePredictor
,
OfflinePredictor
,
PredictConfig
from
tensorpack.tfutils
import
get_model_loader
,
get_tf_version_tuple
from
tensorpack.tfutils
import
SmartInit
,
get_tf_version_tuple
from
tensorpack.tfutils.export
import
ModelExporter
from
tensorpack.tfutils.export
import
ModelExporter
from
tensorpack.utils
import
fs
,
logger
from
tensorpack.utils
import
fs
,
logger
...
@@ -38,7 +38,7 @@ def do_visualize(model, model_path, nr_visualize=100, output_dir='output'):
...
@@ -38,7 +38,7 @@ def do_visualize(model, model_path, nr_visualize=100, output_dir='output'):
pred
=
OfflinePredictor
(
PredictConfig
(
pred
=
OfflinePredictor
(
PredictConfig
(
model
=
model
,
model
=
model
,
session_init
=
get_model_loader
(
model_path
),
session_init
=
SmartInit
(
model_path
),
input_names
=
[
'image'
,
'gt_boxes'
,
'gt_labels'
],
input_names
=
[
'image'
,
'gt_boxes'
,
'gt_labels'
],
output_names
=
[
output_names
=
[
'generate_{}_proposals/boxes'
.
format
(
'fpn'
if
cfg
.
MODE_FPN
else
'rpn'
),
'generate_{}_proposals/boxes'
.
format
(
'fpn'
if
cfg
.
MODE_FPN
else
'rpn'
),
...
@@ -146,7 +146,7 @@ if __name__ == '__main__':
...
@@ -146,7 +146,7 @@ if __name__ == '__main__':
else
:
else
:
predcfg
=
PredictConfig
(
predcfg
=
PredictConfig
(
model
=
MODEL
,
model
=
MODEL
,
session_init
=
get_model_loader
(
args
.
load
),
session_init
=
SmartInit
(
args
.
load
),
input_names
=
MODEL
.
get_inference_tensor_names
()[
0
],
input_names
=
MODEL
.
get_inference_tensor_names
()[
0
],
output_names
=
MODEL
.
get_inference_tensor_names
()[
1
])
output_names
=
MODEL
.
get_inference_tensor_names
()[
1
])
...
...
examples/FasterRCNN/train.py
View file @
17b34c69
...
@@ -103,9 +103,9 @@ if __name__ == '__main__':
...
@@ -103,9 +103,9 @@ if __name__ == '__main__':
else
:
else
:
if
args
.
load
:
if
args
.
load
:
# ignore mismatched values, so you can `--load` a model for fine-tuning
# ignore mismatched values, so you can `--load` a model for fine-tuning
session_init
=
Smart
Restore
(
args
.
load
,
ignore_mismatch
=
True
)
session_init
=
Smart
Init
(
args
.
load
,
ignore_mismatch
=
True
)
else
:
else
:
session_init
=
Smart
Restore
(
cfg
.
BACKBONE
.
WEIGHTS
)
session_init
=
Smart
Init
(
cfg
.
BACKBONE
.
WEIGHTS
)
traincfg
=
TrainConfig
(
traincfg
=
TrainConfig
(
model
=
MODEL
,
model
=
MODEL
,
...
...
examples/GAN/BEGAN.py
View file @
17b34c69
...
@@ -146,5 +146,5 @@ if __name__ == '__main__':
...
@@ -146,5 +146,5 @@ if __name__ == '__main__':
StatMonitorParamSetter
(
StatMonitorParamSetter
(
'learning_rate'
,
'losses/measure'
,
lambda
x
:
x
*
0.5
,
0
,
10
)
'learning_rate'
,
'losses/measure'
,
lambda
x
:
x
*
0.5
,
0
,
10
)
],
],
session_init
=
S
averRestore
(
args
.
load
)
if
args
.
load
else
None
,
session_init
=
S
martInit
(
args
.
load
)
,
steps_per_epoch
=
500
,
max_epoch
=
400
)
steps_per_epoch
=
500
,
max_epoch
=
400
)
examples/GAN/ConditionalGAN-mnist.py
View file @
17b34c69
...
@@ -114,7 +114,7 @@ def get_data():
...
@@ -114,7 +114,7 @@ def get_data():
def
sample
(
model_path
):
def
sample
(
model_path
):
pred
=
PredictConfig
(
pred
=
PredictConfig
(
session_init
=
get_model_loader
(
model_path
),
session_init
=
SmartInit
(
model_path
),
model
=
Model
(),
model
=
Model
(),
input_names
=
[
'label'
,
'z'
],
input_names
=
[
'label'
,
'z'
],
output_names
=
[
'gen/gen'
])
output_names
=
[
'gen/gen'
])
...
@@ -145,5 +145,5 @@ if __name__ == '__main__':
...
@@ -145,5 +145,5 @@ if __name__ == '__main__':
callbacks
=
[
ModelSaver
()],
callbacks
=
[
ModelSaver
()],
steps_per_epoch
=
500
,
steps_per_epoch
=
500
,
max_epoch
=
100
,
max_epoch
=
100
,
session_init
=
S
averRestore
(
args
.
load
)
if
args
.
load
else
None
session_init
=
S
martInit
(
args
.
load
),
)
)
examples/GAN/CycleGAN.py
View file @
17b34c69
...
@@ -224,5 +224,5 @@ if __name__ == '__main__':
...
@@ -224,5 +224,5 @@ if __name__ == '__main__':
],
],
max_epoch
=
195
,
max_epoch
=
195
,
steps_per_epoch
=
data
.
size
(),
steps_per_epoch
=
data
.
size
(),
session_init
=
S
averRestore
(
args
.
load
)
if
args
.
load
else
None
session_init
=
S
martInit
(
args
.
load
)
)
)
examples/GAN/DCGAN.py
View file @
17b34c69
...
@@ -121,7 +121,7 @@ def get_data():
...
@@ -121,7 +121,7 @@ def get_data():
def
sample
(
model
,
model_path
,
output_name
=
'gen/gen'
):
def
sample
(
model
,
model_path
,
output_name
=
'gen/gen'
):
pred
=
PredictConfig
(
pred
=
PredictConfig
(
session_init
=
get_model_loader
(
model_path
),
session_init
=
SmartInit
(
model_path
),
model
=
model
,
model
=
model
,
input_names
=
[
'z'
],
input_names
=
[
'z'
],
output_names
=
[
output_name
,
'z'
])
output_names
=
[
output_name
,
'z'
])
...
@@ -167,5 +167,5 @@ if __name__ == '__main__':
...
@@ -167,5 +167,5 @@ if __name__ == '__main__':
callbacks
=
[
ModelSaver
()],
callbacks
=
[
ModelSaver
()],
steps_per_epoch
=
300
,
steps_per_epoch
=
300
,
max_epoch
=
200
,
max_epoch
=
200
,
session_init
=
S
averRestore
(
args
.
load
)
if
args
.
load
else
None
session_init
=
S
martInit
(
args
.
load
),
)
)
examples/GAN/DiscoGAN-CelebA.py
View file @
17b34c69
...
@@ -211,5 +211,5 @@ if __name__ == '__main__':
...
@@ -211,5 +211,5 @@ if __name__ == '__main__':
callbacks
=
[
ModelSaver
()],
callbacks
=
[
ModelSaver
()],
steps_per_epoch
=
300
,
steps_per_epoch
=
300
,
max_epoch
=
250
,
max_epoch
=
250
,
session_init
=
S
averRestore
(
args
.
load
)
if
args
.
load
else
None
session_init
=
S
martInit
(
args
.
load
),
)
)
examples/GAN/Image2Image.py
View file @
17b34c69
...
@@ -179,7 +179,7 @@ def get_data():
...
@@ -179,7 +179,7 @@ def get_data():
def
sample
(
datadir
,
model_path
):
def
sample
(
datadir
,
model_path
):
pred
=
PredictConfig
(
pred
=
PredictConfig
(
session_init
=
get_model_loader
(
model_path
),
session_init
=
SmartInit
(
model_path
),
model
=
Model
(),
model
=
Model
(),
input_names
=
[
'input'
,
'output'
],
input_names
=
[
'input'
,
'output'
],
output_names
=
[
'viz'
])
output_names
=
[
'viz'
])
...
@@ -226,5 +226,5 @@ if __name__ == '__main__':
...
@@ -226,5 +226,5 @@ if __name__ == '__main__':
],
],
steps_per_epoch
=
data
.
size
(),
steps_per_epoch
=
data
.
size
(),
max_epoch
=
300
,
max_epoch
=
300
,
session_init
=
S
averRestore
(
args
.
load
)
if
args
.
load
else
None
session_init
=
S
martInit
(
args
.
load
)
)
)
examples/GAN/Improved-WGAN.py
View file @
17b34c69
...
@@ -97,5 +97,5 @@ if __name__ == '__main__':
...
@@ -97,5 +97,5 @@ if __name__ == '__main__':
callbacks
=
[
ModelSaver
()],
callbacks
=
[
ModelSaver
()],
steps_per_epoch
=
300
,
steps_per_epoch
=
300
,
max_epoch
=
200
,
max_epoch
=
200
,
session_init
=
S
averRestore
(
args
.
load
)
if
args
.
load
else
None
session_init
=
S
martInit
(
args
.
load
)
)
)
examples/GAN/InfoGAN-mnist.py
View file @
17b34c69
...
@@ -218,7 +218,7 @@ def get_data():
...
@@ -218,7 +218,7 @@ def get_data():
def
sample
(
model_path
):
def
sample
(
model_path
):
pred
=
OfflinePredictor
(
PredictConfig
(
pred
=
OfflinePredictor
(
PredictConfig
(
session_init
=
get_model_loader
(
model_path
),
session_init
=
SmartInit
(
model_path
),
model
=
Model
(),
model
=
Model
(),
input_names
=
[
'z_code'
,
'z_noise'
],
input_names
=
[
'z_code'
,
'z_noise'
],
output_names
=
[
'gen/viz'
]))
output_names
=
[
'gen/viz'
]))
...
@@ -276,5 +276,5 @@ if __name__ == '__main__':
...
@@ -276,5 +276,5 @@ if __name__ == '__main__':
callbacks
=
[
ModelSaver
(
keep_checkpoint_every_n_hours
=
0.1
)],
callbacks
=
[
ModelSaver
(
keep_checkpoint_every_n_hours
=
0.1
)],
steps_per_epoch
=
500
,
steps_per_epoch
=
500
,
max_epoch
=
100
,
max_epoch
=
100
,
session_init
=
S
averRestore
(
args
.
load
)
if
args
.
load
else
None
session_init
=
S
martInit
(
args
.
load
)
)
)
examples/GAN/WGAN.py
View file @
17b34c69
...
@@ -80,5 +80,5 @@ if __name__ == '__main__':
...
@@ -80,5 +80,5 @@ if __name__ == '__main__':
callbacks
=
[
ModelSaver
(),
ClipCallback
()],
callbacks
=
[
ModelSaver
(),
ClipCallback
()],
steps_per_epoch
=
500
,
steps_per_epoch
=
500
,
max_epoch
=
200
,
max_epoch
=
200
,
session_init
=
S
averRestore
(
args
.
load
)
if
args
.
load
else
None
session_init
=
S
martInit
(
args
.
load
)
)
)
examples/HED/hed.py
View file @
17b34c69
...
@@ -271,7 +271,7 @@ def get_config():
...
@@ -271,7 +271,7 @@ def get_config():
def
run
(
model_path
,
image_path
,
output
):
def
run
(
model_path
,
image_path
,
output
):
pred_config
=
PredictConfig
(
pred_config
=
PredictConfig
(
model
=
Model
(),
model
=
Model
(),
session_init
=
get_model_loader
(
model_path
),
session_init
=
SmartInit
(
model_path
),
input_names
=
[
'image'
],
input_names
=
[
'image'
],
output_names
=
[
'output'
+
str
(
k
)
for
k
in
range
(
1
,
7
)])
output_names
=
[
'output'
+
str
(
k
)
for
k
in
range
(
1
,
7
)])
predictor
=
OfflinePredictor
(
pred_config
)
predictor
=
OfflinePredictor
(
pred_config
)
...
@@ -309,8 +309,7 @@ if __name__ == '__main__':
...
@@ -309,8 +309,7 @@ if __name__ == '__main__':
run
(
args
.
load
,
args
.
run
,
args
.
output
)
run
(
args
.
load
,
args
.
run
,
args
.
output
)
else
:
else
:
config
=
get_config
()
config
=
get_config
()
if
args
.
load
:
config
.
session_init
=
SmartInit
(
args
.
load
)
config
.
session_init
=
get_model_loader
(
args
.
load
)
launch_train_with_config
(
launch_train_with_config
(
config
,
config
,
SyncMultiGPUTrainer
(
max
(
get_num_gpu
(),
1
)))
SyncMultiGPUTrainer
(
max
(
get_num_gpu
(),
1
)))
examples/ImageNetModels/imagenet_utils.py
View file @
17b34c69
...
@@ -431,7 +431,7 @@ class ImageNetModel(ModelDesc):
...
@@ -431,7 +431,7 @@ class ImageNetModel(ModelDesc):
Examples:
Examples:
pred = OfflinePredictor(model.create_predict_config(
get_model_loader
(args.load)))
pred = OfflinePredictor(model.create_predict_config(
SmartInit
(args.load)))
prob = pred(NCHW_image)[0] # Nx1000 probabilities
prob = pred(NCHW_image)[0] # Nx1000 probabilities
"""
"""
return
PredictConfig
(
model
=
self
,
input_names
=
[
'input'
],
output_names
=
[
'prob'
],
session_init
=
session_init
)
return
PredictConfig
(
model
=
self
,
input_names
=
[
'input'
],
output_names
=
[
'prob'
],
session_init
=
session_init
)
...
...
examples/ImageNetModels/inception-bn.py
View file @
17b34c69
...
@@ -166,8 +166,7 @@ if __name__ == '__main__':
...
@@ -166,8 +166,7 @@ if __name__ == '__main__':
os
.
environ
[
'CUDA_VISIBLE_DEVICES'
]
=
args
.
gpu
os
.
environ
[
'CUDA_VISIBLE_DEVICES'
]
=
args
.
gpu
config
=
get_config
()
config
=
get_config
()
if
args
.
load
:
config
.
session_init
=
SmartInit
(
args
.
load
)
config
.
session_init
=
SaverRestore
(
args
.
load
)
nr_tower
=
get_num_gpu
()
nr_tower
=
get_num_gpu
()
assert
nr_tower
==
NUM_GPU
assert
nr_tower
==
NUM_GPU
launch_train_with_config
(
config
,
SyncMultiGPUTrainer
(
NUM_GPU
))
launch_train_with_config
(
config
,
SyncMultiGPUTrainer
(
NUM_GPU
))
examples/ImageNetModels/shufflenet.py
View file @
17b34c69
...
@@ -11,7 +11,7 @@ import tensorflow as tf
...
@@ -11,7 +11,7 @@ import tensorflow as tf
from
tensorpack
import
*
from
tensorpack
import
*
from
tensorpack.dataflow
import
imgaug
from
tensorpack.dataflow
import
imgaug
from
tensorpack.tfutils
import
argscope
,
get_model_loader
,
model_utils
from
tensorpack.tfutils
import
argscope
,
SmartInit
,
model_utils
from
tensorpack.tfutils.scope_utils
import
under_name_scope
from
tensorpack.tfutils.scope_utils
import
under_name_scope
from
tensorpack.utils
import
logger
from
tensorpack.utils
import
logger
from
tensorpack.utils.gpu
import
get_num_gpu
from
tensorpack.utils.gpu
import
get_num_gpu
...
@@ -251,7 +251,7 @@ if __name__ == '__main__':
...
@@ -251,7 +251,7 @@ if __name__ == '__main__':
if
args
.
eval
:
if
args
.
eval
:
batch
=
128
# something that can run on one gpu
batch
=
128
# something that can run on one gpu
ds
=
get_data
(
'val'
,
batch
)
ds
=
get_data
(
'val'
,
batch
)
eval_classification
(
model
,
get_model_loader
(
args
.
load
),
ds
)
eval_classification
(
model
,
SmartInit
(
args
.
load
),
ds
)
elif
args
.
flops
:
elif
args
.
flops
:
# manually build the graph with batch=1
# manually build the graph with batch=1
with
TowerContext
(
''
,
is_training
=
False
):
with
TowerContext
(
''
,
is_training
=
False
):
...
@@ -277,6 +277,5 @@ if __name__ == '__main__':
...
@@ -277,6 +277,5 @@ if __name__ == '__main__':
nr_tower
=
max
(
get_num_gpu
(),
1
)
nr_tower
=
max
(
get_num_gpu
(),
1
)
config
=
get_config
(
model
,
nr_tower
)
config
=
get_config
(
model
,
nr_tower
)
if
args
.
load
:
config
.
session_init
=
SmartInit
(
args
.
load
)
config
.
session_init
=
get_model_loader
(
args
.
load
)
launch_train_with_config
(
config
,
SyncMultiGPUTrainerParameterServer
(
nr_tower
))
launch_train_with_config
(
config
,
SyncMultiGPUTrainerParameterServer
(
nr_tower
))
examples/OpticalFlow/flownet2.py
View file @
17b34c69
...
@@ -24,7 +24,7 @@ def apply(model, model_path, images, ground_truth=None):
...
@@ -24,7 +24,7 @@ def apply(model, model_path, images, ground_truth=None):
predict_func
=
OfflinePredictor
(
PredictConfig
(
predict_func
=
OfflinePredictor
(
PredictConfig
(
model
=
model
(
height
=
newh
,
width
=
neww
),
model
=
model
(
height
=
newh
,
width
=
neww
),
session_init
=
get_model_loader
(
model_path
),
session_init
=
SmartInit
(
model_path
),
input_names
=
[
'left'
,
'right'
],
input_names
=
[
'left'
,
'right'
],
output_names
=
[
'prediction'
]))
output_names
=
[
'prediction'
]))
...
@@ -102,7 +102,7 @@ def inference(model, model_path, sintel_path):
...
@@ -102,7 +102,7 @@ def inference(model, model_path, sintel_path):
pred
=
PredictConfig
(
pred
=
PredictConfig
(
model
=
model
(
height
=
h
,
width
=
w
),
model
=
model
(
height
=
h
,
width
=
w
),
session_init
=
get_model_loader
(
model_path
),
session_init
=
SmartInit
(
model_path
),
input_names
=
[
'left'
,
'right'
,
'gt_flow'
],
input_names
=
[
'left'
,
'right'
,
'gt_flow'
],
output_names
=
[
'epe'
,
'prediction'
])
output_names
=
[
'epe'
,
'prediction'
])
pred
=
SimpleDatasetPredictor
(
pred
,
ds
)
pred
=
SimpleDatasetPredictor
(
pred
,
ds
)
...
...
examples/PennTreebank/PTB-LSTM.py
View file @
17b34c69
...
@@ -174,6 +174,5 @@ if __name__ == '__main__':
...
@@ -174,6 +174,5 @@ if __name__ == '__main__':
os
.
environ
[
'CUDA_VISIBLE_DEVICES'
]
=
str
(
args
.
gpu
)
os
.
environ
[
'CUDA_VISIBLE_DEVICES'
]
=
str
(
args
.
gpu
)
config
=
get_config
()
config
=
get_config
()
if
args
.
load
:
config
.
session_init
=
SmartInit
(
args
.
load
)
config
.
session_init
=
SaverRestore
(
args
.
load
)
launch_train_with_config
(
config
,
SimpleTrainer
())
launch_train_with_config
(
config
,
SimpleTrainer
())
examples/ResNet/cifar10-preact18-mixup.py
View file @
17b34c69
...
@@ -149,6 +149,6 @@ if __name__ == '__main__':
...
@@ -149,6 +149,6 @@ if __name__ == '__main__':
],
],
max_epoch
=
200
,
max_epoch
=
200
,
steps_per_epoch
=
len
(
dataset_train
),
steps_per_epoch
=
len
(
dataset_train
),
session_init
=
S
averRestore
(
args
.
load
)
if
args
.
load
else
None
session_init
=
S
martInit
(
args
.
load
)
)
)
launch_train_with_config
(
config
,
SimpleTrainer
())
launch_train_with_config
(
config
,
SimpleTrainer
())
examples/ResNet/cifar10-resnet.py
View file @
17b34c69
...
@@ -166,7 +166,7 @@ if __name__ == '__main__':
...
@@ -166,7 +166,7 @@ if __name__ == '__main__':
[(
1
,
0.1
),
(
82
,
0.01
),
(
123
,
0.001
),
(
300
,
0.0002
)])
[(
1
,
0.1
),
(
82
,
0.01
),
(
123
,
0.001
),
(
300
,
0.0002
)])
],
],
max_epoch
=
400
,
max_epoch
=
400
,
session_init
=
S
averRestore
(
args
.
load
)
if
args
.
load
else
None
session_init
=
S
martInit
(
args
.
load
),
)
)
num_gpu
=
max
(
get_num_gpu
(),
1
)
num_gpu
=
max
(
get_num_gpu
(),
1
)
launch_train_with_config
(
config
,
SyncMultiGPUTrainerParameterServer
(
num_gpu
))
launch_train_with_config
(
config
,
SyncMultiGPUTrainerParameterServer
(
num_gpu
))
examples/ResNet/imagenet-resnet.py
View file @
17b34c69
...
@@ -9,7 +9,7 @@ from tensorpack import QueueInput, TFDatasetInput, logger
...
@@ -9,7 +9,7 @@ from tensorpack import QueueInput, TFDatasetInput, logger
from
tensorpack.callbacks
import
*
from
tensorpack.callbacks
import
*
from
tensorpack.dataflow
import
FakeData
from
tensorpack.dataflow
import
FakeData
from
tensorpack.models
import
*
from
tensorpack.models
import
*
from
tensorpack.tfutils
import
argscope
,
get_model_loader
from
tensorpack.tfutils
import
argscope
,
SmartInit
from
tensorpack.train
import
SyncMultiGPUTrainerReplicated
,
TrainConfig
,
launch_train_with_config
from
tensorpack.train
import
SyncMultiGPUTrainerReplicated
,
TrainConfig
,
launch_train_with_config
from
tensorpack.utils.gpu
import
get_num_gpu
from
tensorpack.utils.gpu
import
get_num_gpu
...
@@ -136,7 +136,7 @@ if __name__ == '__main__':
...
@@ -136,7 +136,7 @@ if __name__ == '__main__':
if
args
.
eval
:
if
args
.
eval
:
batch
=
128
# something that can run on one gpu
batch
=
128
# something that can run on one gpu
ds
=
get_imagenet_dataflow
(
args
.
data
,
'val'
,
batch
)
ds
=
get_imagenet_dataflow
(
args
.
data
,
'val'
,
batch
)
eval_classification
(
model
,
get_model_loader
(
args
.
load
),
ds
)
eval_classification
(
model
,
SmartInit
(
args
.
load
),
ds
)
else
:
else
:
if
args
.
fake
:
if
args
.
fake
:
logger
.
set_logger_dir
(
os
.
path
.
join
(
'train_log'
,
'tmp'
),
'd'
)
logger
.
set_logger_dir
(
os
.
path
.
join
(
'train_log'
,
'tmp'
),
'd'
)
...
@@ -147,7 +147,6 @@ if __name__ == '__main__':
...
@@ -147,7 +147,6 @@ if __name__ == '__main__':
args
.
mode
,
args
.
depth
,
args
.
batch
)))
args
.
mode
,
args
.
depth
,
args
.
batch
)))
config
=
get_config
(
model
)
config
=
get_config
(
model
)
if
args
.
load
:
config
.
session_init
=
SmartInit
(
args
.
load
)
config
.
session_init
=
get_model_loader
(
args
.
load
)
trainer
=
SyncMultiGPUTrainerReplicated
(
max
(
get_num_gpu
(),
1
))
trainer
=
SyncMultiGPUTrainerReplicated
(
max
(
get_num_gpu
(),
1
))
launch_train_with_config
(
config
,
trainer
)
launch_train_with_config
(
config
,
trainer
)
examples/ResNet/load-resnet.py
View file @
17b34c69
...
@@ -79,7 +79,7 @@ def get_inference_augmentor():
...
@@ -79,7 +79,7 @@ def get_inference_augmentor():
def
run_test
(
params
,
input
):
def
run_test
(
params
,
input
):
pred_config
=
PredictConfig
(
pred_config
=
PredictConfig
(
model
=
Model
(),
model
=
Model
(),
session_init
=
DictRestore
(
params
),
session_init
=
SmartInit
(
params
),
input_names
=
[
'input'
],
input_names
=
[
'input'
],
output_names
=
[
'prob'
]
output_names
=
[
'prob'
]
)
)
...
@@ -172,6 +172,6 @@ if __name__ == '__main__':
...
@@ -172,6 +172,6 @@ if __name__ == '__main__':
if
args
.
eval
:
if
args
.
eval
:
ds
=
get_imagenet_dataflow
(
args
.
eval
,
'val'
,
128
,
get_inference_augmentor
())
ds
=
get_imagenet_dataflow
(
args
.
eval
,
'val'
,
128
,
get_inference_augmentor
())
eval_classification
(
Model
(),
Dic
tRestore
(
param
),
ds
)
eval_classification
(
Model
(),
Smar
tRestore
(
param
),
ds
)
elif
args
.
input
:
elif
args
.
input
:
run_test
(
param
,
args
.
input
)
run_test
(
param
,
args
.
input
)
examples/Saliency/CAM-resnet.py
View file @
17b34c69
...
@@ -97,7 +97,7 @@ def viz_cam(model_file, data_dir):
...
@@ -97,7 +97,7 @@ def viz_cam(model_file, data_dir):
ds
=
get_data
(
'val'
)
ds
=
get_data
(
'val'
)
pred_config
=
PredictConfig
(
pred_config
=
PredictConfig
(
model
=
Model
(),
model
=
Model
(),
session_init
=
get_model_loader
(
model_file
),
session_init
=
SmartInit
(
model_file
),
input_names
=
[
'input'
,
'label'
],
input_names
=
[
'input'
,
'label'
],
output_names
=
[
'wrong-top1'
,
'group3new/bnlast/Relu'
,
'linearnew/W'
],
output_names
=
[
'wrong-top1'
,
'group3new/bnlast/Relu'
,
'linearnew/W'
],
return_input
=
True
return_input
=
True
...
@@ -151,6 +151,5 @@ if __name__ == '__main__':
...
@@ -151,6 +151,5 @@ if __name__ == '__main__':
logger
.
auto_set_dir
()
logger
.
auto_set_dir
()
config
=
get_config
()
config
=
get_config
()
if
args
.
load
:
config
.
session_init
=
SmartInit
(
args
.
load
)
config
.
session_init
=
get_model_loader
(
args
.
load
)
launch_train_with_config
(
config
,
SyncMultiGPUTrainerParameterServer
(
num_gpu
))
launch_train_with_config
(
config
,
SyncMultiGPUTrainerParameterServer
(
num_gpu
))
examples/Saliency/saliency-maps.py
View file @
17b34c69
...
@@ -68,7 +68,7 @@ class Model(tp.ModelDescBase):
...
@@ -68,7 +68,7 @@ class Model(tp.ModelDescBase):
def
run
(
model_path
,
image_path
):
def
run
(
model_path
,
image_path
):
predictor
=
tp
.
OfflinePredictor
(
tp
.
PredictConfig
(
predictor
=
tp
.
OfflinePredictor
(
tp
.
PredictConfig
(
model
=
Model
(),
model
=
Model
(),
session_init
=
tp
.
get_model_loader
(
model_path
),
session_init
=
tp
.
SmartInit
(
model_path
),
input_names
=
[
'image'
],
input_names
=
[
'image'
],
output_names
=
[
'saliency'
]))
output_names
=
[
'saliency'
]))
im
=
cv2
.
imread
(
image_path
)
im
=
cv2
.
imread
(
image_path
)
...
...
examples/SimilarityLearning/mnist-embeddings.py
View file @
17b34c69
...
@@ -364,7 +364,7 @@ def visualize(model_path, model, algo_name):
...
@@ -364,7 +364,7 @@ def visualize(model_path, model, algo_name):
logger
.
error
(
"visualize requires matplotlib package ..."
)
logger
.
error
(
"visualize requires matplotlib package ..."
)
return
return
pred
=
OfflinePredictor
(
PredictConfig
(
pred
=
OfflinePredictor
(
PredictConfig
(
session_init
=
get_model_loader
(
model_path
),
session_init
=
SmartInit
(
model_path
),
model
=
model
(),
model
=
model
(),
input_names
=
[
'input'
],
input_names
=
[
'input'
],
output_names
=
[
'emb'
]))
output_names
=
[
'emb'
]))
...
@@ -432,7 +432,5 @@ if __name__ == '__main__':
...
@@ -432,7 +432,5 @@ if __name__ == '__main__':
visualize
(
args
.
load
,
ALGO_CONFIGS
[
args
.
algorithm
],
args
.
algorithm
)
visualize
(
args
.
load
,
ALGO_CONFIGS
[
args
.
algorithm
],
args
.
algorithm
)
else
:
else
:
config
=
get_config
(
ALGO_CONFIGS
[
args
.
algorithm
],
args
.
algorithm
)
config
=
get_config
(
ALGO_CONFIGS
[
args
.
algorithm
],
args
.
algorithm
)
if
args
.
load
:
config
.
session_init
=
SmartInit
(
args
.
load
)
config
.
session_init
=
SaverRestore
(
args
.
load
)
launch_train_with_config
(
config
,
SimpleTrainer
())
else
:
launch_train_with_config
(
config
,
SimpleTrainer
())
examples/SpatialTransformer/mnist-addition.py
View file @
17b34c69
...
@@ -201,7 +201,7 @@ def get_data(isTrain):
...
@@ -201,7 +201,7 @@ def get_data(isTrain):
def
view_warp
(
modelpath
):
def
view_warp
(
modelpath
):
pred
=
OfflinePredictor
(
PredictConfig
(
pred
=
OfflinePredictor
(
PredictConfig
(
session_init
=
get_model_loader
(
modelpath
),
session_init
=
SmartInit
(
modelpath
),
model
=
Model
(),
model
=
Model
(),
input_names
=
[
'input'
],
input_names
=
[
'input'
],
output_names
=
[
'visualization/viz'
,
'STN1/affine'
,
'STN2/affine'
]))
output_names
=
[
'visualization/viz'
,
'STN1/affine'
,
'STN2/affine'
]))
...
@@ -265,6 +265,5 @@ if __name__ == '__main__':
...
@@ -265,6 +265,5 @@ if __name__ == '__main__':
view_warp
(
args
.
load
)
view_warp
(
args
.
load
)
else
:
else
:
config
=
get_config
()
config
=
get_config
()
if
args
.
load
:
config
.
session_init
=
SmartInit
(
args
.
load
)
config
.
session_init
=
SaverRestore
(
args
.
load
)
launch_train_with_config
(
config
,
SimpleTrainer
())
launch_train_with_config
(
config
,
SimpleTrainer
())
examples/SuperResolution/enet-pat.py
View file @
17b34c69
...
@@ -227,7 +227,7 @@ def apply(model_path, lowres_path="", output_path='.'):
...
@@ -227,7 +227,7 @@ def apply(model_path, lowres_path="", output_path='.'):
predict_func
=
OfflinePredictor
(
PredictConfig
(
predict_func
=
OfflinePredictor
(
PredictConfig
(
model
=
Model
(
LR_SIZE_H
,
LR_SIZE_W
),
model
=
Model
(
LR_SIZE_H
,
LR_SIZE_W
),
session_init
=
get_model_loader
(
model_path
),
session_init
=
SmartInit
(
model_path
),
input_names
=
[
'Ilr'
],
input_names
=
[
'Ilr'
],
output_names
=
[
'prediction'
]))
output_names
=
[
'prediction'
]))
...
@@ -279,12 +279,12 @@ if __name__ == '__main__':
...
@@ -279,12 +279,12 @@ if __name__ == '__main__':
logger
.
auto_set_dir
()
logger
.
auto_set_dir
()
if
args
.
load
:
if
args
.
load
:
session_init
=
S
averRestore
(
args
.
load
)
session_init
=
S
martInit
(
args
.
load
)
else
:
else
:
assert
os
.
path
.
isfile
(
args
.
vgg19
)
assert
os
.
path
.
isfile
(
args
.
vgg19
)
param_dict
=
dict
(
np
.
load
(
args
.
vgg19
))
param_dict
=
dict
(
np
.
load
(
args
.
vgg19
))
param_dict
=
{
'VGG19/'
+
name
:
value
for
name
,
value
in
six
.
iteritems
(
param_dict
)}
param_dict
=
{
'VGG19/'
+
name
:
value
for
name
,
value
in
six
.
iteritems
(
param_dict
)}
session_init
=
DictRestore
(
param_dict
)
session_init
=
SmartInit
(
param_dict
)
nr_tower
=
max
(
get_num_gpu
(),
1
)
nr_tower
=
max
(
get_num_gpu
(),
1
)
data
=
QueueInput
(
get_data
(
args
.
data
))
data
=
QueueInput
(
get_data
(
args
.
data
))
...
...
examples/basics/cifar-convnet.py
View file @
17b34c69
...
@@ -143,8 +143,7 @@ if __name__ == '__main__':
...
@@ -143,8 +143,7 @@ if __name__ == '__main__':
with
tf
.
Graph
()
.
as_default
():
with
tf
.
Graph
()
.
as_default
():
logger
.
set_logger_dir
(
os
.
path
.
join
(
'train_log'
,
'cifar'
+
str
(
args
.
classnum
)))
logger
.
set_logger_dir
(
os
.
path
.
join
(
'train_log'
,
'cifar'
+
str
(
args
.
classnum
)))
config
=
get_config
(
args
.
classnum
)
config
=
get_config
(
args
.
classnum
)
if
args
.
load
:
config
.
session_init
=
SmartInit
(
args
.
load
)
config
.
session_init
=
SaverRestore
(
args
.
load
)
num_gpu
=
get_num_gpu
()
num_gpu
=
get_num_gpu
()
trainer
=
SimpleTrainer
()
if
num_gpu
<=
1
\
trainer
=
SimpleTrainer
()
if
num_gpu
<=
1
\
...
...
examples/basics/export-model.py
View file @
17b34c69
...
@@ -106,7 +106,7 @@ class InferenceOnlyModel(Model):
...
@@ -106,7 +106,7 @@ class InferenceOnlyModel(Model):
def
export_serving
(
model_path
):
def
export_serving
(
model_path
):
"""Export trained model to use it in TensorFlow Serving or cloudML. """
"""Export trained model to use it in TensorFlow Serving or cloudML. """
pred_config
=
PredictConfig
(
pred_config
=
PredictConfig
(
session_init
=
get_model_loader
(
model_path
),
session_init
=
SmartInit
(
model_path
),
model
=
InferenceOnlyModel
(),
model
=
InferenceOnlyModel
(),
input_names
=
[
'input_img_bytes'
],
input_names
=
[
'input_img_bytes'
],
output_names
=
[
'prediction_img_bytes'
])
output_names
=
[
'prediction_img_bytes'
])
...
@@ -117,7 +117,7 @@ def export_compact(model_path):
...
@@ -117,7 +117,7 @@ def export_compact(model_path):
"""Export trained model to use it as a frozen and pruned inference graph in
"""Export trained model to use it as a frozen and pruned inference graph in
mobile applications. """
mobile applications. """
pred_config
=
PredictConfig
(
pred_config
=
PredictConfig
(
session_init
=
get_model_loader
(
model_path
),
session_init
=
SmartInit
(
model_path
),
model
=
Model
(),
model
=
Model
(),
input_names
=
[
'input_img'
],
input_names
=
[
'input_img'
],
output_names
=
[
'prediction_img'
])
output_names
=
[
'prediction_img'
])
...
@@ -127,7 +127,7 @@ def export_compact(model_path):
...
@@ -127,7 +127,7 @@ def export_compact(model_path):
def
apply
(
model_path
):
def
apply
(
model_path
):
"""Run inference from a training model checkpoint. """
"""Run inference from a training model checkpoint. """
pred_config
=
PredictConfig
(
pred_config
=
PredictConfig
(
session_init
=
get_model_loader
(
model_path
),
session_init
=
SmartInit
(
model_path
),
model
=
Model
(),
model
=
Model
(),
input_names
=
[
'input_img'
],
input_names
=
[
'input_img'
],
output_names
=
[
'prediction_img'
])
output_names
=
[
'prediction_img'
])
...
@@ -141,7 +141,7 @@ def apply(model_path):
...
@@ -141,7 +141,7 @@ def apply(model_path):
def
apply_inference_graph
(
model_path
):
def
apply_inference_graph
(
model_path
):
"""Run inference from a different graph, which receives encoded images buffers. """
"""Run inference from a different graph, which receives encoded images buffers. """
pred_config
=
PredictConfig
(
pred_config
=
PredictConfig
(
session_init
=
get_model_loader
(
model_path
),
session_init
=
SmartInit
(
model_path
),
model
=
InferenceOnlyModel
(),
model
=
InferenceOnlyModel
(),
input_names
=
[
'input_img_bytes'
],
input_names
=
[
'input_img_bytes'
],
output_names
=
[
'prediction_img_bytes'
])
output_names
=
[
'prediction_img_bytes'
])
...
...
examples/basics/svhn-digit-convnet.py
View file @
17b34c69
...
@@ -107,6 +107,6 @@ if __name__ == '__main__':
...
@@ -107,6 +107,6 @@ if __name__ == '__main__':
ScalarStats
([
'cost'
,
'accuracy'
]))
ScalarStats
([
'cost'
,
'accuracy'
]))
],
],
max_epoch
=
350
,
max_epoch
=
350
,
session_init
=
S
averRestore
(
args
.
load
)
if
args
.
load
else
None
session_init
=
S
martInit
(
args
.
load
)
)
)
launch_train_with_config
(
config
,
SimpleTrainer
())
launch_train_with_config
(
config
,
SimpleTrainer
())
examples/boilerplate.py
View file @
17b34c69
...
@@ -70,8 +70,6 @@ if __name__ == '__main__':
...
@@ -70,8 +70,6 @@ if __name__ == '__main__':
os
.
environ
[
'CUDA_VISIBLE_DEVICES'
]
=
args
.
gpu
os
.
environ
[
'CUDA_VISIBLE_DEVICES'
]
=
args
.
gpu
config
=
get_config
()
config
=
get_config
()
config
.
session_init
=
SmartInit
(
args
.
load
)
if
args
.
load
:
config
.
session_init
=
SaverRestore
(
args
.
load
)
launch_train_with_config
(
config
,
SimpleTrainer
())
launch_train_with_config
(
config
,
SimpleTrainer
())
tensorpack/tfutils/sessinit.py
View file @
17b34c69
...
@@ -12,7 +12,7 @@ from .varmanip import SessionUpdate, get_checkpoint_path, get_savename_from_varn
...
@@ -12,7 +12,7 @@ from .varmanip import SessionUpdate, get_checkpoint_path, get_savename_from_varn
__all__
=
[
'SessionInit'
,
'ChainInit'
,
__all__
=
[
'SessionInit'
,
'ChainInit'
,
'SaverRestore'
,
'SaverRestoreRelaxed'
,
'DictRestore'
,
'SaverRestore'
,
'SaverRestoreRelaxed'
,
'DictRestore'
,
'JustCurrentSession'
,
'get_model_loader'
,
'Smart
Restore
'
]
'JustCurrentSession'
,
'get_model_loader'
,
'Smart
Init
'
]
class
SessionInit
(
object
):
class
SessionInit
(
object
):
...
@@ -260,7 +260,7 @@ class ChainInit(SessionInit):
...
@@ -260,7 +260,7 @@ class ChainInit(SessionInit):
i
.
_run_init
(
sess
)
i
.
_run_init
(
sess
)
def
Smart
Restore
(
obj
,
ignore_mismatch
=
False
):
def
Smart
Init
(
obj
,
ignore_mismatch
=
False
):
"""
"""
Create a :class:`SessionInit` to be loaded to a session,
Create a :class:`SessionInit` to be loaded to a session,
automatically from any supported objects, with some smart heuristics.
automatically from any supported objects, with some smart heuristics.
...
@@ -268,9 +268,9 @@ def SmartRestore(obj, ignore_mismatch=False):
...
@@ -268,9 +268,9 @@ def SmartRestore(obj, ignore_mismatch=False):
+ A TF checkpoint
+ A TF checkpoint
+ A dict of numpy arrays
+ A dict of numpy arrays
+ A npz file
+ A npz file
, to be interpreted as a dict
+ An empty string or None
+ An empty string or None
, in which case the sessinit will be a no-op
+ A list of supported objects
+ A list of supported objects
, to be initialized one by one
Args:
Args:
obj: a supported object
obj: a supported object
...
@@ -285,7 +285,7 @@ def SmartRestore(obj, ignore_mismatch=False):
...
@@ -285,7 +285,7 @@ def SmartRestore(obj, ignore_mismatch=False):
if
not
obj
:
if
not
obj
:
return
JustCurrentSession
()
return
JustCurrentSession
()
if
isinstance
(
obj
,
list
):
if
isinstance
(
obj
,
list
):
return
ChainInit
([
Smart
Restore
(
x
,
ignore_mismatch
=
ignore_mismatch
)
for
x
in
obj
])
return
ChainInit
([
Smart
Init
(
x
,
ignore_mismatch
=
ignore_mismatch
)
for
x
in
obj
])
if
isinstance
(
obj
,
six
.
string_types
):
if
isinstance
(
obj
,
six
.
string_types
):
obj
=
os
.
path
.
expanduser
(
obj
)
obj
=
os
.
path
.
expanduser
(
obj
)
if
obj
.
endswith
(
".npy"
)
or
obj
.
endswith
(
".npz"
):
if
obj
.
endswith
(
".npy"
)
or
obj
.
endswith
(
".npz"
):
...
@@ -301,11 +301,11 @@ def SmartRestore(obj, ignore_mismatch=False):
...
@@ -301,11 +301,11 @@ def SmartRestore(obj, ignore_mismatch=False):
# A TF checkpoint must be a prefix of an actual file.
# A TF checkpoint must be a prefix of an actual file.
return
(
SaverRestoreRelaxed
if
ignore_mismatch
else
SaverRestore
)(
obj
)
return
(
SaverRestoreRelaxed
if
ignore_mismatch
else
SaverRestore
)(
obj
)
else
:
else
:
raise
ValueError
(
"Invalid argument to Smart
Restore
: "
+
obj
)
raise
ValueError
(
"Invalid argument to Smart
Init
: "
+
obj
)
if
isinstance
(
obj
,
dict
):
if
isinstance
(
obj
,
dict
):
return
DictRestore
(
obj
,
ignore_mismatch
=
ignore_mismatch
)
return
DictRestore
(
obj
,
ignore_mismatch
=
ignore_mismatch
)
raise
ValueError
(
"Invalid argument to Smart
Restore
: "
+
type
(
obj
))
raise
ValueError
(
"Invalid argument to Smart
Init
: "
+
type
(
obj
))
get_model_loader
=
Smart
Restore
get_model_loader
=
Smart
Init
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment