Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
ac9ac2a4
Commit
ac9ac2a4
authored
Dec 21, 2018
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
isort -y -sp tox.ini
parent
9c2be2ad
Changes
197
Show whitespace changes
Inline
Side-by-side
Showing
197 changed files
with
724 additions
and
795 deletions
+724
-795
README.md
README.md
+1
-1
examples/A3C-Gym/simulator.py
examples/A3C-Gym/simulator.py
+5
-7
examples/A3C-Gym/train-atari.py
examples/A3C-Gym/train-atari.py
+8
-11
examples/CTC-TIMIT/create-lmdb.py
examples/CTC-TIMIT/create-lmdb.py
+5
-5
examples/CTC-TIMIT/timitdata.py
examples/CTC-TIMIT/timitdata.py
+2
-1
examples/CTC-TIMIT/train-timit.py
examples/CTC-TIMIT/train-timit.py
+4
-4
examples/CaffeModels/load-alexnet.py
examples/CaffeModels/load-alexnet.py
+4
-4
examples/CaffeModels/load-cpm.py
examples/CaffeModels/load-cpm.py
+3
-2
examples/CaffeModels/load-vgg16.py
examples/CaffeModels/load-vgg16.py
+3
-3
examples/CaffeModels/load-vgg19.py
examples/CaffeModels/load-vgg19.py
+3
-3
examples/Char-RNN/char-rnn.py
examples/Char-RNN/char-rnn.py
+4
-5
examples/DeepQNetwork/DQN.py
examples/DeepQNetwork/DQN.py
+6
-6
examples/DeepQNetwork/DQNModel.py
examples/DeepQNetwork/DQNModel.py
+3
-3
examples/DeepQNetwork/atari.py
examples/DeepQNetwork/atari.py
+7
-8
examples/DeepQNetwork/atari_wrapper.py
examples/DeepQNetwork/atari_wrapper.py
+0
-1
examples/DeepQNetwork/common.py
examples/DeepQNetwork/common.py
+3
-3
examples/DeepQNetwork/expreplay.py
examples/DeepQNetwork/expreplay.py
+5
-5
examples/DisturbLabel/mnist-disturb.py
examples/DisturbLabel/mnist-disturb.py
+5
-5
examples/DisturbLabel/svhn-disturb.py
examples/DisturbLabel/svhn-disturb.py
+2
-3
examples/DoReFa-Net/alexnet-dorefa.py
examples/DoReFa-Net/alexnet-dorefa.py
+5
-7
examples/DoReFa-Net/dorefa.py
examples/DoReFa-Net/dorefa.py
+1
-0
examples/DoReFa-Net/resnet-dorefa.py
examples/DoReFa-Net/resnet-dorefa.py
+3
-3
examples/DoReFa-Net/svhn-digit-dorefa.py
examples/DoReFa-Net/svhn-digit-dorefa.py
+2
-2
examples/DynamicFilterNetwork/steering-filter.py
examples/DynamicFilterNetwork/steering-filter.py
+4
-5
examples/FasterRCNN/basemodel.py
examples/FasterRCNN/basemodel.py
+2
-3
examples/FasterRCNN/coco.py
examples/FasterRCNN/coco.py
+3
-4
examples/FasterRCNN/config.py
examples/FasterRCNN/config.py
+1
-0
examples/FasterRCNN/data.py
examples/FasterRCNN/data.py
+10
-11
examples/FasterRCNN/eval.py
examples/FasterRCNN/eval.py
+7
-8
examples/FasterRCNN/model_cascade.py
examples/FasterRCNN/model_cascade.py
+3
-3
examples/FasterRCNN/model_fpn.py
examples/FasterRCNN/model_fpn.py
+7
-8
examples/FasterRCNN/model_frcnn.py
examples/FasterRCNN/model_frcnn.py
+4
-5
examples/FasterRCNN/model_mrcnn.py
examples/FasterRCNN/model_mrcnn.py
+3
-4
examples/FasterRCNN/model_rpn.py
examples/FasterRCNN/model_rpn.py
+4
-5
examples/FasterRCNN/train.py
examples/FasterRCNN/train.py
+26
-39
examples/FasterRCNN/utils/box_ops.py
examples/FasterRCNN/utils/box_ops.py
+2
-0
examples/FasterRCNN/utils/generate_anchors.py
examples/FasterRCNN/utils/generate_anchors.py
+7
-16
examples/FasterRCNN/viz.py
examples/FasterRCNN/viz.py
+2
-2
examples/GAN/BEGAN.py
examples/GAN/BEGAN.py
+4
-3
examples/GAN/ConditionalGAN-mnist.py
examples/GAN/ConditionalGAN-mnist.py
+6
-6
examples/GAN/CycleGAN.py
examples/GAN/CycleGAN.py
+5
-5
examples/GAN/DCGAN.py
examples/GAN/DCGAN.py
+5
-5
examples/GAN/DiscoGAN-CelebA.py
examples/GAN/DiscoGAN-CelebA.py
+6
-6
examples/GAN/GAN.py
examples/GAN/GAN.py
+4
-4
examples/GAN/Image2Image.py
examples/GAN/Image2Image.py
+8
-8
examples/GAN/Improved-WGAN.py
examples/GAN/Improved-WGAN.py
+4
-3
examples/GAN/InfoGAN-mnist.py
examples/GAN/InfoGAN-mnist.py
+8
-8
examples/GAN/WGAN.py
examples/GAN/WGAN.py
+4
-2
examples/HED/hed.py
examples/HED/hed.py
+5
-6
examples/ImageNetModels/alexnet.py
examples/ImageNetModels/alexnet.py
+1
-2
examples/ImageNetModels/imagenet_utils.py
examples/ImageNetModels/imagenet_utils.py
+8
-10
examples/ImageNetModels/inception-bn.py
examples/ImageNetModels/inception-bn.py
+1
-2
examples/ImageNetModels/shufflenet.py
examples/ImageNetModels/shufflenet.py
+3
-7
examples/ImageNetModels/vgg16.py
examples/ImageNetModels/vgg16.py
+1
-3
examples/OpticalFlow/flownet2.py
examples/OpticalFlow/flownet2.py
+3
-3
examples/OpticalFlow/flownet_models.py
examples/OpticalFlow/flownet_models.py
+1
-0
examples/PennTreebank/PTB-LSTM.py
examples/PennTreebank/PTB-LSTM.py
+4
-5
examples/PennTreebank/reader.py
examples/PennTreebank/reader.py
+1
-5
examples/ResNet/cifar10-preact18-mixup.py
examples/ResNet/cifar10-preact18-mixup.py
+2
-2
examples/ResNet/cifar10-resnet.py
examples/ResNet/cifar10-resnet.py
+2
-4
examples/ResNet/imagenet-resnet.py
examples/ResNet/imagenet-resnet.py
+6
-10
examples/ResNet/load-resnet.py
examples/ResNet/load-resnet.py
+7
-7
examples/ResNet/resnet_model.py
examples/ResNet/resnet_model.py
+1
-2
examples/Saliency/CAM-resnet.py
examples/Saliency/CAM-resnet.py
+9
-13
examples/Saliency/saliency-maps.py
examples/Saliency/saliency-maps.py
+2
-3
examples/SimilarityLearning/embedding_data.py
examples/SimilarityLearning/embedding_data.py
+2
-1
examples/SimilarityLearning/mnist-embeddings.py
examples/SimilarityLearning/mnist-embeddings.py
+2
-3
examples/SpatialTransformer/mnist-addition.py
examples/SpatialTransformer/mnist-addition.py
+4
-5
examples/SuperResolution/data_sampler.py
examples/SuperResolution/data_sampler.py
+4
-3
examples/SuperResolution/enet-pat.py
examples/SuperResolution/enet-pat.py
+6
-6
examples/basics/cifar-convnet.py
examples/basics/cifar-convnet.py
+3
-2
examples/basics/export-model.py
examples/basics/export-model.py
+1
-0
examples/basics/mnist-convnet.py
examples/basics/mnist-convnet.py
+5
-6
examples/basics/mnist-tflayers.py
examples/basics/mnist-tflayers.py
+5
-6
examples/basics/mnist-tfslim.py
examples/basics/mnist-tfslim.py
+3
-2
examples/basics/mnist-visualizations.py
examples/basics/mnist-visualizations.py
+1
-0
examples/basics/svhn-digit-convnet.py
examples/basics/svhn-digit-convnet.py
+2
-2
examples/boilerplate.py
examples/boilerplate.py
+1
-1
examples/keras/imagenet-resnet-keras.py
examples/keras/imagenet-resnet-keras.py
+6
-7
examples/keras/mnist-keras-v2.py
examples/keras/mnist-keras-v2.py
+4
-6
examples/keras/mnist-keras.py
examples/keras/mnist-keras.py
+6
-6
scripts/checkpoint-manipulate.py
scripts/checkpoint-manipulate.py
+2
-2
scripts/checkpoint-prof.py
scripts/checkpoint-prof.py
+4
-3
scripts/dump-model-params.py
scripts/dump-model-params.py
+2
-2
scripts/ls-checkpoint.py
scripts/ls-checkpoint.py
+3
-3
setup.py
setup.py
+3
-3
tensorpack/callbacks/base.py
tensorpack/callbacks/base.py
+2
-1
tensorpack/callbacks/concurrency.py
tensorpack/callbacks/concurrency.py
+3
-2
tensorpack/callbacks/graph.py
tensorpack/callbacks/graph.py
+3
-3
tensorpack/callbacks/group.py
tensorpack/callbacks/group.py
+4
-4
tensorpack/callbacks/hooks.py
tensorpack/callbacks/hooks.py
+1
-0
tensorpack/callbacks/inference.py
tensorpack/callbacks/inference.py
+3
-3
tensorpack/callbacks/inference_runner.py
tensorpack/callbacks/inference_runner.py
+6
-11
tensorpack/callbacks/misc.py
tensorpack/callbacks/misc.py
+3
-3
tensorpack/callbacks/monitor.py
tensorpack/callbacks/monitor.py
+7
-7
tensorpack/callbacks/param.py
tensorpack/callbacks/param.py
+6
-6
tensorpack/callbacks/prof.py
tensorpack/callbacks/prof.py
+5
-5
tensorpack/callbacks/saver.py
tensorpack/callbacks/saver.py
+3
-3
tensorpack/callbacks/stats.py
tensorpack/callbacks/stats.py
+2
-2
tensorpack/callbacks/steps.py
tensorpack/callbacks/steps.py
+3
-4
tensorpack/callbacks/summary.py
tensorpack/callbacks/summary.py
+1
-1
tensorpack/callbacks/trigger.py
tensorpack/callbacks/trigger.py
+1
-1
tensorpack/contrib/keras.py
tensorpack/contrib/keras.py
+10
-14
tensorpack/dataflow/base.py
tensorpack/dataflow/base.py
+2
-1
tensorpack/dataflow/common.py
tensorpack/dataflow/common.py
+8
-8
tensorpack/dataflow/dataset/bsds500.py
tensorpack/dataflow/dataset/bsds500.py
+1
-1
tensorpack/dataflow/dataset/cifar.py
tensorpack/dataflow/dataset/cifar.py
+1
-1
tensorpack/dataflow/dataset/ilsvrc.py
tensorpack/dataflow/dataset/ilsvrc.py
+2
-2
tensorpack/dataflow/dataset/mnist.py
tensorpack/dataflow/dataset/mnist.py
+1
-1
tensorpack/dataflow/dataset/svhn.py
tensorpack/dataflow/dataset/svhn.py
+2
-2
tensorpack/dataflow/dftools.py
tensorpack/dataflow/dftools.py
+1
-3
tensorpack/dataflow/format.py
tensorpack/dataflow/format.py
+8
-7
tensorpack/dataflow/image.py
tensorpack/dataflow/image.py
+4
-3
tensorpack/dataflow/imgaug/_test.py
tensorpack/dataflow/imgaug/_test.py
+3
-3
tensorpack/dataflow/imgaug/base.py
tensorpack/dataflow/imgaug/base.py
+2
-2
tensorpack/dataflow/imgaug/convert.py
tensorpack/dataflow/imgaug/convert.py
+3
-2
tensorpack/dataflow/imgaug/crop.py
tensorpack/dataflow/imgaug/crop.py
+1
-2
tensorpack/dataflow/imgaug/deform.py
tensorpack/dataflow/imgaug/deform.py
+3
-2
tensorpack/dataflow/imgaug/external.py
tensorpack/dataflow/imgaug/external.py
+0
-1
tensorpack/dataflow/imgaug/geometry.py
tensorpack/dataflow/imgaug/geometry.py
+1
-1
tensorpack/dataflow/imgaug/imgproc.py
tensorpack/dataflow/imgaug/imgproc.py
+2
-1
tensorpack/dataflow/imgaug/misc.py
tensorpack/dataflow/imgaug/misc.py
+1
-1
tensorpack/dataflow/imgaug/noise.py
tensorpack/dataflow/imgaug/noise.py
+2
-1
tensorpack/dataflow/imgaug/paste.py
tensorpack/dataflow/imgaug/paste.py
+3
-3
tensorpack/dataflow/imgaug/transform.py
tensorpack/dataflow/imgaug/transform.py
+3
-3
tensorpack/dataflow/parallel.py
tensorpack/dataflow/parallel.py
+12
-14
tensorpack/dataflow/parallel_map.py
tensorpack/dataflow/parallel_map.py
+8
-12
tensorpack/dataflow/raw.py
tensorpack/dataflow/raw.py
+2
-1
tensorpack/dataflow/remote.py
tensorpack/dataflow/remote.py
+4
-4
tensorpack/dataflow/serialize.py
tensorpack/dataflow/serialize.py
+6
-7
tensorpack/graph_builder/distributed.py
tensorpack/graph_builder/distributed.py
+4
-7
tensorpack/graph_builder/model_desc.py
tensorpack/graph_builder/model_desc.py
+2
-2
tensorpack/graph_builder/predict.py
tensorpack/graph_builder/predict.py
+1
-1
tensorpack/graph_builder/training.py
tensorpack/graph_builder/training.py
+10
-13
tensorpack/graph_builder/utils.py
tensorpack/graph_builder/utils.py
+4
-5
tensorpack/input_source/input_source.py
tensorpack/input_source/input_source.py
+13
-12
tensorpack/input_source/input_source_base.py
tensorpack/input_source/input_source_base.py
+4
-4
tensorpack/models/_old_batch_norm.py
tensorpack/models/_old_batch_norm.py
+6
-4
tensorpack/models/_test.py
tensorpack/models/_test.py
+1
-1
tensorpack/models/batch_norm.py
tensorpack/models/batch_norm.py
+6
-6
tensorpack/models/common.py
tensorpack/models/common.py
+2
-2
tensorpack/models/conv2d.py
tensorpack/models/conv2d.py
+6
-5
tensorpack/models/fc.py
tensorpack/models/fc.py
+3
-3
tensorpack/models/layer_norm.py
tensorpack/models/layer_norm.py
+2
-1
tensorpack/models/linearwrap.py
tensorpack/models/linearwrap.py
+2
-1
tensorpack/models/nonlin.py
tensorpack/models/nonlin.py
+1
-1
tensorpack/models/pool.py
tensorpack/models/pool.py
+4
-5
tensorpack/models/registry.py
tensorpack/models/registry.py
+3
-3
tensorpack/models/regularize.py
tensorpack/models/regularize.py
+3
-3
tensorpack/models/shapes.py
tensorpack/models/shapes.py
+1
-0
tensorpack/models/tflayer.py
tensorpack/models/tflayer.py
+3
-3
tensorpack/predict/base.py
tensorpack/predict/base.py
+3
-3
tensorpack/predict/concurrency.py
tensorpack/predict/concurrency.py
+5
-5
tensorpack/predict/config.py
tensorpack/predict/config.py
+2
-2
tensorpack/predict/dataset.py
tensorpack/predict/dataset.py
+5
-6
tensorpack/predict/feedfree.py
tensorpack/predict/feedfree.py
+3
-4
tensorpack/predict/multigpu.py
tensorpack/predict/multigpu.py
+2
-1
tensorpack/tfutils/argscope.py
tensorpack/tfutils/argscope.py
+4
-4
tensorpack/tfutils/collection.py
tensorpack/tfutils/collection.py
+2
-2
tensorpack/tfutils/common.py
tensorpack/tfutils/common.py
+1
-0
tensorpack/tfutils/dependency.py
tensorpack/tfutils/dependency.py
+1
-0
tensorpack/tfutils/export.py
tensorpack/tfutils/export.py
+2
-2
tensorpack/tfutils/gradproc.py
tensorpack/tfutils/gradproc.py
+5
-4
tensorpack/tfutils/model_utils.py
tensorpack/tfutils/model_utils.py
+1
-1
tensorpack/tfutils/optimizer.py
tensorpack/tfutils/optimizer.py
+2
-2
tensorpack/tfutils/scope_utils.py
tensorpack/tfutils/scope_utils.py
+1
-1
tensorpack/tfutils/sesscreate.py
tensorpack/tfutils/sesscreate.py
+2
-1
tensorpack/tfutils/sessinit.py
tensorpack/tfutils/sessinit.py
+2
-3
tensorpack/tfutils/summary.py
tensorpack/tfutils/summary.py
+4
-5
tensorpack/tfutils/tower.py
tensorpack/tfutils/tower.py
+3
-3
tensorpack/tfutils/varmanip.py
tensorpack/tfutils/varmanip.py
+3
-2
tensorpack/tfutils/varreplace.py
tensorpack/tfutils/varreplace.py
+1
-1
tensorpack/train/base.py
tensorpack/train/base.py
+12
-14
tensorpack/train/config.py
tensorpack/train/config.py
+4
-6
tensorpack/train/interface.py
tensorpack/train/interface.py
+1
-4
tensorpack/train/tower.py
tensorpack/train/tower.py
+6
-8
tensorpack/train/trainers.py
tensorpack/train/trainers.py
+11
-16
tensorpack/train/utility.py
tensorpack/train/utility.py
+1
-3
tensorpack/utils/argtools.py
tensorpack/utils/argtools.py
+2
-0
tensorpack/utils/compatible_serialize.py
tensorpack/utils/compatible_serialize.py
+2
-1
tensorpack/utils/concurrency.py
tensorpack/utils/concurrency.py
+4
-4
tensorpack/utils/develop.py
tensorpack/utils/develop.py
+2
-2
tensorpack/utils/fs.py
tensorpack/utils/fs.py
+3
-2
tensorpack/utils/gpu.py
tensorpack/utils/gpu.py
+3
-2
tensorpack/utils/loadcaffe.py
tensorpack/utils/loadcaffe.py
+4
-4
tensorpack/utils/logger.py
tensorpack/utils/logger.py
+3
-3
tensorpack/utils/nvml.py
tensorpack/utils/nvml.py
+1
-3
tensorpack/utils/rect.py
tensorpack/utils/rect.py
+1
-0
tensorpack/utils/serialize.py
tensorpack/utils/serialize.py
+3
-2
tensorpack/utils/timer.py
tensorpack/utils/timer.py
+4
-4
tensorpack/utils/utils.py
tensorpack/utils/utils.py
+2
-3
tensorpack/utils/viz.py
tensorpack/utils/viz.py
+3
-2
tests/run-tests.sh
tests/run-tests.sh
+1
-6
tests/test_char_rnn.py
tests/test_char_rnn.py
+2
-1
tests/test_infogan.py
tests/test_infogan.py
+1
-0
tests/test_serializer.py
tests/test_serializer.py
+5
-4
tox.ini
tox.ini
+10
-0
No files found.
README.md
View file @
ac9ac2a4
...
...
@@ -32,7 +32,7 @@ It's Yet Another TF high-level API, with __speed__, and __flexibility__ built to
See
[
tutorials and documentations
](
http://tensorpack.readthedocs.io/tutorial/index.html#user-tutorials
)
to know more about these features.
##
[Examples](examples)
:
##
Examples
:
We refuse toy examples.
Instead of showing you 10 arbitrary networks trained on toy datasets,
...
...
examples/A3C-Gym/simulator.py
View file @
ac9ac2a4
...
...
@@ -4,20 +4,18 @@
# Author: Yuxin Wu
import
multiprocessing
as
mp
import
time
import
os
import
threading
from
abc
import
abstractmethod
,
ABCMeta
import
time
from
abc
import
ABCMeta
,
abstractmethod
from
collections
import
defaultdict
import
six
from
six.moves
import
queue
import
zmq
from
six.moves
import
queue
from
tensorpack.utils
import
logger
from
tensorpack.utils.serialize
import
loads
,
dumps
from
tensorpack.utils.concurrency
import
(
LoopThread
,
ensure_proc_terminate
,
enable_death_signal
)
from
tensorpack.utils.concurrency
import
LoopThread
,
enable_death_signal
,
ensure_proc_terminate
from
tensorpack.utils.serialize
import
dumps
,
loads
__all__
=
[
'SimulatorProcess'
,
'SimulatorMaster'
,
'SimulatorProcessStateExchange'
,
...
...
examples/A3C-Gym/train-atari.py
View file @
ac9ac2a4
...
...
@@ -3,29 +3,26 @@
# File: train-atari.py
# Author: Yuxin Wu
import
argparse
import
numpy
as
np
import
sys
import
os
import
sys
import
uuid
import
argparse
import
cv2
import
tensorflow
as
tf
import
gym
import
six
import
tensorflow
as
tf
from
six.moves
import
queue
from
tensorpack
import
*
from
tensorpack.utils.concurrency
import
ensure_proc_terminate
,
start_proc_mask_signal
from
tensorpack.utils.serialize
import
dumps
from
tensorpack.tfutils.gradproc
import
MapGradient
,
SummaryGradient
from
tensorpack.utils.concurrency
import
ensure_proc_terminate
,
start_proc_mask_signal
from
tensorpack.utils.gpu
import
get_num_gpu
from
tensorpack.utils.serialize
import
dumps
import
gym
from
simulator
import
SimulatorProcess
,
SimulatorMaster
,
TransitionExperience
from
atari_wrapper
import
FireResetEnv
,
FrameStack
,
LimitLength
,
MapState
from
common
import
Evaluator
,
eval_model_multithread
,
play_n_episodes
from
atari_wrapper
import
MapState
,
FrameStack
,
FireResetEnv
,
LimitLength
from
simulator
import
SimulatorMaster
,
SimulatorProcess
,
TransitionExperience
if
six
.
PY3
:
from
concurrent
import
futures
...
...
examples/CTC-TIMIT/create-lmdb.py
View file @
ac9ac2a4
...
...
@@ -2,17 +2,17 @@
# -*- coding: utf-8 -*-
# File: create-lmdb.py
# Author: Yuxin Wu
import
argparse
import
numpy
as
np
import
os
import
scipy.io.wavfile
as
wavfile
import
string
import
numpy
as
np
import
argparse
import
bob.ap
import
scipy.io.wavfile
as
wavfile
from
tensorpack.dataflow
import
DataFlow
,
LMDBSerializer
from
tensorpack.utils
import
fs
,
logger
,
serialize
from
tensorpack.utils.argtools
import
memoized
from
tensorpack.utils.stats
import
OnlineMoments
from
tensorpack.utils
import
serialize
,
fs
,
logger
from
tensorpack.utils.utils
import
get_tqdm
CHARSET
=
set
(
string
.
ascii_lowercase
+
' '
)
...
...
examples/CTC-TIMIT/timitdata.py
View file @
ac9ac2a4
...
...
@@ -2,10 +2,11 @@
# File: timitdata.py
# Author: Yuxin Wu
from
tensorpack
import
ProxyDataFlow
import
numpy
as
np
from
six.moves
import
range
from
tensorpack
import
ProxyDataFlow
__all__
=
[
'TIMITBatch'
]
...
...
examples/CTC-TIMIT/train-timit.py
View file @
ac9ac2a4
...
...
@@ -3,17 +3,17 @@
# File: train-timit.py
# Author: Yuxin Wu
import
os
import
argparse
import
os
import
tensorflow
as
tf
from
six.moves
import
range
from
tensorpack
import
*
from
tensorpack.tfutils.gradproc
import
SummaryGradient
,
GlobalNormClip
from
tensorpack.tfutils.gradproc
import
GlobalNormClip
,
SummaryGradient
from
tensorpack.utils
import
serialize
import
tensorflow
as
tf
from
timitdata
import
TIMITBatch
rnn
=
tf
.
contrib
.
rnn
...
...
examples/CaffeModels/load-alexnet.py
View file @
ac9ac2a4
...
...
@@ -4,16 +4,16 @@
# Author: Yuxin Wu
from
__future__
import
print_function
import
argparse
import
numpy
as
np
import
os
import
cv2
import
argparse
import
tensorflow
as
tf
from
tensorpack
import
*
from
tensorpack.tfutils.symbolic_functions
import
*
from
tensorpack.tfutils.summary
import
*
from
tensorpack.dataflow.dataset
import
ILSVRCMeta
import
tensorflow
as
tf
from
tensorpack.tfutils.summary
import
*
from
tensorpack.tfutils.symbolic_functions
import
*
def
tower_func
(
image
):
...
...
examples/CaffeModels/load-cpm.py
View file @
ac9ac2a4
...
...
@@ -3,15 +3,16 @@
# File: load-cpm.py
# Author: Yuxin Wu
import
argparse
import
numpy
as
np
import
cv2
import
tensorflow
as
tf
import
numpy
as
np
import
argparse
from
tensorpack
import
*
from
tensorpack.utils
import
viz
from
tensorpack.utils.argtools
import
memoized
"""
15 channels:
0-1 head, neck
...
...
examples/CaffeModels/load-vgg16.py
View file @
ac9ac2a4
...
...
@@ -3,12 +3,12 @@
# File: load-vgg16.py
from
__future__
import
print_function
import
cv2
import
tensorflow
as
tf
import
argparse
import
numpy
as
np
import
os
import
cv2
import
six
import
argparse
import
tensorflow
as
tf
from
tensorpack
import
*
from
tensorpack.dataflow.dataset
import
ILSVRCMeta
...
...
examples/CaffeModels/load-vgg19.py
View file @
ac9ac2a4
...
...
@@ -3,12 +3,12 @@
# File: load-vgg19.py
from
__future__
import
print_function
import
cv2
import
tensorflow
as
tf
import
argparse
import
numpy
as
np
import
os
import
cv2
import
six
import
argparse
import
tensorflow
as
tf
from
tensorpack
import
*
from
tensorpack.dataflow.dataset
import
ILSVRCMeta
...
...
examples/Char-RNN/char-rnn.py
View file @
ac9ac2a4
...
...
@@ -3,21 +3,20 @@
# File: char-rnn.py
# Author: Yuxin Wu
import
argparse
import
numpy
as
np
import
operator
import
os
import
sys
import
argparse
from
collections
import
Counter
import
operator
import
six
import
tensorflow
as
tf
from
six.moves
import
range
from
tensorpack
import
*
from
tensorpack.tfutils
import
summary
,
optimizer
from
tensorpack.tfutils
import
optimizer
,
summary
from
tensorpack.tfutils.gradproc
import
GlobalNormClip
import
tensorflow
as
tf
rnn
=
tf
.
contrib
.
rnn
class
_NS
:
pass
# noqa
...
...
examples/DeepQNetwork/DQN.py
View file @
ac9ac2a4
...
...
@@ -3,20 +3,20 @@
# File: DQN.py
# Author: Yuxin Wu
import
os
import
argparse
import
cv2
import
numpy
as
np
import
tensorflow
as
tf
import
os
import
cv2
import
gym
import
tensorflow
as
tf
from
tensorpack
import
*
from
DQNModel
import
Model
as
DQNModel
from
atari
import
AtariPlayer
from
atari_wrapper
import
FireResetEnv
,
FrameStack
,
LimitLength
,
MapState
from
common
import
Evaluator
,
eval_model_multithread
,
play_n_episodes
from
atari_wrapper
import
FrameStack
,
MapState
,
FireResetEnv
,
LimitLength
from
DQNModel
import
Model
as
DQNModel
from
expreplay
import
ExpReplay
from
atari
import
AtariPlayer
BATCH_SIZE
=
64
IMAGE_SIZE
=
(
84
,
84
)
...
...
examples/DeepQNetwork/DQNModel.py
View file @
ac9ac2a4
...
...
@@ -4,11 +4,11 @@
import
abc
import
tensorflow
as
tf
from
tensorpack
import
ModelDesc
from
tensorpack.utils
import
logger
from
tensorpack.tfutils
import
(
varreplace
,
summary
,
get_current_tower_context
,
optimizer
,
gradproc
)
from
tensorpack.tfutils
import
get_current_tower_context
,
gradproc
,
optimizer
,
summary
,
varreplace
from
tensorpack.tfutils.scope_utils
import
auto_reuse_variable_scope
from
tensorpack.utils
import
logger
class
Model
(
ModelDesc
):
...
...
examples/DeepQNetwork/atari.py
View file @
ac9ac2a4
...
...
@@ -4,19 +4,18 @@
import
numpy
as
np
import
os
import
cv2
import
threading
import
six
from
six.moves
import
range
from
tensorpack.utils
import
logger
from
tensorpack.utils.utils
import
get_rng
,
execute_only_once
from
tensorpack.utils.fs
import
get_dataset_path
import
cv2
import
gym
import
six
from
ale_python_interface
import
ALEInterface
from
gym
import
spaces
from
gym.envs.atari.atari_env
import
ACTION_MEANING
from
six.moves
import
range
from
ale_python_interface
import
ALEInterface
from
tensorpack.utils
import
logger
from
tensorpack.utils.fs
import
get_dataset_path
from
tensorpack.utils.utils
import
execute_only_once
,
get_rng
__all__
=
[
'AtariPlayer'
]
...
...
examples/DeepQNetwork/atari_wrapper.py
View file @
ac9ac2a4
...
...
@@ -3,7 +3,6 @@
import
numpy
as
np
from
collections
import
deque
import
gym
from
gym
import
spaces
...
...
examples/DeepQNetwork/common.py
View file @
ac9ac2a4
# -*- coding: utf-8 -*-
# File: common.py
# Author: Yuxin Wu
import
multiprocessing
import
random
import
time
import
multiprocessing
from
tqdm
import
tqdm
from
six.moves
import
queue
from
tqdm
import
tqdm
from
tensorpack.utils.concurrency
import
StoppableThread
,
ShareSessionThread
from
tensorpack.callbacks
import
Callback
from
tensorpack.utils
import
logger
from
tensorpack.utils.concurrency
import
ShareSessionThread
,
StoppableThread
from
tensorpack.utils.stats
import
StatCounter
from
tensorpack.utils.utils
import
get_tqdm_kwargs
...
...
examples/DeepQNetwork/expreplay.py
View file @
ac9ac2a4
...
...
@@ -2,18 +2,18 @@
# File: expreplay.py
# Author: Yuxin Wu
import
numpy
as
np
import
copy
from
collections
import
deque
,
namedtuple
import
numpy
as
np
import
threading
from
collections
import
deque
,
namedtuple
from
six.moves
import
queue
,
range
from
tensorpack.callbacks.base
import
Callback
from
tensorpack.dataflow
import
DataFlow
from
tensorpack.utils
import
logger
from
tensorpack.utils.utils
import
get_tqdm
,
get_rng
from
tensorpack.utils.stats
import
StatCounter
from
tensorpack.utils.concurrency
import
LoopThread
,
ShareSessionThread
from
tensorpack.callbacks.base
import
Callback
from
tensorpack.utils.stats
import
StatCounter
from
tensorpack.utils.utils
import
get_rng
,
get_tqdm
__all__
=
[
'ExpReplay'
]
...
...
examples/DisturbLabel/mnist-disturb.py
View file @
ac9ac2a4
...
...
@@ -2,17 +2,17 @@
# -*- coding: utf-8 -*-
# File: mnist-disturb.py
import
os
import
argparse
import
imp
import
os
import
tensorflow
as
tf
from
tensorpack
import
*
from
tensorpack.utils
import
logger
from
tensorpack.dataflow
import
dataset
import
tensorflow
as
tf
from
tensorpack.utils
import
logger
from
disturb
import
DisturbLabel
import
imp
mnist_example
=
imp
.
load_source
(
'mnist_example'
,
os
.
path
.
join
(
os
.
path
.
dirname
(
__file__
),
'..'
,
'basics'
,
'mnist-convnet.py'
))
get_config
=
mnist_example
.
get_config
...
...
examples/DisturbLabel/svhn-disturb.py
View file @
ac9ac2a4
...
...
@@ -3,13 +3,12 @@
# File: svhn-disturb.py
import
argparse
import
os
import
imp
import
os
from
tensorpack
import
*
from
tensorpack.utils
import
logger
from
tensorpack.dataflow
import
dataset
from
tensorpack.utils
import
logger
from
disturb
import
DisturbLabel
...
...
examples/DoReFa-Net/alexnet-dorefa.py
View file @
ac9ac2a4
...
...
@@ -3,24 +3,22 @@
# File: alexnet-dorefa.py
# Author: Yuxin Wu, Yuheng Zou ({wyx,zyh}@megvii.com)
import
cv2
import
tensorflow
as
tf
import
argparse
import
numpy
as
np
import
os
import
sys
import
cv2
import
tensorflow
as
tf
from
tensorpack
import
*
from
tensorpack.
tfutils.summary
import
add_param_summary
from
tensorpack.
dataflow
import
dataset
from
tensorpack.tfutils.sessinit
import
get_model_loader
from
tensorpack.tfutils.summary
import
add_param_summary
from
tensorpack.tfutils.varreplace
import
remap_variables
from
tensorpack.dataflow
import
dataset
from
tensorpack.utils.gpu
import
get_num_gpu
from
imagenet_utils
import
(
get_imagenet_dataflow
,
fbresnet_augmentor
,
ImageNetModel
,
eval_on_ILSVRC12
)
from
dorefa
import
get_dorefa
,
ternarize
from
imagenet_utils
import
ImageNetModel
,
eval_on_ILSVRC12
,
fbresnet_augmentor
,
get_imagenet_dataflow
"""
This is a tensorpack script for the ImageNet results in paper:
...
...
examples/DoReFa-Net/dorefa.py
View file @
ac9ac2a4
...
...
@@ -3,6 +3,7 @@
# Author: Yuxin Wu
import
tensorflow
as
tf
from
tensorpack.utils.argtools
import
graph_memoized
...
...
examples/DoReFa-Net/resnet-dorefa.py
View file @
ac9ac2a4
...
...
@@ -2,18 +2,18 @@
# -*- coding: utf-8 -*-
# File: resnet-dorefa.py
import
cv2
import
tensorflow
as
tf
import
argparse
import
numpy
as
np
import
os
import
cv2
import
tensorflow
as
tf
from
tensorpack
import
*
from
tensorpack.dataflow
import
dataset
from
tensorpack.tfutils.varreplace
import
remap_variables
from
imagenet_utils
import
ImageNetModel
,
eval_on_ILSVRC12
,
fbresnet_augmentor
from
dorefa
import
get_dorefa
from
imagenet_utils
import
ImageNetModel
,
eval_on_ILSVRC12
,
fbresnet_augmentor
"""
This script loads the pre-trained ResNet-18 model with (W,A,G) = (1,4,32)
...
...
examples/DoReFa-Net/svhn-digit-dorefa.py
View file @
ac9ac2a4
...
...
@@ -3,13 +3,13 @@
# File: svhn-digit-dorefa.py
# Author: Yuxin Wu
import
os
import
argparse
import
os
import
tensorflow
as
tf
from
tensorpack
import
*
from
tensorpack.tfutils.summary
import
add_moving_summary
,
add_param_summary
from
tensorpack.dataflow
import
dataset
from
tensorpack.tfutils.summary
import
add_moving_summary
,
add_param_summary
from
tensorpack.tfutils.varreplace
import
remap_variables
from
dorefa
import
get_dorefa
...
...
examples/DynamicFilterNetwork/steering-filter.py
View file @
ac9ac2a4
...
...
@@ -3,19 +3,18 @@
# File: steering-filter.py
import
argparse
import
multiprocessing
import
numpy
as
np
import
tensorflow
as
tf
import
cv2
import
tensorflow
as
tf
from
scipy.signal
import
convolve2d
from
six.moves
import
range
,
zip
import
multiprocessing
from
tensorpack
import
*
from
tensorpack.dataflow
import
dataset
from
tensorpack.utils
import
logger
from
tensorpack.utils.viz
import
*
from
tensorpack.utils.argtools
import
shape2d
,
shape4d
from
tensorpack.
dataflow
import
dataset
from
tensorpack.
utils.viz
import
*
BATCH
=
32
SHAPE
=
64
...
...
examples/FasterRCNN/basemodel.py
View file @
ac9ac2a4
# -*- coding: utf-8 -*-
# File: basemodel.py
from
contextlib
import
contextmanager
,
ExitStack
import
numpy
as
np
from
contextlib
import
ExitStack
,
contextmanager
import
tensorflow
as
tf
from
tensorpack.models
import
BatchNorm
,
Conv2D
,
MaxPooling
,
layer_register
from
tensorpack.tfutils
import
argscope
from
tensorpack.tfutils.scope_utils
import
auto_reuse_variable_scope
from
tensorpack.tfutils.varreplace
import
custom_getter_scope
,
freeze_variables
from
tensorpack.models
import
(
Conv2D
,
MaxPooling
,
BatchNorm
,
layer_register
)
from
config
import
config
as
cfg
...
...
examples/FasterRCNN/coco.py
View file @
ac9ac2a4
...
...
@@ -3,17 +3,16 @@
import
numpy
as
np
import
os
from
termcolor
import
colored
from
tabulate
import
tabulate
import
tqdm
from
tabulate
import
tabulate
from
termcolor
import
colored
from
tensorpack.utils
import
logger
from
tensorpack.utils.timer
import
timed_operation
from
tensorpack.utils.argtools
import
log_once
from
tensorpack.utils.timer
import
timed_operation
from
config
import
config
as
cfg
__all__
=
[
'COCODetection'
,
'COCOMeta'
]
...
...
examples/FasterRCNN/config.py
View file @
ac9ac2a4
...
...
@@ -4,6 +4,7 @@
import
numpy
as
np
import
os
import
pprint
from
tensorpack.utils
import
logger
from
tensorpack.utils.gpu
import
get_num_gpu
...
...
examples/FasterRCNN/data.py
View file @
ac9ac2a4
# -*- coding: utf-8 -*-
# File: data.py
import
cv2
import
numpy
as
np
import
copy
import
numpy
as
np
import
cv2
from
tensorpack.utils.argtools
import
memoized
,
log_once
from
tensorpack.dataflow
import
(
imgaug
,
TestDataSpeed
,
MultiProcessMapDataZMQ
,
MultiThreadMapData
,
MapDataComponent
,
DataFromList
)
DataFromList
,
MapDataComponent
,
MultiProcessMapDataZMQ
,
MultiThreadMapData
,
TestDataSpeed
,
imgaug
)
from
tensorpack.utils
import
logger
# import tensorpack.utils.viz as tpviz
from
tensorpack.utils.argtools
import
log_once
,
memoized
from
coco
import
COCODetection
from
common
import
(
CustomResize
,
DataFromListOfDict
,
box_to_point8
,
filter_boxes_inside_shape
,
point8_to_box
,
segmentation_to_mask
)
from
config
import
config
as
cfg
from
utils.generate_anchors
import
generate_anchors
from
utils.np_box_ops
import
area
as
np_area
from
utils.np_box_ops
import
ioa
as
np_ioa
from
common
import
(
DataFromListOfDict
,
CustomResize
,
filter_boxes_inside_shape
,
box_to_point8
,
point8_to_box
,
segmentation_to_mask
)
from
config
import
config
as
cfg
# import tensorpack.utils.viz as tpviz
try
:
import
pycocotools.mask
as
cocomask
...
...
examples/FasterRCNN/eval.py
View file @
ac9ac2a4
# -*- coding: utf-8 -*-
# File: eval.py
import
tqdm
import
itertools
import
numpy
as
np
import
os
from
collections
import
namedtuple
from
concurrent.futures
import
ThreadPoolExecutor
from
contextlib
import
ExitStack
import
itertools
import
numpy
as
np
import
cv2
from
concurrent.futures
import
ThreadPoolExecutor
from
tensorpack.utils.utils
import
get_tqdm_kwargs
import
pycocotools.mask
as
cocomask
import
tqdm
from
pycocotools.coco
import
COCO
from
pycocotools.cocoeval
import
COCOeval
import
pycocotools.mask
as
cocomask
from
tensorpack.utils.utils
import
get_tqdm_kwargs
from
coco
import
COCOMeta
from
common
import
CustomResize
,
clip_boxes
...
...
examples/FasterRCNN/model_cascade.py
View file @
ac9ac2a4
...
...
@@ -2,10 +2,10 @@ import tensorflow as tf
from
tensorpack.tfutils
import
get_current_tower_context
from
utils.box_ops
import
pairwise_iou
from
model_box
import
clip_boxes
from
model_frcnn
import
FastRCNNHead
,
BoxProposals
,
fastrcnn_outputs
from
config
import
config
as
cfg
from
model_box
import
clip_boxes
from
model_frcnn
import
BoxProposals
,
FastRCNNHead
,
fastrcnn_outputs
from
utils.box_ops
import
pairwise_iou
class
CascadeRCNNHead
(
object
):
...
...
examples/FasterRCNN/model_fpn.py
View file @
ac9ac2a4
# -*- coding: utf-8 -*-
import
itertools
import
numpy
as
np
import
tensorflow
as
tf
import
itertools
from
tensorpack.
tfutils.summary
import
add_moving_summary
from
tensorpack.
models
import
Conv2D
,
FixedUnPooling
,
MaxPooling
,
layer_register
from
tensorpack.tfutils.argscope
import
argscope
from
tensorpack.tfutils.tower
import
get_current_tower_context
from
tensorpack.tfutils.scope_utils
import
under_name_scope
from
tensorpack.
models
import
(
Conv2D
,
layer_register
,
FixedUnPooling
,
MaxPooling
)
from
tensorpack.
tfutils.summary
import
add_moving_summary
from
tensorpack.tfutils.tower
import
get_current_tower_context
from
model_rpn
import
rpn_losses
,
generate_rpn_proposals
from
basemodel
import
GroupNorm
from
config
import
config
as
cfg
from
model_box
import
roi_align
from
model_rpn
import
generate_rpn_proposals
,
rpn_losses
from
utils.box_ops
import
area
as
tf_area
from
config
import
config
as
cfg
from
basemodel
import
GroupNorm
@
layer_register
(
log_shape
=
True
)
...
...
examples/FasterRCNN/model_frcnn.py
View file @
ac9ac2a4
...
...
@@ -3,18 +3,17 @@
import
tensorflow
as
tf
from
tensorpack.
tfutils.summary
import
add_moving_summary
from
tensorpack.
models
import
Conv2D
,
FullyConnected
,
layer_register
from
tensorpack.tfutils.argscope
import
argscope
from
tensorpack.tfutils.common
import
get_tf_version_tuple
from
tensorpack.tfutils.scope_utils
import
under_name_scope
from
tensorpack.models
import
(
Conv2D
,
FullyConnected
,
layer_register
)
from
tensorpack.tfutils.summary
import
add_moving_summary
from
tensorpack.utils.argtools
import
memoized_method
from
basemodel
import
GroupNorm
from
utils.box_ops
import
pairwise_iou
from
model_box
import
encode_bbox_target
,
decode_bbox_target
from
config
import
config
as
cfg
from
model_box
import
decode_bbox_target
,
encode_bbox_target
from
utils.box_ops
import
pairwise_iou
@
under_name_scope
()
...
...
examples/FasterRCNN/model_mrcnn.py
View file @
ac9ac2a4
...
...
@@ -2,12 +2,11 @@
import
tensorflow
as
tf
from
tensorpack.models
import
(
Conv2D
,
layer_register
,
Conv2DTranspose
)
from
tensorpack.tfutils.scope_utils
import
under_name_scope
from
tensorpack.models
import
Conv2D
,
Conv2DTranspose
,
layer_register
from
tensorpack.tfutils.argscope
import
argscope
from
tensorpack.tfutils.summary
import
add_moving_summary
from
tensorpack.tfutils.common
import
get_tf_version_tuple
from
tensorpack.tfutils.scope_utils
import
under_name_scope
from
tensorpack.tfutils.summary
import
add_moving_summary
from
basemodel
import
GroupNorm
from
config
import
config
as
cfg
...
...
examples/FasterRCNN/model_rpn.py
View file @
ac9ac2a4
...
...
@@ -2,14 +2,13 @@
import
tensorflow
as
tf
from
tensorpack.tfutils.summary
import
add_moving_summary
from
tensorpack.tfutils.argscope
import
argscope
from
tensorpack.tfutils.scope_utils
import
under_name_scope
,
auto_reuse_variable_scope
from
tensorpack.models
import
Conv2D
,
layer_register
from
tensorpack.tfutils.argscope
import
argscope
from
tensorpack.tfutils.scope_utils
import
auto_reuse_variable_scope
,
under_name_scope
from
tensorpack.tfutils.summary
import
add_moving_summary
from
model_box
import
clip_boxes
from
config
import
config
as
cfg
from
model_box
import
clip_boxes
@
layer_register
(
log_shape
=
True
)
...
...
examples/FasterRCNN/train.py
View file @
ac9ac2a4
...
...
@@ -2,58 +2,45 @@
# -*- coding: utf-8 -*-
# File: train.py
import
os
import
argparse
import
cv2
import
shutil
import
itertools
import
tqdm
import
numpy
as
np
import
json
import
numpy
as
np
import
os
import
shutil
import
cv2
import
six
import
tensorflow
as
tf
try
:
import
horovod.tensorflow
as
hvd
except
ImportError
:
pass
assert
six
.
PY3
,
"FasterRCNN requires Python 3!"
import
tqdm
import
tensorpack.utils.viz
as
tpviz
from
tensorpack
import
*
from
tensorpack.tfutils.summary
import
add_moving_summary
from
tensorpack.tfutils
import
optimizer
from
tensorpack.tfutils.common
import
get_tf_version_tuple
import
tensorpack.utils.viz
as
tpviz
from
coco
import
COCODetection
from
basemodel
import
(
image_preprocess
,
resnet_c4_backbone
,
resnet_conv5
,
resnet_fpn_backbone
)
from
tensorpack.tfutils.summary
import
add_moving_summary
import
model_frcnn
import
model_mrcnn
from
model_frcnn
import
(
sample_fast_rcnn_targets
,
fastrcnn_outputs
,
fastrcnn_predictions
,
BoxProposals
,
FastRCNNHead
)
from
model_mrcnn
import
maskrcnn_upXconv_head
,
maskrcnn_loss
from
model_rpn
import
rpn_head
,
rpn_losses
,
generate_rpn_proposals
from
model_fpn
import
(
fpn_model
,
multilevel_roi_align
,
multilevel_rpn_losses
,
generate_fpn_proposals
)
from
basemodel
import
image_preprocess
,
resnet_c4_backbone
,
resnet_conv5
,
resnet_fpn_backbone
from
coco
import
COCODetection
from
config
import
config
as
cfg
from
config
import
finalize_configs
from
data
import
get_all_anchors
,
get_all_anchors_fpn
,
get_eval_dataflow
,
get_train_dataflow
from
eval
import
DetectionResult
,
detect_one_image
,
eval_coco
,
multithread_eval_coco
,
print_coco_metrics
from
model_box
import
RPNAnchors
,
clip_boxes
,
crop_and_resize
,
roi_align
from
model_cascade
import
CascadeRCNNHead
from
model_box
import
(
clip_boxes
,
crop_and_resize
,
roi_align
,
RPNAnchors
)
from
data
import
(
get_train_dataflow
,
get_eval_dataflow
,
get_all_anchors
,
get_all_anchors_fpn
)
from
viz
import
(
draw_annotation
,
draw_proposal_recall
,
draw_predictions
,
draw_final_outputs
)
from
eval
import
(
eval_coco
,
multithread_eval_coco
,
detect_one_image
,
print_coco_metrics
,
DetectionResult
)
from
config
import
finalize_configs
,
config
as
cfg
from
model_fpn
import
fpn_model
,
generate_fpn_proposals
,
multilevel_roi_align
,
multilevel_rpn_losses
from
model_frcnn
import
BoxProposals
,
FastRCNNHead
,
fastrcnn_outputs
,
fastrcnn_predictions
,
sample_fast_rcnn_targets
from
model_mrcnn
import
maskrcnn_loss
,
maskrcnn_upXconv_head
from
model_rpn
import
generate_rpn_proposals
,
rpn_head
,
rpn_losses
from
viz
import
draw_annotation
,
draw_final_outputs
,
draw_predictions
,
draw_proposal_recall
try
:
import
horovod.tensorflow
as
hvd
except
ImportError
:
pass
assert
six
.
PY3
,
"FasterRCNN requires Python 3!"
class
DetectionModel
(
ModelDesc
):
...
...
examples/FasterRCNN/utils/box_ops.py
View file @
ac9ac2a4
...
...
@@ -2,8 +2,10 @@
# File: box_ops.py
import
tensorflow
as
tf
from
tensorpack.tfutils.scope_utils
import
under_name_scope
"""
This file is modified from
https://github.com/tensorflow/models/blob/master/object_detection/core/box_list_ops.py
...
...
examples/FasterRCNN/utils/generate_anchors.py
View file @
ac9ac2a4
...
...
@@ -7,8 +7,8 @@
# Written by Ross Girshick and Sean Bell
# --------------------------------------------------------
from
six.moves
import
range
import
numpy
as
np
from
six.moves
import
range
# Verify that we compute the same anchors as Shaoqing's matlab implementation:
#
...
...
@@ -27,7 +27,7 @@ import numpy as np
# -79 -167 96 184
# -167 -343 184 360
#array([[ -83., -39., 100., 56.],
#
array([[ -83., -39., 100., 56.],
# [-175., -87., 192., 104.],
# [-359., -183., 376., 200.],
# [ -55., -55., 72., 72.],
...
...
@@ -37,6 +37,7 @@ import numpy as np
# [ -79., -167., 96., 184.],
# [-167., -343., 184., 360.]])
def
generate_anchors
(
base_size
=
16
,
ratios
=
[
0.5
,
1
,
2
],
scales
=
2
**
np
.
arange
(
3
,
6
)):
"""
...
...
@@ -50,6 +51,7 @@ def generate_anchors(base_size=16, ratios=[0.5, 1, 2],
for
i
in
range
(
ratio_anchors
.
shape
[
0
])])
return
anchors
def
_whctrs
(
anchor
):
"""
Return width, height, x center, and y center for an anchor (window).
...
...
@@ -61,6 +63,7 @@ def _whctrs(anchor):
y_ctr
=
anchor
[
1
]
+
0.5
*
(
h
-
1
)
return
w
,
h
,
x_ctr
,
y_ctr
def
_mkanchors
(
ws
,
hs
,
x_ctr
,
y_ctr
):
"""
Given a vector of widths (ws) and heights (hs) around a center
...
...
@@ -75,6 +78,7 @@ def _mkanchors(ws, hs, x_ctr, y_ctr):
y_ctr
+
0.5
*
(
hs
-
1
)))
return
anchors
def
_ratio_enum
(
anchor
,
ratios
):
"""
Enumerate a set of anchors for each aspect ratio wrt an anchor.
...
...
@@ -88,6 +92,7 @@ def _ratio_enum(anchor, ratios):
anchors
=
_mkanchors
(
ws
,
hs
,
x_ctr
,
y_ctr
)
return
anchors
def
_scale_enum
(
anchor
,
scales
):
"""
Enumerate a set of anchors for each scale wrt an anchor.
...
...
@@ -98,17 +103,3 @@ def _scale_enum(anchor, scales):
hs
=
h
*
scales
anchors
=
_mkanchors
(
ws
,
hs
,
x_ctr
,
y_ctr
)
return
anchors
if
__name__
==
'__main__'
:
#import time
#t = time.time()
#a = generate_anchors()
#print(time.time() - t)
#print(a)
#from IPython import embed; embed()
anchors
=
generate_anchors
(
16
,
scales
=
np
.
asarray
((
2
,
4
,
8
,
16
,
32
),
'float32'
),
ratios
=
[
0.5
,
1
,
2
])
print
(
anchors
)
import
IPython
as
IP
;
IP
.
embed
()
examples/FasterRCNN/viz.py
View file @
ac9ac2a4
# -*- coding: utf-8 -*-
# File: viz.py
from
six.moves
import
zip
import
numpy
as
np
from
six.moves
import
zip
from
tensorpack.utils
import
viz
from
tensorpack.utils.palette
import
PALETTE_RGB
from
utils.np_box_ops
import
iou
as
np_iou
from
config
import
config
as
cfg
from
utils.np_box_ops
import
iou
as
np_iou
def
draw_annotation
(
img
,
boxes
,
klass
,
is_crowd
=
None
):
...
...
examples/GAN/BEGAN.py
View file @
ac9ac2a4
...
...
@@ -3,12 +3,14 @@
# File: BEGAN.py
# Author: Yuxin Wu
import
tensorflow
as
tf
from
tensorpack
import
*
from
tensorpack.tfutils.scope_utils
import
auto_reuse_variable_scope
from
tensorpack.tfutils.summary
import
add_moving_summary
from
tensorpack.utils.gpu
import
get_num_gpu
from
tensorpack.tfutils.scope_utils
import
auto_reuse_variable_scope
import
tensorflow
as
tf
import
DCGAN
from
GAN
import
GANModelDesc
,
GANTrainer
,
MultiGPUGANTrainer
"""
...
...
@@ -19,7 +21,6 @@ A pretrained model on CelebA is at http://models.tensorpack.com/GAN/
"""
import
DCGAN
NH
=
64
NF
=
64
GAMMA
=
0.5
...
...
examples/GAN/ConditionalGAN-mnist.py
View file @
ac9ac2a4
...
...
@@ -3,18 +3,18 @@
# File: ConditionalGAN-mnist.py
# Author: Yuxin Wu
import
argparse
import
numpy
as
np
import
tensorflow
as
tf
import
os
import
cv2
import
argparse
import
tensorflow
as
tf
from
tensorpack
import
*
from
tensorpack.utils.viz
import
interactive_imshow
,
stack_patches
from
tensorpack.tfutils.scope_utils
import
auto_reuse_variable_scope
from
tensorpack.dataflow
import
dataset
from
GAN
import
GANTrainer
,
RandomZData
,
GANModelDesc
from
tensorpack.tfutils.scope_utils
import
auto_reuse_variable_scope
from
tensorpack.utils.viz
import
interactive_imshow
,
stack_patches
from
GAN
import
GANModelDesc
,
GANTrainer
,
RandomZData
"""
To train:
...
...
examples/GAN/CycleGAN.py
View file @
ac9ac2a4
...
...
@@ -3,17 +3,17 @@
# File: CycleGAN.py
# Author: Yuxin Wu
import
os
import
argparse
import
glob
import
os
import
tensorflow
as
tf
from
six.moves
import
range
from
tensorpack
import
*
from
tensorpack.tfutils.summary
import
add_moving_summary
from
tensorpack.tfutils.scope_utils
import
auto_reuse_variable_scope
import
tensorflow
as
tf
from
GAN
import
GANTrainer
,
GANModelDesc
from
tensorpack.tfutils.summary
import
add_moving_summary
from
GAN
import
GANModelDesc
,
GANTrainer
"""
1. Download the dataset following the original project: https://github.com/junyanz/CycleGAN#train
...
...
examples/GAN/DCGAN.py
View file @
ac9ac2a4
...
...
@@ -3,18 +3,18 @@
# File: DCGAN.py
# Author: Yuxin Wu
import
argparse
import
glob
import
numpy
as
np
import
os
import
argparse
import
tensorflow
as
tf
from
tensorpack
import
*
from
tensorpack.utils.viz
import
stack_patches
from
tensorpack.tfutils.scope_utils
import
auto_reuse_variable_scope
import
tensorflow
as
tf
from
tensorpack.utils.viz
import
stack_patches
from
GAN
import
GANModelDesc
,
GANTrainer
,
RandomZData
from
GAN
import
GANTrainer
,
RandomZData
,
GANModelDesc
"""
1. Download the 'aligned&cropped' version of CelebA dataset
...
...
examples/GAN/DiscoGAN-CelebA.py
View file @
ac9ac2a4
...
...
@@ -3,17 +3,17 @@
# File: DiscoGAN-CelebA.py
# Author: Yuxin Wu
import
os
import
argparse
from
six.moves
import
map
,
zip
import
numpy
as
np
import
os
import
tensorflow
as
tf
from
six.moves
import
map
,
zip
from
tensorpack
import
*
from
tensorpack.tfutils.summary
import
add_moving_summary
from
tensorpack.tfutils.scope_utils
import
auto_reuse_variable_scope
import
tensorflow
as
tf
from
GAN
import
SeparateGANTrainer
,
GANModelDesc
from
tensorpack.tfutils.summary
import
add_moving_summary
from
GAN
import
GANModelDesc
,
SeparateGANTrainer
"""
1. Download "aligned&cropped" version of celebA to /path/to/img_align_celeba.
...
...
examples/GAN/GAN.py
View file @
ac9ac2a4
...
...
@@ -2,13 +2,13 @@
# File: GAN.py
# Author: Yuxin Wu
import
tensorflow
as
tf
import
numpy
as
np
from
tensorpack
import
(
TowerTrainer
,
StagingInput
,
ModelDescBase
,
DataFlow
,
argscope
,
BatchNorm
)
from
tensorpack
.tfutils.tower
import
TowerContext
,
TowerFuncWrapper
import
tensorflow
as
tf
from
tensorpack
import
BatchNorm
,
DataFlow
,
ModelDescBase
,
StagingInput
,
TowerTrainer
,
argscope
from
tensorpack.graph_builder
import
DataParallelBuilder
,
LeastLoadedDeviceSetter
from
tensorpack.tfutils.summary
import
add_moving_summary
from
tensorpack.tfutils.tower
import
TowerContext
,
TowerFuncWrapper
from
tensorpack.utils
import
logger
from
tensorpack.utils.argtools
import
memoized_method
from
tensorpack.utils.develop
import
deprecated
...
...
examples/GAN/Image2Image.py
View file @
ac9ac2a4
...
...
@@ -3,20 +3,20 @@
# File: Image2Image.py
# Author: Yuxin Wu
import
cv2
import
numpy
as
np
import
tensorflow
as
tf
import
argparse
import
glob
import
numpy
as
np
import
os
import
argparse
import
cv2
import
tensorflow
as
tf
from
tensorpack
import
*
from
tensorpack.tfutils.scope_utils
import
auto_reuse_variable_scope
from
tensorpack.tfutils.summary
import
add_moving_summary
from
tensorpack.utils.gpu
import
get_num_gpu
from
tensorpack.utils.viz
import
stack_patches
from
tensorpack.tfutils.summary
import
add_moving_summary
from
tensorpack.tfutils.scope_utils
import
auto_reuse_variable_scope
from
GAN
import
GANTrainer
,
GANModelDesc
from
GAN
import
GANModelDesc
,
GANTrainer
"""
To train Image-to-Image translation model with image pairs:
...
...
examples/GAN/Improved-WGAN.py
View file @
ac9ac2a4
...
...
@@ -3,12 +3,14 @@
# File: Improved-WGAN.py
# Author: Yuxin Wu
import
tensorflow
as
tf
from
tensorpack
import
*
from
tensorpack.tfutils
import
get_tf_version_tuple
from
tensorpack.tfutils.summary
import
add_moving_summary
from
tensorpack.tfutils.scope_utils
import
auto_reuse_variable_scope
import
tensorflow
as
tf
from
tensorpack.tfutils.summary
import
add_moving_summary
import
DCGAN
from
GAN
import
SeparateGANTrainer
"""
...
...
@@ -18,7 +20,6 @@ See the docstring in DCGAN.py for usage.
# Don't want to mix two examples together, but want to reuse the code.
# So here just import stuff from DCGAN.
import
DCGAN
class
Model
(
DCGAN
.
Model
):
...
...
examples/GAN/InfoGAN-mnist.py
View file @
ac9ac2a4
...
...
@@ -3,19 +3,19 @@
# File: InfoGAN-mnist.py
# Author: Yuxin Wu
import
cv2
import
argparse
import
numpy
as
np
import
tensorflow
as
tf
import
os
import
argparse
import
cv2
import
tensorflow
as
tf
from
tensorpack
import
*
from
tensorpack.utils
import
viz
from
tensorpack.tfutils.scope_utils
import
auto_reuse_variable_scope
,
under_name_scope
from
tensorpack.tfutils
import
optimizer
,
summary
,
gradproc
from
tensorpack.dataflow
import
dataset
from
GAN
import
GANTrainer
,
GANModelDesc
from
tensorpack.tfutils
import
gradproc
,
optimizer
,
summary
from
tensorpack.tfutils.scope_utils
import
auto_reuse_variable_scope
,
under_name_scope
from
tensorpack.utils
import
viz
from
GAN
import
GANModelDesc
,
GANTrainer
"""
To train:
...
...
examples/GAN/WGAN.py
View file @
ac9ac2a4
...
...
@@ -3,9 +3,12 @@
# File: WGAN.py
# Author: Yuxin Wu
import
tensorflow
as
tf
from
tensorpack
import
*
from
tensorpack.tfutils.summary
import
add_moving_summary
import
tensorflow
as
tf
import
DCGAN
from
GAN
import
SeparateGANTrainer
"""
...
...
@@ -15,7 +18,6 @@ See the docstring in DCGAN.py for usage.
# Don't want to mix two examples together, but want to reuse the code.
# So here just import stuff from DCGAN
import
DCGAN
class
Model
(
DCGAN
.
Model
):
...
...
examples/HED/hed.py
View file @
ac9ac2a4
...
...
@@ -3,19 +3,18 @@
# File: hed.py
# Author: Yuxin Wu
import
argparse
import
numpy
as
np
import
os
import
cv2
import
tensorflow
as
tf
import
numpy
as
np
import
argparse
from
six.moves
import
zip
import
os
from
tensorpack
import
*
from
tensorpack.dataflow
import
dataset
from
tensorpack.utils.gpu
import
get_num_gpu
from
tensorpack.tfutils
import
optimizer
,
gradproc
from
tensorpack.tfutils
import
gradproc
,
optimizer
from
tensorpack.tfutils.summary
import
add_moving_summary
,
add_param_summary
from
tensorpack.utils.gpu
import
get_num_gpu
def
class_balanced_sigmoid_cross_entropy
(
logits
,
label
,
name
=
'cross_entropy_loss'
):
...
...
examples/ImageNetModels/alexnet.py
View file @
ac9ac2a4
...
...
@@ -3,10 +3,9 @@
# File: alexnet.py
import
argparse
import
numpy
as
np
import
os
import
cv2
import
numpy
as
np
import
tensorflow
as
tf
from
tensorpack
import
*
...
...
examples/ImageNetModels/imagenet_utils.py
View file @
ac9ac2a4
...
...
@@ -2,24 +2,22 @@
# File: imagenet_utils.py
import
cv2
import
os
import
numpy
as
np
import
tqdm
import
multiprocessing
import
tensorflow
as
tf
import
numpy
as
np
import
os
from
abc
import
abstractmethod
import
cv2
import
tensorflow
as
tf
import
tqdm
from
tensorpack
import
ModelDesc
from
tensorpack.dataflow
import
AugmentImageComponent
,
BatchData
,
MultiThreadMapData
,
PrefetchDataZMQ
,
dataset
,
imgaug
from
tensorpack.input_source
import
QueueInput
,
StagingInput
from
tensorpack.dataflow
import
(
imgaug
,
dataset
,
AugmentImageComponent
,
PrefetchDataZMQ
,
BatchData
,
MultiThreadMapData
)
from
tensorpack.predict
import
PredictConfig
,
FeedfreePredictor
from
tensorpack.utils.stats
import
RatioCounter
from
tensorpack.models
import
regularize_cost
from
tensorpack.predict
import
FeedfreePredictor
,
PredictConfig
from
tensorpack.tfutils.summary
import
add_moving_summary
from
tensorpack.utils
import
logger
from
tensorpack.utils.stats
import
RatioCounter
"""
...
...
examples/ImageNetModels/inception-bn.py
View file @
ac9ac2a4
...
...
@@ -7,10 +7,9 @@ import argparse
import
os
import
tensorflow
as
tf
from
tensorpack
import
*
from
tensorpack.tfutils.summary
import
add_moving_summary
from
tensorpack.dataflow
import
dataset
from
tensorpack.tfutils.summary
import
add_moving_summary
from
tensorpack.utils.gpu
import
get_num_gpu
from
imagenet_utils
import
fbresnet_augmentor
,
get_imagenet_dataflow
...
...
examples/ImageNetModels/shufflenet.py
View file @
ac9ac2a4
...
...
@@ -3,24 +3,20 @@
# File: shufflenet.py
import
argparse
import
numpy
as
np
import
math
import
numpy
as
np
import
os
import
cv2
import
tensorflow
as
tf
from
tensorpack
import
*
from
tensorpack.dataflow
import
imgaug
from
tensorpack.tfutils
import
argscope
,
get_model_loader
,
model_utils
from
tensorpack.tfutils.scope_utils
import
under_name_scope
from
tensorpack.utils.gpu
import
get_num_gpu
from
tensorpack.utils
import
logger
from
tensorpack.utils.gpu
import
get_num_gpu
from
imagenet_utils
import
(
get_imagenet_dataflow
,
ImageNetModel
,
GoogleNetResize
,
eval_on_ILSVRC12
)
from
imagenet_utils
import
GoogleNetResize
,
ImageNetModel
,
eval_on_ILSVRC12
,
get_imagenet_dataflow
@
layer_register
(
log_shape
=
True
)
...
...
examples/ImageNetModels/vgg16.py
View file @
ac9ac2a4
...
...
@@ -4,7 +4,6 @@
import
argparse
import
os
import
tensorflow
as
tf
from
tensorpack
import
*
...
...
@@ -12,8 +11,7 @@ from tensorpack.tfutils import argscope
from
tensorpack.tfutils.summary
import
*
from
tensorpack.utils.gpu
import
get_num_gpu
from
imagenet_utils
import
(
ImageNetModel
,
get_imagenet_dataflow
,
fbresnet_augmentor
)
from
imagenet_utils
import
ImageNetModel
,
fbresnet_augmentor
,
get_imagenet_dataflow
def
GroupNorm
(
x
,
group
,
gamma_initializer
=
tf
.
constant_initializer
(
1.
)):
...
...
examples/OpticalFlow/flownet2.py
View file @
ac9ac2a4
...
...
@@ -2,16 +2,16 @@
# -*- coding: utf-8 -*-
# Author: Patrick Wieschollek <mail@patwie.com>
import
argparse
import
glob
import
os
import
cv2
import
glob
from
helper
import
Flow
import
argparse
from
tensorpack
import
*
from
tensorpack.utils
import
viz
import
flownet_models
as
models
from
helper
import
Flow
def
apply
(
model
,
model_path
,
left
,
right
,
ground_truth
=
None
):
...
...
examples/OpticalFlow/flownet_models.py
View file @
ac9ac2a4
...
...
@@ -4,6 +4,7 @@
import
tensorflow
as
tf
from
tensorpack
import
ModelDesc
,
argscope
,
enable_argscope_for_module
enable_argscope_for_module
(
tf
.
layers
)
...
...
examples/PennTreebank/PTB-LSTM.py
View file @
ac9ac2a4
...
...
@@ -3,21 +3,20 @@
# File: PTB-LSTM.py
# Author: Yuxin Wu
import
argparse
import
numpy
as
np
import
os
import
argparse
import
tensorflow
as
tf
from
tensorpack
import
*
from
tensorpack.tfutils
import
optimizer
,
summary
,
gradproc
from
tensorpack.tfutils
import
gradproc
,
optimizer
,
summary
from
tensorpack.utils
import
logger
from
tensorpack.utils.fs
import
download
,
get_dataset_path
from
tensorpack.utils.argtools
import
memoized_ignoreargs
from
tensorpack.utils.fs
import
download
,
get_dataset_path
import
reader
as
tfreader
from
reader
import
ptb_producer
import
tensorflow
as
tf
rnn
=
tf
.
contrib
.
rnn
SEQ_LEN
=
35
...
...
examples/PennTreebank/reader.py
View file @
ac9ac2a4
...
...
@@ -16,13 +16,9 @@
"""Utilities for parsing PTB text files."""
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
from
__future__
import
absolute_import
,
division
,
print_function
import
collections
import
os
import
tensorflow
as
tf
...
...
examples/ResNet/cifar10-preact18-mixup.py
View file @
ac9ac2a4
...
...
@@ -3,14 +3,14 @@
# File: cifar10-preact18-mixup.py
# Author: Tao Hu <taohu620@gmail.com>, Yauheni Selivonchyk <y.selivonchyk@gmail.com>
import
numpy
as
np
import
argparse
import
numpy
as
np
import
os
import
tensorflow
as
tf
from
tensorpack
import
*
from
tensorpack.tfutils.summary
import
*
from
tensorpack.dataflow
import
dataset
from
tensorpack.tfutils.summary
import
*
BATCH_SIZE
=
128
CLASS_NUM
=
10
...
...
examples/ResNet/cifar10-resnet.py
View file @
ac9ac2a4
...
...
@@ -5,14 +5,12 @@
import
argparse
import
os
import
tensorflow
as
tf
from
tensorpack
import
*
from
tensorpack.dataflow
import
dataset
from
tensorpack.tfutils.summary
import
add_moving_summary
,
add_param_summary
from
tensorpack.utils.gpu
import
get_num_gpu
from
tensorpack.dataflow
import
dataset
import
tensorflow
as
tf
"""
CIFAR10 ResNet example. See:
...
...
examples/ResNet/imagenet-resnet.py
View file @
ac9ac2a4
...
...
@@ -5,22 +5,18 @@
import
argparse
import
os
from
tensorpack
import
logger
,
QueueInput
,
TFDatasetInput
from
tensorpack.models
import
*
from
tensorpack
import
QueueInput
,
TFDatasetInput
,
logger
from
tensorpack.callbacks
import
*
from
tensorpack.train
import
(
TrainConfig
,
SyncMultiGPUTrainerReplicated
,
launch_train_with_config
)
from
tensorpack.dataflow
import
FakeData
from
tensorpack.models
import
*
from
tensorpack.tfutils
import
argscope
,
get_model_loader
from
tensorpack.train
import
SyncMultiGPUTrainerReplicated
,
TrainConfig
,
launch_train_with_config
from
tensorpack.utils.gpu
import
get_num_gpu
from
imagenet_utils
import
(
get_imagenet_dataflow
,
get_imagenet_tfdata
,
ImageNetModel
,
eval_on_ILSVRC12
)
from
imagenet_utils
import
ImageNetModel
,
eval_on_ILSVRC12
,
get_imagenet_dataflow
,
get_imagenet_tfdata
from
resnet_model
import
(
preresnet_group
,
preresnet_basicblock
,
preresnet_bottleneck
,
resnet_group
,
resnet_basicblock
,
resnet_bottleneck
,
se_resnet_bottleneck
,
resnet_backbone
)
preresnet_basicblock
,
preresnet_bottleneck
,
preresnet_group
,
resnet_backbone
,
resnet_basicblock
,
resnet_bottleneck
,
resnet_group
,
se_resnet_bottleneck
)
class
Model
(
ImageNetModel
):
...
...
examples/ResNet/load-resnet.py
View file @
ac9ac2a4
...
...
@@ -4,20 +4,20 @@
# Author: Eric Yujia Huang <yujiah1@andrew.cmu.edu>
# Yuxin Wu
import
cv2
import
functools
import
tensorflow
as
tf
import
argparse
import
re
import
functools
import
numpy
as
np
import
re
import
cv2
import
six
import
tensorflow
as
tf
from
tensorpack
import
*
from
tensorpack.utils
import
logger
from
tensorpack.dataflow.dataset
import
ILSVRCMeta
from
tensorpack.utils
import
logger
from
imagenet_utils
import
eval_on_ILSVRC12
,
get_imagenet_dataflow
,
ImageNetModel
from
resnet_model
import
resnet_
group
,
resnet_bottleneck
from
imagenet_utils
import
ImageNetModel
,
eval_on_ILSVRC12
,
get_imagenet_dataflow
from
resnet_model
import
resnet_
bottleneck
,
resnet_group
DEPTH
=
None
CFG
=
{
...
...
examples/ResNet/resnet_model.py
View file @
ac9ac2a4
...
...
@@ -3,9 +3,8 @@
import
tensorflow
as
tf
from
tensorpack.models
import
BatchNorm
,
BNReLU
,
Conv2D
,
FullyConnected
,
GlobalAvgPooling
,
MaxPooling
from
tensorpack.tfutils.argscope
import
argscope
,
get_arg_scope
from
tensorpack.models
import
(
Conv2D
,
MaxPooling
,
GlobalAvgPooling
,
BatchNorm
,
BNReLU
,
FullyConnected
)
def
resnet_shortcut
(
l
,
n_out
,
stride
,
activation
=
tf
.
identity
):
...
...
examples/Saliency/CAM-resnet.py
View file @
ac9ac2a4
...
...
@@ -2,28 +2,24 @@
# -*- coding: utf-8 -*-
# File: CAM-resnet.py
import
cv2
import
sys
import
argparse
import
multiprocessing
import
numpy
as
np
import
os
import
multiprocessing
import
sys
import
cv2
import
tensorflow
as
tf
from
tensorpack
import
*
from
tensorpack.dataflow
import
dataset
from
tensorpack.tfutils
import
optimizer
,
gradproc
from
tensorpack.tfutils.symbolic_functions
import
*
from
tensorpack.tfutils
import
gradproc
,
optimizer
from
tensorpack.tfutils.summary
import
*
from
tensorpack.
utils.gpu
import
get_num_gpu
from
tensorpack.
tfutils.symbolic_functions
import
*
from
tensorpack.utils
import
viz
from
tensorpack.utils.gpu
import
get_num_gpu
from
imagenet_utils
import
(
fbresnet_augmentor
,
ImageNetModel
)
from
resnet_model
import
(
preresnet_basicblock
,
preresnet_group
)
from
imagenet_utils
import
ImageNetModel
,
fbresnet_augmentor
from
resnet_model
import
preresnet_basicblock
,
preresnet_group
TOTAL_BATCH_SIZE
=
256
DEPTH
=
None
...
...
examples/Saliency/saliency-maps.py
View file @
ac9ac2a4
#!/usr/bin/env python
# -*- coding: utf-8 -*-
import
cv2
import
numpy
as
np
import
sys
from
contextlib
import
contextmanager
import
numpy
as
np
import
cv2
import
tensorflow
as
tf
import
tensorflow.contrib.slim
as
slim
from
tensorflow.contrib.slim.nets
import
resnet_v1
...
...
examples/SimilarityLearning/embedding_data.py
View file @
ac9ac2a4
...
...
@@ -3,7 +3,8 @@
# Author: tensorpack contributors
import
numpy
as
np
from
tensorpack.dataflow
import
dataset
,
BatchData
from
tensorpack.dataflow
import
BatchData
,
dataset
def
get_test_data
(
batch
=
128
):
...
...
examples/SimilarityLearning/mnist-embeddings.py
View file @
ac9ac2a4
...
...
@@ -2,17 +2,16 @@
# -*- coding: utf-8 -*-
# File: mnist-embeddings.py
import
numpy
as
np
import
argparse
import
numpy
as
np
import
tensorflow
as
tf
import
tensorflow.contrib.slim
as
slim
from
tensorpack
import
*
from
tensorpack.tfutils.summary
import
add_moving_summary
from
tensorpack.utils.gpu
import
change_gpu
from
embedding_data
import
get_test_data
,
MnistPairs
,
MnistTriplets
from
embedding_data
import
MnistPairs
,
MnistTriplets
,
get_test_data
MATPLOTLIB_AVAIBLABLE
=
False
try
:
...
...
examples/SpatialTransformer/mnist-addition.py
View file @
ac9ac2a4
...
...
@@ -3,16 +3,15 @@
# File: mnist-addition.py
# Author: Yuxin Wu
import
cv2
import
argparse
import
numpy
as
np
import
tensorflow
as
tf
import
os
import
argparse
import
cv2
import
tensorflow
as
tf
from
tensorpack
import
*
from
tensorpack.dataflow
import
dataset
from
tensorpack.tfutils
import
optimizer
,
summary
,
gradproc
from
tensorpack.tfutils
import
gradproc
,
optimizer
,
summary
IMAGE_SIZE
=
42
WARP_TARGET_SIZE
=
28
...
...
examples/SuperResolution/data_sampler.py
View file @
ac9ac2a4
import
cv2
import
os
import
argparse
import
numpy
as
np
import
os
import
zipfile
from
tensorpack
import
RNGDataFlow
,
MapDataComponent
,
LMDBSerializer
import
cv2
from
tensorpack
import
LMDBSerializer
,
MapDataComponent
,
RNGDataFlow
class
ImageDataFromZIPFile
(
RNGDataFlow
):
...
...
examples/SuperResolution/enet-pat.py
View file @
ac9ac2a4
...
...
@@ -2,11 +2,11 @@
# -*- coding: utf-8 -*-
# Author: Patrick Wieschollek <mail@patwie.com>
import
os
import
argparse
import
numpy
as
np
import
os
import
cv2
import
six
import
numpy
as
np
import
tensorflow
as
tf
from
tensorpack
import
*
...
...
@@ -14,10 +14,10 @@ from tensorpack.tfutils.scope_utils import auto_reuse_variable_scope
from
tensorpack.tfutils.summary
import
add_moving_summary
from
tensorpack.utils
import
logger
from
tensorpack.utils.gpu
import
get_num_gpu
from
data_sampler
import
(
ImageDecode
,
ImageDataFromZIPFile
,
RejectTooSmallImages
,
CenterSquareResize
)
from
GAN
import
SeparateGANTrainer
,
GANModelDesc
from
data_sampler
import
CenterSquareResize
,
ImageDataFromZIPFile
,
ImageDecode
,
RejectTooSmallImages
from
GAN
import
GANModelDesc
,
SeparateGANTrainer
Reduction
=
tf
.
losses
.
Reduction
BATCH_SIZE
=
16
...
...
examples/basics/cifar-convnet.py
View file @
ac9ac2a4
...
...
@@ -2,15 +2,16 @@
# -*- coding: utf-8 -*-
# File: cifar-convnet.py
# Author: Yuxin Wu
import
tensorflow
as
tf
import
argparse
import
os
import
tensorflow
as
tf
from
tensorpack
import
*
from
tensorpack.tfutils.summary
import
*
from
tensorpack.dataflow
import
dataset
from
tensorpack.tfutils.summary
import
*
from
tensorpack.utils.gpu
import
get_num_gpu
"""
A small convnet model for Cifar10 or Cifar100 dataset.
...
...
examples/basics/export-model.py
View file @
ac9ac2a4
...
...
@@ -4,6 +4,7 @@
import
argparse
import
cv2
import
tensorflow
as
tf
from
tensorpack
import
*
from
tensorpack.tfutils.export
import
ModelExporter
...
...
examples/basics/mnist-convnet.py
View file @
ac9ac2a4
...
...
@@ -3,17 +3,16 @@
# File: mnist-convnet.py
import
tensorflow
as
tf
from
tensorpack
import
*
from
tensorpack.dataflow
import
dataset
from
tensorpack.tfutils
import
summary
"""
MNIST ConvNet example.
about 0.6
%
validation error after 30 epochs.
"""
# Just import everything into current namespace
from
tensorpack
import
*
from
tensorpack.tfutils
import
summary
from
tensorpack.dataflow
import
dataset
IMAGE_SIZE
=
28
...
...
examples/basics/mnist-tflayers.py
View file @
ac9ac2a4
...
...
@@ -3,6 +3,11 @@
# File: mnist-tflayers.py
import
tensorflow
as
tf
from
tensorpack
import
*
from
tensorpack.dataflow
import
dataset
from
tensorpack.tfutils
import
get_current_tower_context
,
summary
"""
MNIST ConvNet example using tf.layers
Mostly the same as 'mnist-convnet.py',
...
...
@@ -11,12 +16,6 @@ the only differences are:
2. use tf.layers variable names to summarize weights
"""
# Just import everything into current namespace
from
tensorpack
import
*
from
tensorpack.tfutils
import
summary
,
get_current_tower_context
from
tensorpack.dataflow
import
dataset
IMAGE_SIZE
=
28
# Monkey-patch tf.layers to support argscope.
enable_argscope_for_module
(
tf
.
layers
)
...
...
examples/basics/mnist-tfslim.py
View file @
ac9ac2a4
...
...
@@ -11,11 +11,12 @@ the only differences are:
"""
from
tensorpack
import
*
from
tensorpack.dataflow
import
dataset
import
tensorflow
as
tf
import
tensorflow.contrib.slim
as
slim
from
tensorpack
import
*
from
tensorpack.dataflow
import
dataset
IMAGE_SIZE
=
28
...
...
examples/basics/mnist-visualizations.py
View file @
ac9ac2a4
...
...
@@ -7,6 +7,7 @@ The same MNIST ConvNet example, but with weights/activations visualization.
"""
import
tensorflow
as
tf
from
tensorpack
import
*
from
tensorpack.dataflow
import
dataset
...
...
examples/basics/svhn-digit-convnet.py
View file @
ac9ac2a4
...
...
@@ -5,12 +5,12 @@
import
argparse
import
os
import
tensorflow
as
tf
from
tensorpack
import
*
from
tensorpack.dataflow
import
dataset
from
tensorpack.tfutils.summary
import
*
import
tensorflow
as
tf
"""
A very small SVHN convnet model (only 0.8m parameters).
...
...
examples/boilerplate.py
View file @
ac9ac2a4
# -*- coding: utf-8 -*-
# Author: Your Name <your@email.com>
import
os
import
argparse
import
os
import
tensorflow
as
tf
from
tensorpack
import
*
...
...
examples/keras/imagenet-resnet-keras.py
View file @
ac9ac2a4
...
...
@@ -3,22 +3,21 @@
# File: imagenet-resnet-keras.py
# Author: Yuxin Wu
import
argparse
import
numpy
as
np
import
os
import
tensorflow
as
tf
import
argparse
from
tensorflow.python.keras.layers
import
*
from
tensorpack
import
InputDesc
,
SyncMultiGPUTrainerReplicated
from
tensorpack.callbacks
import
*
from
tensorpack.contrib.keras
import
KerasModel
from
tensorpack.dataflow
import
FakeData
,
MapDataComponent
from
tensorpack.tfutils.common
import
get_tf_version_tuple
from
tensorpack.utils
import
logger
from
tensorpack.utils.gpu
import
get_num_gpu
from
tensorpack.contrib.keras
import
KerasModel
from
tensorpack.callbacks
import
*
from
tensorflow.python.keras.layers
import
*
from
tensorpack.tfutils.common
import
get_tf_version_tuple
from
imagenet_utils
import
get_imagenet_dataflow
,
fbresnet_augmentor
from
imagenet_utils
import
fbresnet_augmentor
,
get_imagenet_dataflow
TOTAL_BATCH_SIZE
=
512
BASE_LR
=
0.1
*
(
TOTAL_BATCH_SIZE
//
256
)
...
...
examples/keras/mnist-keras-v2.py
View file @
ac9ac2a4
...
...
@@ -5,17 +5,15 @@
import
numpy
as
np
import
tensorflow
as
tf
from
tensorflow
import
keras
KL
=
keras
.
layers
from
tensorpack
import
InputDesc
,
QueueInput
from
tensorpack.dataflow
import
dataset
,
BatchData
,
MapData
from
tensorpack.utils
import
logger
from
tensorpack.contrib.keras
import
KerasModel
from
tensorpack.callbacks
import
ModelSaver
from
tensorpack.contrib.keras
import
KerasModel
from
tensorpack.dataflow
import
BatchData
,
MapData
,
dataset
from
tensorpack.utils
import
logger
KL
=
keras
.
layers
IMAGE_SIZE
=
28
...
...
examples/keras/mnist-keras.py
View file @
ac9ac2a4
...
...
@@ -5,6 +5,12 @@
import
tensorflow
as
tf
from
tensorflow
import
keras
from
tensorpack
import
*
from
tensorpack.contrib.keras
import
KerasPhaseCallback
from
tensorpack.dataflow
import
dataset
from
tensorpack.utils.argtools
import
memoized
KL
=
keras
.
layers
"""
...
...
@@ -14,12 +20,6 @@ This way you can define models in Keras-style, and benefit from the more efficei
Note: this example does not work for replicated-style data-parallel trainers.
"""
from
tensorpack
import
*
from
tensorpack.dataflow
import
dataset
from
tensorpack.utils.argtools
import
memoized
from
tensorpack.contrib.keras
import
KerasPhaseCallback
IMAGE_SIZE
=
28
...
...
scripts/checkpoint-manipulate.py
View file @
ac9ac2a4
...
...
@@ -3,11 +3,11 @@
# File: checkpoint-manipulate.py
import
argparse
import
numpy
as
np
from
tensorpack.tfutils.varmanip
import
load_chkpt_vars
from
tensorpack.utils
import
logger
import
argparse
if
__name__
==
'__main__'
:
parser
=
argparse
.
ArgumentParser
()
...
...
scripts/checkpoint-prof.py
View file @
ac9ac2a4
...
...
@@ -2,12 +2,13 @@
# -*- coding: utf-8 -*-
# File: checkpoint-prof.py
import
tensorflow
as
tf
import
argparse
import
numpy
as
np
import
tensorflow
as
tf
from
tensorpack
import
get_default_sess_config
,
get_op_tensor_name
from
tensorpack.utils
import
logger
from
tensorpack.tfutils.sessinit
import
get_model_loader
import
argparse
from
tensorpack.utils
import
logger
if
__name__
==
'__main__'
:
parser
=
argparse
.
ArgumentParser
()
...
...
scripts/dump-model-params.py
View file @
ac9ac2a4
...
...
@@ -2,10 +2,10 @@
# -*- coding: utf-8 -*-
# File: dump-model-params.py
import
numpy
as
np
import
six
import
argparse
import
numpy
as
np
import
os
import
six
import
tensorflow
as
tf
from
tensorpack.tfutils
import
varmanip
...
...
scripts/ls-checkpoint.py
View file @
ac9ac2a4
...
...
@@ -2,11 +2,11 @@
# -*- coding: utf-8 -*-
# File: ls-checkpoint.py
import
tensorflow
as
tf
import
numpy
as
np
import
six
import
sys
import
pprint
import
sys
import
six
import
tensorflow
as
tf
from
tensorpack.tfutils.varmanip
import
get_checkpoint_path
...
...
setup.py
View file @
ac9ac2a4
import
platform
from
os
import
path
import
setuptools
from
setuptools
import
setup
from
os
import
path
import
platform
version
=
int
(
setuptools
.
__version__
.
split
(
'.'
)[
0
])
assert
version
>
30
,
"Tensorpack installation requires setuptools > 30"
...
...
@@ -24,7 +24,7 @@ def add_git_version():
from
subprocess
import
check_output
try
:
return
check_output
(
"git describe --tags --long --dirty"
.
split
())
.
decode
(
'utf-8'
)
.
strip
()
except
:
except
Exception
:
return
__version__
# noqa
newlibinfo_content
=
[
l
for
l
in
libinfo_content
if
not
l
.
startswith
(
'__git_version__'
)]
...
...
tensorpack/callbacks/base.py
View file @
ac9ac2a4
...
...
@@ -2,9 +2,10 @@
# File: base.py
import
tensorflow
as
tf
from
abc
import
ABCMeta
import
six
import
tensorflow
as
tf
from
..tfutils.common
import
get_op_or_tensor_by_name
__all__
=
[
'Callback'
,
'ProxyCallback'
,
'CallbackFactory'
]
...
...
tensorpack/callbacks/concurrency.py
View file @
ac9ac2a4
...
...
@@ -3,9 +3,10 @@
import
multiprocessing
as
mp
from
.base
import
Callback
from
..utils.concurrency
import
start_proc_mask_signal
,
StoppableThread
from
..utils
import
logger
from
..utils.concurrency
import
StoppableThread
,
start_proc_mask_signal
from
.base
import
Callback
__all__
=
[
'StartProcOrThread'
]
...
...
tensorpack/callbacks/graph.py
View file @
ac9ac2a4
...
...
@@ -4,14 +4,14 @@
""" Graph related callbacks"""
import
tensorflow
as
tf
import
os
import
numpy
as
np
import
os
import
tensorflow
as
tf
from
six.moves
import
zip
from
..tfutils.common
import
get_op_tensor_name
from
..utils
import
logger
from
.base
import
Callback
from
..tfutils.common
import
get_op_tensor_name
__all__
=
[
'RunOp'
,
'RunUpdateOps'
,
'ProcessTensors'
,
'DumpTensors'
,
'DumpTensor'
,
'DumpTensorAsImage'
,
'DumpParamAsImage'
]
...
...
tensorpack/callbacks/group.py
View file @
ac9ac2a4
...
...
@@ -2,16 +2,16 @@
# File: group.py
import
t
ensorflow
as
tf
import
t
raceback
from
contextlib
import
contextmanager
from
time
import
time
as
timer
import
traceback
import
six
import
tensorflow
as
tf
from
.base
import
Callback
from
.hooks
import
CallbackToHook
from
..utils
import
logger
from
..utils.utils
import
humanize_time_delta
from
.base
import
Callback
from
.hooks
import
CallbackToHook
if
six
.
PY3
:
from
time
import
perf_counter
as
timer
# noqa
...
...
tensorpack/callbacks/hooks.py
View file @
ac9ac2a4
...
...
@@ -5,6 +5,7 @@
""" Compatible layers between tf.train.SessionRunHook and Callback"""
import
tensorflow
as
tf
from
.base
import
Callback
__all__
=
[
'CallbackToHook'
,
'HookToCallback'
]
...
...
tensorpack/callbacks/inference.py
View file @
ac9ac2a4
...
...
@@ -7,10 +7,10 @@ from abc import ABCMeta
import
six
from
six.moves
import
zip
from
.base
import
Callback
from
..utils
import
logger
from
..utils.stats
import
RatioCounter
,
BinaryStatistics
from
..tfutils.common
import
get_op_tensor_name
from
..utils
import
logger
from
..utils.stats
import
BinaryStatistics
,
RatioCounter
from
.base
import
Callback
__all__
=
[
'ScalarStats'
,
'Inferencer'
,
'ClassificationError'
,
'BinaryClassificationStats'
]
...
...
tensorpack/callbacks/inference_runner.py
View file @
ac9ac2a4
...
...
@@ -2,24 +2,19 @@
# File: inference_runner.py
import
sys
import
tensorflow
as
tf
from
tensorflow.python.training.monitored_session
\
import
_HookedSession
as
HookedSession
import
itertools
import
sys
from
contextlib
import
contextmanager
import
tensorflow
as
tf
import
tqdm
from
six.moves
import
range
from
tensorflow.python.training.monitored_session
import
_HookedSession
as
HookedSession
from
..utils
import
logger
from
..utils.utils
import
get_tqdm_kwargs
from
..dataflow.base
import
DataFlow
from
..input_source
import
FeedInput
,
InputSource
,
QueueInput
,
StagingInput
from
..tfutils.tower
import
PredictTowerContext
from
..input_source
import
(
InputSource
,
FeedInput
,
QueueInput
,
StagingInput
)
from
..utils
import
logger
from
..utils.utils
import
get_tqdm_kwargs
from
.base
import
Callback
from
.group
import
Callbacks
from
.inference
import
Inferencer
...
...
tensorpack/callbacks/misc.py
View file @
ac9ac2a4
...
...
@@ -2,14 +2,14 @@
# File: misc.py
import
numpy
as
np
import
os
import
time
from
collections
import
deque
import
numpy
as
np
from
.base
import
Callback
from
..utils.utils
import
humanize_time_delta
from
..utils
import
logger
from
..utils.utils
import
humanize_time_delta
from
.base
import
Callback
__all__
=
[
'SendStat'
,
'InjectShell'
,
'EstimatedTimeLeft'
]
...
...
tensorpack/callbacks/monitor.py
View file @
ac9ac2a4
...
...
@@ -2,20 +2,20 @@
# File: monitor.py
import
os
import
json
import
numpy
as
np
import
operator
import
os
import
re
import
shutil
import
time
from
datetime
import
datetime
import
operator
from
collections
import
defaultdict
from
datetime
import
datetime
import
six
import
json
import
re
import
tensorflow
as
tf
from
..tfutils.summary
import
create_image_summary
,
create_scalar_summary
from
..utils
import
logger
from
..tfutils.summary
import
create_scalar_summary
,
create_image_summary
from
..utils.develop
import
HIDE_DOC
from
.base
import
Callback
...
...
tensorpack/callbacks/param.py
View file @
ac9ac2a4
...
...
@@ -2,16 +2,16 @@
# File: param.py
import
tensorflow
as
tf
from
collections
import
deque
from
abc
import
abstractmethod
,
ABCMeta
import
operator
import
six
import
os
from
abc
import
ABCMeta
,
abstractmethod
from
collections
import
deque
import
six
import
tensorflow
as
tf
from
.base
import
Callback
from
..utils
import
logger
from
..tfutils.common
import
get_op_tensor_name
from
..utils
import
logger
from
.base
import
Callback
__all__
=
[
'HyperParam'
,
'GraphVarParam'
,
'ObjAttrParam'
,
'HyperParamSetter'
,
'HumanHyperParamSetter'
,
...
...
tensorpack/callbacks/prof.py
View file @
ac9ac2a4
...
...
@@ -2,20 +2,20 @@
# File: prof.py
import
os
import
numpy
as
np
import
multiprocessing
as
mp
import
numpy
as
np
import
os
import
time
from
six.moves
import
map
import
tensorflow
as
tf
from
six.moves
import
map
from
tensorflow.python.client
import
timeline
from
.
base
import
Callback
from
.
.tfutils.common
import
gpu_available_in_session
from
..utils
import
logger
from
..utils.concurrency
import
ensure_proc_terminate
,
start_proc_mask_signal
from
..utils.gpu
import
get_num_gpu
from
..utils.nvml
import
NVMLContext
from
.
.tfutils.common
import
gpu_available_in_session
from
.
base
import
Callback
__all__
=
[
'GPUUtilizationTracker'
,
'GraphProfiler'
,
'PeakMemoryTracker'
]
...
...
tensorpack/callbacks/saver.py
View file @
ac9ac2a4
...
...
@@ -2,12 +2,12 @@
# File: saver.py
import
tensorflow
as
tf
from
datetime
import
datetime
import
os
from
datetime
import
datetime
import
tensorflow
as
tf
from
.base
import
Callback
from
..utils
import
logger
from
.base
import
Callback
__all__
=
[
'ModelSaver'
,
'MinSaver'
,
'MaxSaver'
]
...
...
tensorpack/callbacks/stats.py
View file @
ac9ac2a4
# -*- coding: utf-8 -*-
# File: stats.py
from
.graph
import
DumpParamAsImage
# noqa
# for compatibility only
from
.misc
import
InjectShell
,
SendStat
# noqa
from
.graph
import
DumpParamAsImage
# noqa
__all__
=
[]
tensorpack/callbacks/steps.py
View file @
ac9ac2a4
...
...
@@ -4,14 +4,13 @@
""" Some common step callbacks. """
import
tensorflow
as
tf
from
six.moves
import
zip
import
tqdm
from
six.moves
import
zip
from
..tfutils.common
import
get_global_step_var
,
get_op_tensor_name
from
..utils
import
logger
from
..utils.utils
import
get_tqdm_kwargs
from
..utils.naming
import
GLOBAL_STEP_INCR_OP_NAME
from
..tfutils.common
import
(
get_op_tensor_name
,
get_global_step_var
)
from
..utils.utils
import
get_tqdm_kwargs
from
.base
import
Callback
__all__
=
[
'TensorPrinter'
,
'ProgressBar'
,
'SessionRunTimeout'
]
...
...
tensorpack/callbacks/summary.py
View file @
ac9ac2a4
...
...
@@ -2,9 +2,9 @@
# File: summary.py
import
tensorflow
as
tf
import
numpy
as
np
from
collections
import
deque
import
tensorflow
as
tf
from
..tfutils.common
import
get_op_tensor_name
from
..utils
import
logger
...
...
tensorpack/callbacks/trigger.py
View file @
ac9ac2a4
...
...
@@ -2,8 +2,8 @@
# File: trigger.py
from
.base
import
ProxyCallback
,
Callback
from
..utils.develop
import
log_deprecated
from
.base
import
Callback
,
ProxyCallback
__all__
=
[
'PeriodicTrigger'
,
'PeriodicCallback'
,
'EnableCallbackIf'
]
...
...
tensorpack/contrib/keras.py
View file @
ac9ac2a4
# -*- coding: utf-8 -*-
# File: keras.py
import
tensorflow
as
tf
from
contextlib
import
contextmanager
import
six
from
tensorflow
import
keras
import
tensorflow
as
tf
import
tensorflow.keras.backend
as
K
from
tensorflow
import
keras
from
tensorflow.python.keras
import
metrics
as
metrics_module
from
contextlib
import
contextmanager
from
..callbacks
import
Callback
,
CallbackToHook
,
InferenceRunner
,
InferenceRunnerBase
,
ScalarStats
from
..models.regularize
import
regularize_cost_from_collection
from
..train
import
Trainer
,
SimpleTrainer
,
SyncMultiGPUTrainerParameterServer
from
..train.trainers
import
DistributedTrainerBase
from
..train.interface
import
apply_default_prefetch
from
..callbacks
import
(
Callback
,
InferenceRunnerBase
,
InferenceRunner
,
CallbackToHook
,
ScalarStats
)
from
..tfutils.common
import
get_op_tensor_name
from
..tfutils.collection
import
backup_collection
,
restore_collection
from
..tfutils.
tower
import
get_current_tower_context
from
..tfutils.
common
import
get_op_tensor_name
from
..tfutils.scope_utils
import
cached_name_scope
from
..tfutils.summary
import
add_moving_summary
from
..utils.gpu
import
get_nr_gpu
from
..tfutils.tower
import
get_current_tower_context
from
..train
import
SimpleTrainer
,
SyncMultiGPUTrainerParameterServer
,
Trainer
from
..train.interface
import
apply_default_prefetch
from
..train.trainers
import
DistributedTrainerBase
from
..utils
import
logger
from
..utils.gpu
import
get_nr_gpu
__all__
=
[
'KerasPhaseCallback'
,
'setup_keras_trainer'
,
'KerasModel'
]
...
...
tensorpack/dataflow/base.py
View file @
ac9ac2a4
...
...
@@ -3,8 +3,9 @@
import
threading
from
abc
import
abstractmethod
,
ABCMeta
from
abc
import
ABCMeta
,
abstractmethod
import
six
from
..utils.utils
import
get_rng
__all__
=
[
'DataFlow'
,
'ProxyDataFlow'
,
'RNGDataFlow'
,
'DataFlowTerminated'
]
...
...
tensorpack/dataflow/common.py
View file @
ac9ac2a4
...
...
@@ -2,20 +2,20 @@
# File: common.py
from
__future__
import
division
import
six
import
itertools
import
numpy
as
np
from
copy
import
copy
import
pprint
import
itertools
from
termcolor
import
colored
from
collections
import
deque
,
defaultdict
from
six.moves
import
range
,
map
from
collections
import
defaultdict
,
deque
from
copy
import
copy
import
six
import
tqdm
from
six.moves
import
map
,
range
from
termcolor
import
colored
from
.base
import
DataFlow
,
ProxyDataFlow
,
RNGDataFlow
,
DataFlowReentrantGuard
from
..utils
import
logger
from
..utils.utils
import
get_tqdm
,
get_rng
,
get_tqdm_kwargs
from
..utils.develop
import
log_deprecated
from
..utils.utils
import
get_rng
,
get_tqdm
,
get_tqdm_kwargs
from
.base
import
DataFlow
,
DataFlowReentrantGuard
,
ProxyDataFlow
,
RNGDataFlow
__all__
=
[
'TestDataSpeed'
,
'PrintData'
,
'BatchData'
,
'BatchDataByShape'
,
'FixedSizeData'
,
'MapData'
,
'MapDataComponent'
,
'RepeatedData'
,
'RepeatedDataPoint'
,
'RandomChooseData'
,
...
...
tensorpack/dataflow/dataset/bsds500.py
View file @
ac9ac2a4
...
...
@@ -2,9 +2,9 @@
# File: bsds500.py
import
os
import
glob
import
numpy
as
np
import
os
from
...utils.fs
import
download
,
get_dataset_path
from
..base
import
RNGDataFlow
...
...
tensorpack/dataflow/dataset/cifar.py
View file @
ac9ac2a4
...
...
@@ -3,9 +3,9 @@
# Yukun Chen <cykustc@gmail.com>
import
numpy
as
np
import
os
import
pickle
import
numpy
as
np
import
tarfile
import
six
from
six.moves
import
range
...
...
tensorpack/dataflow/dataset/ilsvrc.py
View file @
ac9ac2a4
# -*- coding: utf-8 -*-
# File: ilsvrc.py
import
numpy
as
np
import
os
import
tarfile
import
numpy
as
np
import
tqdm
from
...utils
import
logger
from
...utils.fs
import
download
,
get_dataset_path
,
mkdir_p
from
...utils.loadcaffe
import
get_caffe_pb
from
...utils.fs
import
mkdir_p
,
download
,
get_dataset_path
from
...utils.timer
import
timed_operation
from
..base
import
RNGDataFlow
...
...
tensorpack/dataflow/dataset/mnist.py
View file @
ac9ac2a4
...
...
@@ -2,9 +2,9 @@
# File: mnist.py
import
os
import
gzip
import
numpy
import
os
from
six.moves
import
range
from
...utils
import
logger
...
...
tensorpack/dataflow/dataset/svhn.py
View file @
ac9ac2a4
...
...
@@ -2,11 +2,11 @@
# File: svhn.py
import
os
import
numpy
as
np
import
os
from
...utils
import
logger
from
...utils.fs
import
get_dataset_path
,
download
from
...utils.fs
import
download
,
get_dataset_path
from
..base
import
RNGDataFlow
__all__
=
[
'SVHNDigit'
]
...
...
tensorpack/dataflow/dftools.py
View file @
ac9ac2a4
...
...
@@ -3,15 +3,13 @@
from
..utils.develop
import
deprecated
from
.remote
import
dump_dataflow_to_process_queue
from
.serialize
import
LMDBSerializer
,
TFRecordSerializer
__all__
=
[
'dump_dataflow_to_process_queue'
,
'dump_dataflow_to_lmdb'
,
'dump_dataflow_to_tfrecord'
]
from
.remote
import
dump_dataflow_to_process_queue
@
deprecated
(
"Use LMDBSerializer.save instead!"
,
"2019-01-31"
)
def
dump_dataflow_to_lmdb
(
df
,
lmdb_path
,
write_frequency
=
5000
):
LMDBSerializer
.
save
(
df
,
lmdb_path
,
write_frequency
)
...
...
tensorpack/dataflow/format.py
View file @
ac9ac2a4
...
...
@@ -3,18 +3,19 @@
import
numpy
as
np
import
os
import
six
from
six.moves
import
range
import
os
from
..utils
import
logger
from
..utils.utils
import
get_tqdm
from
..utils.timer
import
timed_operation
from
..utils.loadcaffe
import
get_caffe_pb
from
..utils.compatible_serialize
import
loads
from
..utils.argtools
import
log_once
from
..utils.compatible_serialize
import
loads
from
..utils.develop
import
create_dummy_class
# noqa
from
..utils.develop
import
log_deprecated
from
.base
import
RNGDataFlow
,
DataFlow
,
DataFlowReentrantGuard
from
..utils.loadcaffe
import
get_caffe_pb
from
..utils.timer
import
timed_operation
from
..utils.utils
import
get_tqdm
from
.base
import
DataFlow
,
DataFlowReentrantGuard
,
RNGDataFlow
from
.common
import
MapData
__all__
=
[
'HDF5Data'
,
'LMDBData'
,
'LMDBDataDecoder'
,
'LMDBDataPoint'
,
...
...
@@ -258,7 +259,7 @@ class TFRecordData(DataFlow):
for
dp
in
gen
:
yield
loads
(
dp
)
from
..utils.develop
import
create_dummy_class
# noqa
try
:
import
h5py
except
ImportError
:
...
...
tensorpack/dataflow/image.py
View file @
ac9ac2a4
...
...
@@ -2,13 +2,14 @@
# File: image.py
import
numpy
as
np
import
copy
as
copy_mod
import
numpy
as
np
from
contextlib
import
contextmanager
from
.base
import
RNGDataFlow
from
.common
import
MapDataComponent
,
MapData
from
..utils
import
logger
from
..utils.argtools
import
shape2d
from
.base
import
RNGDataFlow
from
.common
import
MapData
,
MapDataComponent
__all__
=
[
'ImageFromFile'
,
'AugmentImageComponent'
,
'AugmentImageCoordinates'
,
'AugmentImageComponents'
]
...
...
tensorpack/dataflow/imgaug/_test.py
View file @
ac9ac2a4
...
...
@@ -4,13 +4,13 @@
import
sys
import
cv2
from
.
import
AugmentorList
from
.crop
import
*
from
.imgproc
import
*
from
.noname
import
*
from
.deform
import
*
from
.imgproc
import
*
from
.noise
import
SaltPepperNoise
from
.noname
import
*
anchors
=
[(
0.2
,
0.2
),
(
0.7
,
0.2
),
(
0.8
,
0.8
),
(
0.5
,
0.5
),
(
0.2
,
0.5
)]
augmentors
=
AugmentorList
([
...
...
tensorpack/dataflow/imgaug/base.py
View file @
ac9ac2a4
...
...
@@ -4,12 +4,12 @@
import
inspect
import
pprint
from
abc
import
abstractmethod
,
ABCMeta
from
abc
import
ABCMeta
,
abstractmethod
import
six
from
six.moves
import
zip
from
...utils.utils
import
get_rng
from
...utils.argtools
import
log_once
from
...utils.utils
import
get_rng
from
..image
import
check_dtype
__all__
=
[
'Augmentor'
,
'ImageAugmentor'
,
'AugmentorList'
]
...
...
tensorpack/dataflow/imgaug/convert.py
View file @
ac9ac2a4
# -*- coding: utf-8 -*-
# File: convert.py
from
.base
import
ImageAugmentor
from
.meta
import
MapImage
import
numpy
as
np
import
cv2
from
.base
import
ImageAugmentor
from
.meta
import
MapImage
__all__
=
[
'ColorSpace'
,
'Grayscale'
,
'ToUint8'
,
'ToFloat32'
]
...
...
tensorpack/dataflow/imgaug/crop.py
View file @
ac9ac2a4
...
...
@@ -3,8 +3,7 @@
from
...utils.argtools
import
shape2d
from
.transform
import
TransformAugmentorBase
,
CropTransform
from
.transform
import
CropTransform
,
TransformAugmentorBase
__all__
=
[
'RandomCrop'
,
'CenterCrop'
,
'RandomCropRandomShape'
]
...
...
tensorpack/dataflow/imgaug/deform.py
View file @
ac9ac2a4
...
...
@@ -2,10 +2,11 @@
# File: deform.py
from
.base
import
ImageAugmentor
from
...utils
import
logger
import
numpy
as
np
from
...utils
import
logger
from
.base
import
ImageAugmentor
__all__
=
[]
# Code was temporarily kept here for a future reference in case someone needs it
...
...
tensorpack/dataflow/imgaug/external.py
View file @
ac9ac2a4
...
...
@@ -4,7 +4,6 @@ import numpy as np
from
.base
import
ImageAugmentor
__all__
=
[
'IAAugmentor'
,
'Albumentations'
]
...
...
tensorpack/dataflow/imgaug/geometry.py
View file @
ac9ac2a4
...
...
@@ -3,8 +3,8 @@
import
math
import
cv2
import
numpy
as
np
import
cv2
from
.base
import
ImageAugmentor
from
.transform
import
TransformAugmentorBase
,
WarpAffineTransform
...
...
tensorpack/dataflow/imgaug/imgproc.py
View file @
ac9ac2a4
...
...
@@ -2,10 +2,11 @@
# File: imgproc.py
from
.base
import
ImageAugmentor
import
numpy
as
np
import
cv2
from
.base
import
ImageAugmentor
__all__
=
[
'Hue'
,
'Brightness'
,
'BrightnessScale'
,
'Contrast'
,
'MeanVarianceNormalize'
,
'GaussianBlur'
,
'Gamma'
,
'Clip'
,
'Saturation'
,
'Lighting'
,
'MinMaxNormalize'
]
...
...
tensorpack/dataflow/imgaug/misc.py
View file @
ac9ac2a4
...
...
@@ -5,9 +5,9 @@
import
numpy
as
np
import
cv2
from
.base
import
ImageAugmentor
from
...utils
import
logger
from
...utils.argtools
import
shape2d
from
.base
import
ImageAugmentor
from
.transform
import
ResizeTransform
,
TransformAugmentorBase
__all__
=
[
'Flip'
,
'Resize'
,
'RandomResize'
,
'ResizeShortestEdge'
,
'Transpose'
]
...
...
tensorpack/dataflow/imgaug/noise.py
View file @
ac9ac2a4
...
...
@@ -2,10 +2,11 @@
# File: noise.py
from
.base
import
ImageAugmentor
import
numpy
as
np
import
cv2
from
.base
import
ImageAugmentor
__all__
=
[
'JpegNoise'
,
'GaussianNoise'
,
'SaltPepperNoise'
]
...
...
tensorpack/dataflow/imgaug/paste.py
View file @
ac9ac2a4
...
...
@@ -2,10 +2,10 @@
# File: paste.py
from
.base
import
ImageAugmentor
from
abc
import
abstractmethod
import
numpy
as
np
from
abc
import
abstractmethod
from
.base
import
ImageAugmentor
__all__
=
[
'CenterPaste'
,
'BackgroundFiller'
,
'ConstantBackgroundFiller'
,
'RandomPaste'
]
...
...
tensorpack/dataflow/imgaug/transform.py
View file @
ac9ac2a4
# -*- coding: utf-8 -*-
# File: transform.py
from
abc
import
abstractmethod
,
ABCMeta
import
six
import
cv2
import
numpy
as
np
from
abc
import
ABCMeta
,
abstractmethod
import
cv2
import
six
from
.base
import
ImageAugmentor
...
...
tensorpack/dataflow/parallel.py
View file @
ac9ac2a4
# -*- coding: utf-8 -*-
# File: parallel.py
import
atexit
import
errno
import
itertools
import
multiprocessing
as
mp
import
os
import
sys
import
uuid
import
weakref
from
contextlib
import
contextmanager
import
multiprocessing
as
mp
import
itertools
from
six.moves
import
range
,
zip
,
queue
import
errno
import
uuid
import
os
import
zmq
import
atexit
from
six.moves
import
queue
,
range
,
zip
from
.base
import
DataFlow
,
ProxyDataFlow
,
DataFlowTerminated
,
DataFlowReentrantGuard
from
..utils.concurrency
import
(
ensure_proc_terminate
,
mask_sigint
,
start_proc_mask_signal
,
enable_death_signal
,
StoppableThread
)
from
..utils.serialize
import
loads
,
dumps
from
..utils
import
logger
from
..utils.gpu
import
change_gpu
from
..utils.concurrency
import
(
StoppableThread
,
enable_death_signal
,
ensure_proc_terminate
,
mask_sigint
,
start_proc_mask_signal
)
from
..utils.develop
import
log_deprecated
from
..utils.gpu
import
change_gpu
from
..utils.serialize
import
dumps
,
loads
from
.base
import
DataFlow
,
DataFlowReentrantGuard
,
DataFlowTerminated
,
ProxyDataFlow
__all__
=
[
'PrefetchData'
,
'MultiProcessPrefetchData'
,
'PrefetchDataZMQ'
,
'PrefetchOnGPUs'
,
'MultiThreadPrefetchData'
]
...
...
tensorpack/dataflow/parallel_map.py
View file @
ac9ac2a4
# -*- coding: utf-8 -*-
# File: parallel_map.py
import
numpy
as
np
import
ctypes
import
copy
import
threading
import
ctypes
import
multiprocessing
as
mp
from
six.moves
import
queue
import
numpy
as
np
import
threading
import
zmq
from
six.moves
import
queue
from
.base
import
DataFlow
,
ProxyDataFlow
,
DataFlowReentrantGuard
from
.common
import
RepeatedData
from
..utils.concurrency
import
StoppableThread
,
enable_death_signal
from
..utils.serialize
import
loads
,
dumps
from
.parallel
import
(
_MultiProcessZMQDataFlow
,
_repeat_iter
,
_get_pipe_name
,
_bind_guard
,
_zmq_catch_error
)
from
..utils.serialize
import
dumps
,
loads
from
.base
import
DataFlow
,
DataFlowReentrantGuard
,
ProxyDataFlow
from
.common
import
RepeatedData
from
.parallel
import
_bind_guard
,
_get_pipe_name
,
_MultiProcessZMQDataFlow
,
_repeat_iter
,
_zmq_catch_error
__all__
=
[
'ThreadedMapData'
,
'MultiThreadMapData'
,
'MultiProcessMapData'
,
'MultiProcessMapDataZMQ'
]
...
...
tensorpack/dataflow/raw.py
View file @
ac9ac2a4
...
...
@@ -2,10 +2,11 @@
# File: raw.py
import
numpy
as
np
import
copy
import
numpy
as
np
import
six
from
six.moves
import
range
from
.base
import
DataFlow
,
RNGDataFlow
__all__
=
[
'FakeData'
,
'DataFromQueue'
,
'DataFromList'
,
'DataFromGenerator'
,
'DataFromIterable'
]
...
...
tensorpack/dataflow/remote.py
View file @
ac9ac2a4
...
...
@@ -2,17 +2,17 @@
# File: remote.py
import
multiprocessing
as
mp
import
time
from
collections
import
deque
import
tqdm
import
multiprocessing
as
mp
from
six.moves
import
range
from
collections
import
deque
from
.base
import
DataFlow
,
DataFlowReentrantGuard
from
..utils
import
logger
from
..utils.utils
import
get_tqdm_kwargs
from
..utils.concurrency
import
DIE
from
..utils.serialize
import
dumps
,
loads
from
..utils.utils
import
get_tqdm_kwargs
from
.base
import
DataFlow
,
DataFlowReentrantGuard
try
:
import
zmq
...
...
tensorpack/dataflow/serialize.py
View file @
ac9ac2a4
# -*- coding: utf-8 -*-
# File: serialize.py
import
os
import
numpy
as
np
import
os
from
collections
import
defaultdict
from
..utils.utils
import
get_tqdm
from
..utils
import
logger
from
..utils.compatible_serialize
import
dumps
,
loads
from
..utils.develop
import
create_dummy_class
# noqa
from
..utils.utils
import
get_tqdm
from
.base
import
DataFlow
from
.
format
import
LMDBData
,
HDF5
Data
from
.
common
import
MapData
,
FixedSize
Data
from
.raw
import
DataFrom
List
,
DataFromGenerator
from
.
common
import
FixedSizeData
,
Map
Data
from
.
format
import
HDF5Data
,
LMDB
Data
from
.raw
import
DataFrom
Generator
,
DataFromList
__all__
=
[
'LMDBSerializer'
,
'NumpySerializer'
,
'TFRecordSerializer'
,
'HDF5Serializer'
]
...
...
@@ -195,7 +195,6 @@ class HDF5Serializer():
return
HDF5Data
(
path
,
data_paths
,
shuffle
)
from
..utils.develop
import
create_dummy_class
# noqa
try
:
import
lmdb
except
ImportError
:
...
...
tensorpack/graph_builder/distributed.py
View file @
ac9ac2a4
# -*- coding: utf-8 -*-
# File: distributed.py
import
tensorflow
as
tf
import
re
import
tensorflow
as
tf
from
six.moves
import
range
from
..tfutils.common
import
get_global_step_var
,
get_op_tensor_name
from
..utils
import
logger
from
..utils.argtools
import
memoized
from
..tfutils.common
import
get_op_tensor_name
,
get_global_step_var
from
.training
import
GraphBuilder
,
DataParallelBuilder
from
.utils
import
(
override_to_local_variable
,
aggregate_grads
,
OverrideCachingDevice
)
from
.training
import
DataParallelBuilder
,
GraphBuilder
from
.utils
import
OverrideCachingDevice
,
aggregate_grads
,
override_to_local_variable
__all__
=
[
'DistributedParameterServerBuilder'
,
'DistributedReplicatedBuilder'
]
...
...
tensorpack/graph_builder/model_desc.py
View file @
ac9ac2a4
...
...
@@ -5,11 +5,11 @@
from
collections
import
namedtuple
import
tensorflow
as
tf
from
..models.regularize
import
regularize_cost_from_collection
from
..tfutils.tower
import
get_current_tower_context
from
..utils
import
logger
from
..utils.argtools
import
memoized_method
from
..utils.develop
import
log_deprecated
from
..tfutils.tower
import
get_current_tower_context
from
..models.regularize
import
regularize_cost_from_collection
__all__
=
[
'InputDesc'
,
'ModelDesc'
,
'ModelDescBase'
]
...
...
tensorpack/graph_builder/predict.py
View file @
ac9ac2a4
...
...
@@ -3,9 +3,9 @@
import
tensorflow
as
tf
from
..tfutils.tower
import
PredictTowerContext
from
..utils
import
logger
from
..utils.develop
import
deprecated
from
..tfutils.tower
import
PredictTowerContext
from
.training
import
GraphBuilder
__all__
=
[
'SimplePredictBuilder'
]
...
...
tensorpack/graph_builder/training.py
View file @
ac9ac2a4
# -*- coding: utf-8 -*-
# File: training.py
from
abc
import
ABCMeta
,
abstractmethod
import
tensorflow
as
tf
import
copy
import
six
import
re
import
pprint
from
six.moves
import
zip
,
range
import
re
from
abc
import
ABCMeta
,
abstractmethod
from
contextlib
import
contextmanager
import
six
import
tensorflow
as
tf
from
six.moves
import
range
,
zip
from
..utils
import
logger
from
..tfutils.tower
import
TrainTowerContext
from
..tfutils.gradproc
import
ScaleGradient
from
..tfutils.common
import
get_tf_version_tuple
from
..tfutils.gradproc
import
ScaleGradient
from
..tfutils.tower
import
TrainTowerContext
from
..utils
import
logger
from
.utils
import
(
LeastLoadedDeviceSetter
,
override_to_local_variable
,
allreduce_grads
,
aggregate_grads
,
allreduce_grads_hierarchical
,
split_grad_list
,
merge_grad_list
,
GradientPacker
)
GradientPacker
,
LeastLoadedDeviceSetter
,
aggregate_grads
,
allreduce_grads
,
allreduce_grads_hierarchical
,
merge_grad_list
,
override_to_local_variable
,
split_grad_list
)
__all__
=
[
'GraphBuilder'
,
'SyncMultiGPUParameterServerBuilder'
,
'DataParallelBuilder'
,
...
...
tensorpack/graph_builder/utils.py
View file @
ac9ac2a4
...
...
@@ -2,16 +2,15 @@
# File: utils.py
from
contextlib
import
contextmanager
import
operator
from
contextlib
import
contextmanager
import
tensorflow
as
tf
from
..tfutils.varreplace
import
custom_getter_scope
from
..tfutils.scope_utils
import
under_name_scope
,
cached_name_scope
from
..tfutils.common
import
get_tf_version_tuple
from
..utils.argtools
import
call_only_once
from
..tfutils.scope_utils
import
cached_name_scope
,
under_name_scope
from
..tfutils.varreplace
import
custom_getter_scope
from
..utils
import
logger
from
..utils.argtools
import
call_only_once
__all__
=
[
'LeastLoadedDeviceSetter'
,
'OverrideCachingDevice'
,
...
...
tensorpack/input_source/input_source.py
View file @
ac9ac2a4
...
...
@@ -2,27 +2,28 @@
# File: input_source.py
import
tensorflow
as
tf
try
:
from
tensorflow.python.ops.data_flow_ops
import
StagingArea
except
ImportError
:
pass
import
threading
from
contextlib
import
contextmanager
from
itertools
import
chain
import
tensorflow
as
tf
from
six.moves
import
range
,
zip
import
threading
from
.input_source_base
import
InputSource
from
..callbacks.base
import
Callback
,
CallbackFactory
from
..callbacks.graph
import
RunOp
from
..dataflow
import
DataFlow
,
MapData
,
RepeatedData
from
..tfutils.summary
import
add_moving_summary
from
..tfutils.common
import
get_op_tensor_name
from
..tfutils.tower
import
get_current_tower_context
from
..tfutils.dependency
import
dependency_of_fetches
from
..tfutils.summary
import
add_moving_summary
from
..tfutils.tower
import
get_current_tower_context
from
..utils
import
logger
from
..utils.concurrency
import
ShareSessionThread
from
..callbacks.base
import
Callback
,
CallbackFactory
from
..callbacks.graph
import
RunOp
from
.input_source_base
import
InputSource
try
:
from
tensorflow.python.ops.data_flow_ops
import
StagingArea
except
ImportError
:
pass
__all__
=
[
'PlaceholderInput'
,
'FeedInput'
,
'FeedfreeInput'
,
'QueueInput'
,
'BatchQueueInput'
,
...
...
tensorpack/input_source/input_source_base.py
View file @
ac9ac2a4
# -*- coding: utf-8 -*-
# File: input_source_base.py
from
abc
import
ABCMeta
,
abstractmethod
import
copy
import
six
from
six.moves
import
zip
from
abc
import
ABCMeta
,
abstractmethod
from
contextlib
import
contextmanager
import
six
import
tensorflow
as
tf
from
six.moves
import
zip
from
..utils.argtools
import
memoized_method
,
call_only_once
from
..callbacks.base
import
CallbackFactory
from
..tfutils.common
import
get_op_tensor_name
from
..utils
import
logger
from
..utils.argtools
import
call_only_once
,
memoized_method
__all__
=
[
'InputSource'
,
'remap_input_source'
]
...
...
tensorpack/models/_old_batch_norm.py
View file @
ac9ac2a4
...
...
@@ -4,13 +4,15 @@
import
tensorflow
as
tf
from
tensorflow.contrib.framework
import
add_model_variable
from
tensorflow.python.training
import
moving_averages
from
..tfutils.common
import
get_tf_version_tuple
from
..tfutils.tower
import
get_current_tower_context
from
..utils
import
logger
from
..utils.argtools
import
get_data_format
from
..tfutils.tower
import
get_current_tower_context
from
..tfutils.common
import
get_tf_version_tuple
from
.common
import
layer_register
,
VariableHolder
from
.common
import
VariableHolder
,
layer_register
from
.tflayer
import
convert_to_tflayer_args
"""
Old Custom BN Implementation, Kept Here For Future Reference
"""
...
...
tensorpack/models/_test.py
View file @
ac9ac2a4
...
...
@@ -3,8 +3,8 @@
import
logging
import
tensorflow
as
tf
import
unittest
import
tensorflow
as
tf
class
TestModel
(
unittest
.
TestCase
):
...
...
tensorpack/models/batch_norm.py
View file @
ac9ac2a4
...
...
@@ -2,17 +2,17 @@
# File: batch_norm.py
import
tensorflow
as
tf
from
tensorflow.python.training
import
moving_averages
import
re
import
six
import
tensorflow
as
tf
from
tensorflow.python.training
import
moving_averages
from
..tfutils.collection
import
backup_collection
,
restore_collection
from
..tfutils.common
import
get_tf_version_tuple
from
..tfutils.tower
import
get_current_tower_context
from
..utils
import
logger
from
..utils.argtools
import
get_data_format
from
..tfutils.tower
import
get_current_tower_context
from
..tfutils.common
import
get_tf_version_tuple
from
..tfutils.collection
import
backup_collection
,
restore_collection
from
.common
import
layer_register
,
VariableHolder
from
.common
import
VariableHolder
,
layer_register
from
.tflayer
import
convert_to_tflayer_args
,
rename_get_variable
__all__
=
[
'BatchNorm'
,
'BatchRenorm'
]
...
...
tensorpack/models/common.py
View file @
ac9ac2a4
...
...
@@ -2,7 +2,7 @@
# File: common.py
from
.registry
import
layer_register
# noqa
from
.utils
import
VariableHolder
# noqa
from
.tflayer
import
rename_tflayer_get_variable
from
.utils
import
VariableHolder
# noqa
__all__
=
[
'layer_register'
,
'VariableHolder'
,
'rename_tflayer_get_variable'
]
tensorpack/models/conv2d.py
View file @
ac9ac2a4
...
...
@@ -3,10 +3,11 @@
import
tensorflow
as
tf
from
.common
import
layer_register
,
VariableHolder
from
..tfutils.common
import
get_tf_version_tuple
from
..utils.argtools
import
shape2d
,
shape4d
,
get_data_format
from
.tflayer
import
rename_get_variable
,
convert_to_tflayer_args
from
..utils.argtools
import
get_data_format
,
shape2d
,
shape4d
from
.common
import
VariableHolder
,
layer_register
from
.tflayer
import
convert_to_tflayer_args
,
rename_get_variable
__all__
=
[
'Conv2D'
,
'Deconv2D'
,
'Conv2DTranspose'
]
...
...
@@ -50,7 +51,7 @@ def Conv2D(
"""
if
kernel_initializer
is
None
:
if
get_tf_version_tuple
()
<=
(
1
,
12
):
kernel_initializer
=
tf
.
contrib
.
layers
.
variance_scaling_initializer
(
2.0
)
,
kernel_initializer
=
tf
.
contrib
.
layers
.
variance_scaling_initializer
(
2.0
)
else
:
kernel_initializer
=
tf
.
keras
.
initializers
.
VarianceScaling
(
2.0
,
distribution
=
'untruncated_normal'
)
if
split
==
1
:
...
...
@@ -158,7 +159,7 @@ def Conv2DTranspose(
"""
if
kernel_initializer
is
None
:
if
get_tf_version_tuple
()
<=
(
1
,
12
):
kernel_initializer
=
tf
.
contrib
.
layers
.
variance_scaling_initializer
(
2.0
)
,
kernel_initializer
=
tf
.
contrib
.
layers
.
variance_scaling_initializer
(
2.0
)
else
:
kernel_initializer
=
tf
.
keras
.
initializers
.
VarianceScaling
(
2.0
,
distribution
=
'untruncated_normal'
)
...
...
tensorpack/models/fc.py
View file @
ac9ac2a4
...
...
@@ -2,11 +2,11 @@
# File: fc.py
import
tensorflow
as
tf
import
numpy
as
np
import
tensorflow
as
tf
from
..tfutils.common
import
get_tf_version_tuple
from
.common
import
layer_register
,
VariableHold
er
from
.common
import
VariableHolder
,
layer_regist
er
from
.tflayer
import
convert_to_tflayer_args
,
rename_get_variable
__all__
=
[
'FullyConnected'
]
...
...
@@ -48,7 +48,7 @@ def FullyConnected(
"""
if
kernel_initializer
is
None
:
if
get_tf_version_tuple
()
<=
(
1
,
12
):
kernel_initializer
=
tf
.
contrib
.
layers
.
variance_scaling_initializer
(
2.0
)
,
kernel_initializer
=
tf
.
contrib
.
layers
.
variance_scaling_initializer
(
2.0
)
else
:
kernel_initializer
=
tf
.
keras
.
initializers
.
VarianceScaling
(
2.0
,
distribution
=
'untruncated_normal'
)
...
...
tensorpack/models/layer_norm.py
View file @
ac9ac2a4
...
...
@@ -3,8 +3,9 @@
import
tensorflow
as
tf
from
.common
import
layer_register
,
VariableHolder
from
..utils.argtools
import
get_data_format
from
.common
import
VariableHolder
,
layer_register
__all__
=
[
'LayerNorm'
,
'InstanceNorm'
]
...
...
tensorpack/models/linearwrap.py
View file @
ac9ac2a4
...
...
@@ -2,8 +2,9 @@
# File: linearwrap.py
import
six
from
types
import
ModuleType
import
six
from
.registry
import
get_registered_layer
__all__
=
[
'LinearWrap'
]
...
...
tensorpack/models/nonlin.py
View file @
ac9ac2a4
...
...
@@ -4,8 +4,8 @@
import
tensorflow
as
tf
from
.common
import
layer_register
,
VariableHolder
from
.batch_norm
import
BatchNorm
from
.common
import
VariableHolder
,
layer_register
__all__
=
[
'Maxout'
,
'PReLU'
,
'BNReLU'
]
...
...
tensorpack/models/pool.py
View file @
ac9ac2a4
# -*- coding: utf-8 -*-
# File: pool.py
import
tensorflow
as
tf
import
numpy
as
np
import
tensorflow
as
tf
from
.shape_utils
import
StaticDynamicShape
from
.common
import
layer_register
from
..utils.argtools
import
shape2d
,
get_data_format
from
..utils.argtools
import
get_data_format
,
shape2d
from
..utils.develop
import
log_deprecated
from
._test
import
TestModel
from
.common
import
layer_register
from
.shape_utils
import
StaticDynamicShape
from
.tflayer
import
convert_to_tflayer_args
__all__
=
[
'MaxPooling'
,
'FixedUnPooling'
,
'AvgPooling'
,
'GlobalAvgPooling'
,
'BilinearUpSample'
]
...
...
tensorpack/models/registry.py
View file @
ac9ac2a4
...
...
@@ -2,11 +2,11 @@
# File: registry.py
import
tensorflow
as
tf
import
copy
import
re
from
functools
import
wraps
import
six
import
re
import
copy
import
tensorflow
as
tf
from
..tfutils.argscope
import
get_arg_scope
from
..tfutils.model_utils
import
get_shape_str
...
...
tensorpack/models/regularize.py
View file @
ac9ac2a4
...
...
@@ -2,13 +2,13 @@
# File: regularize.py
import
tensorflow
as
tf
import
re
import
tensorflow
as
tf
from
..utils
import
logger
from
..utils.argtools
import
graph_memoized
from
..tfutils.common
import
get_tf_version_tuple
from
..tfutils.tower
import
get_current_tower_context
from
..utils
import
logger
from
..utils.argtools
import
graph_memoized
from
.common
import
layer_register
__all__
=
[
'regularize_cost'
,
'regularize_cost_from_collection'
,
...
...
tensorpack/models/shapes.py
View file @
ac9ac2a4
...
...
@@ -3,6 +3,7 @@
import
tensorflow
as
tf
from
.common
import
layer_register
__all__
=
[
'ConcatWith'
]
...
...
tensorpack/models/tflayer.py
View file @
ac9ac2a4
# -*- coding: utf-8 -*-
# File: tflayer.py
import
tensorflow
as
tf
import
six
import
functools
import
six
import
tensorflow
as
tf
from
..utils.argtools
import
get_data_format
from
..tfutils.common
import
get_tf_version_tuple
from
..tfutils.varreplace
import
custom_getter_scope
from
..utils.argtools
import
get_data_format
__all__
=
[]
...
...
tensorpack/predict/base.py
View file @
ac9ac2a4
...
...
@@ -2,13 +2,13 @@
# File: base.py
from
abc
import
abstractmethod
,
ABCMeta
import
tensorflow
as
tf
from
abc
import
ABCMeta
,
abstractmethod
import
six
import
tensorflow
as
tf
from
..input_source
import
PlaceholderInput
from
..tfutils.common
import
get_tensors_by_names
from
..tfutils.tower
import
PredictTowerContext
from
..input_source
import
PlaceholderInput
__all__
=
[
'PredictorBase'
,
'AsyncPredictorBase'
,
'OnlinePredictor'
,
'OfflinePredictor'
,
...
...
tensorpack/predict/concurrency.py
View file @
ac9ac2a4
...
...
@@ -2,16 +2,16 @@
# File: concurrency.py
import
numpy
as
np
import
multiprocessing
import
numpy
as
np
import
six
from
six.moves
import
queue
,
range
import
tensorflow
as
tf
from
six.moves
import
queue
,
range
from
..utils
import
logger
from
..utils.concurrency
import
DIE
,
StoppableThread
,
ShareSessionThread
from
..tfutils.model_utils
import
describe_trainable_vars
from
.base
import
OnlinePredictor
,
OfflinePredictor
,
AsyncPredictorBase
from
..utils
import
logger
from
..utils.concurrency
import
DIE
,
ShareSessionThread
,
StoppableThread
from
.base
import
AsyncPredictorBase
,
OfflinePredictor
,
OnlinePredictor
__all__
=
[
'MultiProcessPredictWorker'
,
'MultiProcessQueuePredictWorker'
,
'MultiThreadAsyncPredictor'
]
...
...
tensorpack/predict/config.py
View file @
ac9ac2a4
...
...
@@ -2,13 +2,13 @@
# File: config.py
import
tensorflow
as
tf
import
six
import
tensorflow
as
tf
from
..graph_builder
import
ModelDescBase
from
..tfutils
import
get_default_sess_config
from
..tfutils.sessinit
import
JustCurrentSession
,
SessionInit
from
..tfutils.tower
import
TowerFuncWrapper
from
..tfutils.sessinit
import
SessionInit
,
JustCurrentSession
from
..utils
import
logger
__all__
=
[
'PredictConfig'
]
...
...
tensorpack/predict/dataset.py
View file @
ac9ac2a4
...
...
@@ -2,22 +2,21 @@
# File: dataset.py
from
six.moves
import
range
,
zip
from
abc
import
ABCMeta
,
abstractmethod
import
multiprocessing
import
os
from
abc
import
ABCMeta
,
abstractmethod
import
six
from
six.moves
import
range
,
zip
from
..dataflow
import
DataFlow
from
..dataflow.remote
import
dump_dataflow_to_process_queue
from
..utils.concurrency
import
ensure_proc_terminate
,
OrderedResultGatherProc
,
DIE
from
..utils
import
logger
from
..utils.
utils
import
get_tqdm
from
..utils.
concurrency
import
DIE
,
OrderedResultGatherProc
,
ensure_proc_terminate
from
..utils.gpu
import
change_gpu
,
get_num_gpu
from
..utils.utils
import
get_tqdm
from
.base
import
OfflinePredictor
from
.concurrency
import
MultiProcessQueuePredictWorker
from
.config
import
PredictConfig
from
.base
import
OfflinePredictor
__all__
=
[
'DatasetPredictorBase'
,
'SimpleDatasetPredictor'
,
'MultiProcessDatasetPredictor'
]
...
...
tensorpack/predict/feedfree.py
View file @
ac9ac2a4
#!/usr/bin/env python
from
tensorflow.python.training.monitored_session
\
import
_HookedSession
as
HookedSession
from
tensorflow.python.training.monitored_session
import
_HookedSession
as
HookedSession
from
.base
import
PredictorBase
from
..tfutils.tower
import
PredictTowerContext
from
..callbacks
import
Callbacks
from
..tfutils.tower
import
PredictTowerContext
from
.base
import
PredictorBase
__all__
=
[
'FeedfreePredictor'
]
...
...
tensorpack/predict/multigpu.py
View file @
ac9ac2a4
...
...
@@ -3,10 +3,11 @@
import
tensorflow
as
tf
from
..utils
import
logger
from
..graph_builder.model_desc
import
InputDesc
from
..input_source
import
PlaceholderInput
from
..tfutils.tower
import
PredictTowerContext
from
..utils
import
logger
from
.base
import
OnlinePredictor
__all__
=
[
'MultiTowerOfflinePredictor'
,
...
...
tensorpack/tfutils/argscope.py
View file @
ac9ac2a4
# -*- coding: utf-8 -*-
# File: argscope.py
from
contextlib
import
contextmanager
from
collections
import
defaultdict
import
copy
from
collections
import
defaultdict
from
contextlib
import
contextmanager
from
functools
import
wraps
from
inspect
import
isfunction
,
getmembers
from
inspect
import
getmembers
,
isfunction
from
.tower
import
get_current_tower_context
from
..utils
import
logger
from
.tower
import
get_current_tower_context
__all__
=
[
'argscope'
,
'get_arg_scope'
,
'enable_argscope_for_module'
]
...
...
tensorpack/tfutils/collection.py
View file @
ac9ac2a4
...
...
@@ -2,10 +2,10 @@
# File: collection.py
import
tensorflow
as
tf
from
contextlib
import
contextmanager
from
copy
import
copy
import
six
from
contextlib
import
contextmanager
import
tensorflow
as
tf
from
..utils
import
logger
from
..utils.argtools
import
memoized
...
...
tensorpack/tfutils/common.py
View file @
ac9ac2a4
...
...
@@ -4,6 +4,7 @@
import
tensorflow
as
tf
from
six.moves
import
map
from
..utils.argtools
import
graph_memoized
from
..utils.develop
import
deprecated
...
...
tensorpack/tfutils/dependency.py
View file @
ac9ac2a4
import
tensorflow
as
tf
from
tensorflow.contrib.graph_editor
import
get_backward_walk_ops
from
..utils.argtools
import
graph_memoized
"""
...
...
tensorpack/tfutils/export.py
View file @
ac9ac2a4
...
...
@@ -12,10 +12,10 @@ from tensorflow.python.framework import graph_util
from
tensorflow.python.platform
import
gfile
from
tensorflow.python.tools
import
optimize_for_inference_lib
from
..
utils
import
logger
from
..
input_source
import
PlaceholderInput
from
..tfutils.common
import
get_tensors_by_names
from
..tfutils.tower
import
PredictTowerContext
from
..
input_source
import
PlaceholderInput
from
..
utils
import
logger
__all__
=
[
'ModelExporter'
]
...
...
tensorpack/tfutils/gradproc.py
View file @
ac9ac2a4
...
...
@@ -2,14 +2,15 @@
# File: gradproc.py
import
tensorflow
as
tf
from
abc
import
ABCMeta
,
abstractmethod
import
inspect
import
re
from
abc
import
ABCMeta
,
abstractmethod
import
six
import
inspect
import
tensorflow
as
tf
from
..utils
import
logger
from
.symbolic_functions
import
rms
,
print_stat
from
.summary
import
add_moving_summary
from
.symbolic_functions
import
print_stat
,
rms
__all__
=
[
'GradientProcessor'
,
'FilterNoneGrad'
,
'GlobalNormClip'
,
'MapGradient'
,
'SummaryGradient'
,
...
...
tensorpack/tfutils/model_utils.py
View file @
ac9ac2a4
...
...
@@ -3,8 +3,8 @@
# Author: tensorpack contributors
import
tensorflow
as
tf
from
termcolor
import
colored
from
tabulate
import
tabulate
from
termcolor
import
colored
from
..utils
import
logger
...
...
tensorpack/tfutils/optimizer.py
View file @
ac9ac2a4
...
...
@@ -2,11 +2,11 @@
# File: optimizer.py
import
tensorflow
as
tf
from
contextlib
import
contextmanager
import
tensorflow
as
tf
from
..utils.develop
import
HIDE_DOC
from
..tfutils.common
import
get_tf_version_tuple
from
..utils.develop
import
HIDE_DOC
from
.gradproc
import
FilterNoneGrad
,
GradientProcessor
__all__
=
[
'apply_grad_processors'
,
'ProxyOptimizer'
,
...
...
tensorpack/tfutils/scope_utils.py
View file @
ac9ac2a4
...
...
@@ -2,9 +2,9 @@
# File: scope_utils.py
import
tensorflow
as
tf
import
functools
from
contextlib
import
contextmanager
import
tensorflow
as
tf
from
..utils.argtools
import
graph_memoized
from
.common
import
get_tf_version_tuple
...
...
tensorpack/tfutils/sesscreate.py
View file @
ac9ac2a4
...
...
@@ -3,8 +3,9 @@
import
tensorflow
as
tf
from
.common
import
get_default_sess_config
from
..utils
import
logger
from
.common
import
get_default_sess_config
__all__
=
[
'NewSessionCreator'
,
'ReuseSessionCreator'
,
'SessionCreatorAdapter'
]
...
...
tensorpack/tfutils/sessinit.py
View file @
ac9ac2a4
...
...
@@ -3,13 +3,12 @@
import
numpy
as
np
import
tensorflow
as
tf
import
six
import
tensorflow
as
tf
from
..utils
import
logger
from
.common
import
get_op_tensor_name
from
.varmanip
import
(
SessionUpdate
,
get_savename_from_varname
,
is_training_name
,
get_checkpoint_path
)
from
.varmanip
import
SessionUpdate
,
get_checkpoint_path
,
get_savename_from_varname
,
is_training_name
__all__
=
[
'SessionInit'
,
'ChainInit'
,
'SaverRestore'
,
'SaverRestoreRelaxed'
,
'DictRestore'
,
...
...
tensorpack/tfutils/summary.py
View file @
ac9ac2a4
...
...
@@ -2,20 +2,19 @@
# File: summary.py
import
re
from
contextlib
import
contextmanager
import
six
import
tensorflow
as
tf
import
re
from
six.moves
import
range
from
contextlib
import
contextmanager
from
tensorflow.python.training
import
moving_averages
from
..utils
import
logger
from
..utils.argtools
import
graph_memoized
from
..utils.naming
import
MOVING_SUMMARY_OPS_KEY
from
.tower
import
get_current_tower_context
from
.symbolic_functions
import
rms
from
.scope_utils
import
cached_name_scope
from
.symbolic_functions
import
rms
from
.tower
import
get_current_tower_context
__all__
=
[
'add_tensor_summary'
,
'add_param_summary'
,
'add_activation_summary'
,
'add_moving_summary'
,
...
...
tensorpack/tfutils/tower.py
View file @
ac9ac2a4
...
...
@@ -2,15 +2,15 @@
# File: tower.py
import
tensorflow
as
tf
from
abc
import
ABCMeta
,
abstractmethod
,
abstractproperty
import
six
import
tensorflow
as
tf
from
six.moves
import
zip
from
abc
import
abstractproperty
,
abstractmethod
,
ABCMeta
from
..utils
import
logger
from
..utils.argtools
import
call_only_once
from
..utils.naming
import
MOVING_SUMMARY_OPS_KEY
from
..utils.develop
import
HIDE_DOC
from
..utils.naming
import
MOVING_SUMMARY_OPS_KEY
from
.collection
import
CollectionGuard
from
.common
import
get_op_or_tensor_by_name
,
get_op_tensor_name
...
...
tensorpack/tfutils/varmanip.py
View file @
ac9ac2a4
# -*- coding: utf-8 -*-
# File: varmanip.py
import
six
import
numpy
as
np
import
os
import
pprint
import
six
import
tensorflow
as
tf
import
numpy
as
np
from
..utils
import
logger
from
.common
import
get_op_tensor_name
...
...
tensorpack/tfutils/varreplace.py
View file @
ac9ac2a4
...
...
@@ -2,8 +2,8 @@
# File: varreplace.py
# Credit: Qinyao He
import
tensorflow
as
tf
from
contextlib
import
contextmanager
import
tensorflow
as
tf
from
.common
import
get_tf_version_tuple
...
...
tensorpack/train/base.py
View file @
ac9ac2a4
# -*- coding: utf-8 -*-
# File: base.py
import
tensorflow
as
tf
import
weakref
import
copy
import
time
from
six.moves
import
range
import
weakref
import
six
import
copy
import
tensorflow
as
tf
from
six.moves
import
range
from
..callbacks
import
(
Callback
,
Callbacks
,
Monitors
,
TrainingMonitor
)
from
..utils
import
logger
from
..utils.utils
import
humanize_time_delta
from
..utils.argtools
import
call_only_once
from
..callbacks
import
Callback
,
Callbacks
,
Monitors
,
TrainingMonitor
from
..callbacks.steps
import
MaintainStepCounter
from
..tfutils
import
get_global_step_value
from
..tfutils.model_utils
import
describe_trainable_vars
from
..tfutils.sessinit
import
SessionInit
,
JustCurrentSession
from
..tfutils.sesscreate
import
ReuseSessionCreator
,
NewSessionCreator
from
..callbacks.steps
import
MaintainStepCounter
from
.config
import
TrainConfig
,
DEFAULT_MONITORS
,
DEFAULT_CALLBACKS
from
..tfutils.sesscreate
import
NewSessionCreator
,
ReuseSessionCreator
from
..tfutils.sessinit
import
JustCurrentSession
,
SessionInit
from
..utils
import
logger
from
..utils.argtools
import
call_only_once
from
..utils.utils
import
humanize_time_delta
from
.config
import
DEFAULT_CALLBACKS
,
DEFAULT_MONITORS
,
TrainConfig
__all__
=
[
'StopTraining'
,
'Trainer'
]
...
...
tensorpack/train/config.py
View file @
ac9ac2a4
...
...
@@ -5,15 +5,13 @@ import os
import
tensorflow
as
tf
from
..callbacks
import
(
MovingAverageSummary
,
ProgressBar
,
MergeAllSummaries
,
TFEventWriter
,
JSONWriter
,
ScalarPrinter
,
RunUpdateOps
)
JSONWriter
,
MergeAllSummaries
,
MovingAverageSummary
,
ProgressBar
,
RunUpdateOps
,
ScalarPrinter
,
TFEventWriter
)
from
..dataflow.base
import
DataFlow
from
..graph_builder.model_desc
import
ModelDescBase
from
..utils
import
logger
from
..tfutils.sessinit
import
SessionInit
,
SaverRestore
from
..tfutils.sesscreate
import
NewSessionCreator
from
..input_source
import
InputSource
from
..tfutils.sesscreate
import
NewSessionCreator
from
..tfutils.sessinit
import
SaverRestore
,
SessionInit
from
..utils
import
logger
__all__
=
[
'TrainConfig'
,
'AutoResumeTrainConfig'
,
'DEFAULT_CALLBACKS'
,
'DEFAULT_MONITORS'
]
...
...
tensorpack/train/interface.py
View file @
ac9ac2a4
...
...
@@ -3,11 +3,8 @@
import
tensorflow
as
tf
from
..input_source
import
(
InputSource
,
FeedInput
,
FeedfreeInput
,
QueueInput
,
StagingInput
,
DummyConstantInput
)
from
..input_source
import
DummyConstantInput
,
FeedfreeInput
,
FeedInput
,
InputSource
,
QueueInput
,
StagingInput
from
..utils
import
logger
from
.config
import
TrainConfig
from
.tower
import
SingleCostTrainer
from
.trainers
import
SimpleTrainer
...
...
tensorpack/train/tower.py
View file @
ac9ac2a4
# -*- coding: utf-8 -*-
# File: tower.py
import
tensorflow
as
tf
from
abc
import
ABCMeta
,
abstractmethod
import
six
from
abc
import
abstractmethod
,
ABCMeta
import
tensorflow
as
tf
from
..utils.argtools
import
call_only_once
,
memoized
from
..utils.develop
import
HIDE_DOC
from
..utils
import
logger
from
..input_source
import
PlaceholderInput
from
..predict.base
import
OnlinePredictor
from
..tfutils.tower
import
TowerFuncWrapper
,
get_current_tower_context
,
PredictTowerContext
from
..tfutils.gradproc
import
FilterNoneGrad
from
..tfutils.tower
import
PredictTowerContext
,
TowerFuncWrapper
,
get_current_tower_context
from
..utils
import
logger
from
..utils.argtools
import
call_only_once
,
memoized
from
..utils.develop
import
HIDE_DOC
from
.base
import
Trainer
__all__
=
[
'SingleCostTrainer'
,
'TowerTrainer'
]
...
...
tensorpack/train/trainers.py
View file @
ac9ac2a4
# -*- coding: utf-8 -*-
# File: trainers.py
import
sys
import
multiprocessing
as
mp
import
os
import
sys
import
tensorflow
as
tf
import
multiprocessing
as
mp
from
..callbacks
import
RunOp
,
CallbackFactory
from
..callbacks
import
CallbackFactory
,
RunOp
from
..graph_builder.distributed
import
DistributedParameterServerBuilder
,
DistributedReplicatedBuilder
from
..graph_builder.training
import
(
AsyncMultiGPUBuilder
,
SyncMultiGPUParameterServerBuilder
,
SyncMultiGPUReplicatedBuilder
)
from
..graph_builder.utils
import
override_to_local_variable
from
..input_source
import
FeedfreeInput
,
QueueInput
from
..tfutils
import
get_global_step_var
from
..tfutils.distributed
import
get_distributed_session_creator
from
..tfutils.sesscreate
import
NewSessionCreator
from
..tfutils.tower
import
TrainTowerContext
from
..utils
import
logger
from
..utils.argtools
import
map_arg
from
..utils.develop
import
HIDE_DOC
,
log_deprecated
from
..tfutils
import
get_global_step_var
from
..tfutils.distributed
import
get_distributed_session_creator
from
..tfutils.tower
import
TrainTowerContext
from
..input_source
import
QueueInput
,
FeedfreeInput
from
..graph_builder.training
import
(
SyncMultiGPUParameterServerBuilder
,
SyncMultiGPUReplicatedBuilder
,
AsyncMultiGPUBuilder
)
from
..graph_builder.distributed
import
DistributedReplicatedBuilder
,
DistributedParameterServerBuilder
from
..graph_builder.utils
import
override_to_local_variable
from
.tower
import
SingleCostTrainer
__all__
=
[
'NoOpTrainer'
,
'SimpleTrainer'
,
...
...
tensorpack/train/utility.py
View file @
ac9ac2a4
...
...
@@ -2,6 +2,4 @@
# File: utility.py
# for backwards-compatibility
from
..graph_builder.utils
import
(
# noqa
OverrideToLocalVariable
,
override_to_local_variable
,
LeastLoadedDeviceSetter
)
from
..graph_builder.utils
import
LeastLoadedDeviceSetter
,
OverrideToLocalVariable
,
override_to_local_variable
# noqa
tensorpack/utils/argtools.py
View file @
ac9ac2a4
...
...
@@ -4,7 +4,9 @@
import
inspect
import
six
from
.
import
logger
if
six
.
PY2
:
import
functools32
as
functools
else
:
...
...
tensorpack/utils/compatible_serialize.py
View file @
ac9ac2a4
import
os
from
.serialize
import
loads_msgpack
,
loads_pyarrow
,
dumps_msgpack
,
dumps_pyarrow
from
.serialize
import
dumps_msgpack
,
dumps_pyarrow
,
loads_msgpack
,
loads_pyarrow
"""
Serialization that has compatibility guarantee (therefore is safe to store to disk).
...
...
tensorpack/utils/concurrency.py
View file @
ac9ac2a4
...
...
@@ -3,14 +3,14 @@
# Some code taken from zxytim
import
threading
import
platform
import
multiprocessing
import
atexit
import
bisect
from
contextlib
import
contextmanager
import
multiprocessing
import
platform
import
signal
import
threading
import
weakref
from
contextlib
import
contextmanager
import
six
from
six.moves
import
queue
...
...
tensorpack/utils/develop.py
View file @
ac9ac2a4
...
...
@@ -6,11 +6,11 @@
""" Utilities for developers only.
These are not visible to users (not automatically imported). And should not
appeared in docs."""
import
os
import
functools
from
datetime
import
datetime
import
importlib
import
os
import
types
from
datetime
import
datetime
import
six
from
.
import
logger
...
...
tensorpack/utils/fs.py
View file @
ac9ac2a4
...
...
@@ -2,10 +2,11 @@
# File: fs.py
import
os
from
six.moves
import
urllib
import
errno
import
os
import
tqdm
from
six.moves
import
urllib
from
.
import
logger
from
.utils
import
execute_only_once
...
...
tensorpack/utils/gpu.py
View file @
ac9ac2a4
...
...
@@ -3,10 +3,11 @@
import
os
from
.utils
import
change_env
from
.
import
logger
from
.nvml
import
NVMLContext
from
.concurrency
import
subproc_call
from
.nvml
import
NVMLContext
from
.utils
import
change_env
__all__
=
[
'change_gpu'
,
'get_nr_gpu'
,
'get_num_gpu'
]
...
...
tensorpack/utils/loadcaffe.py
View file @
ac9ac2a4
...
...
@@ -2,14 +2,14 @@
# File: loadcaffe.py
import
sys
import
numpy
as
np
import
os
import
sys
from
.utils
import
change_env
from
.fs
import
download
,
get_dataset_path
from
.concurrency
import
subproc_call
from
.
import
logger
from
.concurrency
import
subproc_call
from
.fs
import
download
,
get_dataset_path
from
.utils
import
change_env
__all__
=
[
'load_caffe'
,
'get_caffe_pb'
]
...
...
tensorpack/utils/logger.py
View file @
ac9ac2a4
...
...
@@ -16,12 +16,12 @@ The logger module itself has the common logging functions of Python's
import
logging
import
os
import
shutil
import
os.path
from
termcolor
import
colored
import
shutil
import
sys
from
datetime
import
datetime
from
six.moves
import
input
import
sys
from
termcolor
import
colored
__all__
=
[
'set_logger_dir'
,
'auto_set_dir'
,
'get_logger_dir'
]
...
...
tensorpack/utils/nvml.py
View file @
ac9ac2a4
# -*- coding: utf-8 -*-
# File: nvml.py
from
ctypes
import
(
byref
,
c_uint
,
c_ulonglong
,
CDLL
,
POINTER
,
Structure
)
import
threading
from
ctypes
import
CDLL
,
POINTER
,
Structure
,
byref
,
c_uint
,
c_ulonglong
__all__
=
[
'NVMLContext'
]
...
...
tensorpack/utils/rect.py
View file @
ac9ac2a4
...
...
@@ -3,6 +3,7 @@
import
numpy
as
np
from
.develop
import
log_deprecated
__all__
=
[
'IntBox'
,
'FloatBox'
]
...
...
tensorpack/utils/serialize.py
View file @
ac9ac2a4
# -*- coding: utf-8 -*-
# File: serialize.py
import
sys
import
os
from
.develop
import
create_dummy_func
import
sys
from
.
import
logger
from
.develop
import
create_dummy_func
__all__
=
[
'loads'
,
'dumps'
]
...
...
tensorpack/utils/timer.py
View file @
ac9ac2a4
...
...
@@ -2,14 +2,14 @@
# File: timer.py
from
contextlib
import
contextmanager
from
collections
import
defaultdict
import
six
import
atexit
from
collections
import
defaultdict
from
contextlib
import
contextmanager
from
time
import
time
as
timer
import
six
from
.stats
import
StatCounter
from
.
import
logger
from
.stats
import
StatCounter
if
six
.
PY3
:
from
time
import
perf_counter
as
timer
# noqa
...
...
tensorpack/utils/utils.py
View file @
ac9ac2a4
...
...
@@ -2,17 +2,16 @@
# File: utils.py
import
inspect
import
numpy
as
np
import
os
import
sys
from
contextlib
import
contextmanager
import
inspect
from
datetime
import
datetime
,
timedelta
from
tqdm
import
tqdm
import
numpy
as
np
from
.
import
logger
__all__
=
[
'change_env'
,
'get_rng'
,
'fix_rng_seed'
,
...
...
tensorpack/utils/viz.py
View file @
ac9ac2a4
...
...
@@ -5,8 +5,10 @@
import
numpy
as
np
import
os
import
sys
from
.fs
import
mkdir_p
from
..utils.develop
import
create_dummy_func
# noqa
from
.argtools
import
shape2d
from
.fs
import
mkdir_p
from
.palette
import
PALETTE_RGB
try
:
...
...
@@ -411,7 +413,6 @@ def draw_boxes(im, boxes, labels=None, color=None):
return
im
from
..utils.develop
import
create_dummy_func
# noqa
try
:
import
matplotlib.pyplot
as
plt
except
(
ImportError
,
RuntimeError
):
...
...
tests/run-tests.sh
View file @
ac9ac2a4
...
...
@@ -8,16 +8,11 @@ export TF_CPP_MIN_LOG_LEVEL=2
# test import (#471)
python
-c
'from tensorpack.dataflow.imgaug import transform'
#
python -m unittest discover -v
python
-m
unittest discover
-v
# python -m tensorpack.models._test
# segfault for no reason (https://travis-ci.org/ppwwyyxx/tensorpack/jobs/217702985)
# python ../tensorpack/user_ops/test-recv-op.py
python test_char_rnn.py
python test_infogan.py
python test_mnist.py
python test_mnist_similarity.py
TENSORPACK_SERIALIZE
=
pyarrow python test_serializer.py
TENSORPACK_SERIALIZE
=
msgpack python test_serializer.py
tests/test_char_rnn.py
View file @
ac9ac2a4
from
case_script
import
TestPythonScript
import
os
from
case_script
import
TestPythonScript
def
random_content
():
return
(
'Lorem ipsum dolor sit amet
\n
'
...
...
tests/test_infogan.py
View file @
ac9ac2a4
...
...
@@ -10,6 +10,7 @@ class InfoGANTest(TestPythonScript):
return
'../examples/GAN/InfoGAN-mnist.py'
def
test
(
self
):
return
True
# https://github.com/tensorflow/tensorflow/issues/24517
if
get_tf_version_tuple
()
<
(
1
,
4
):
return
True
# requires leaky_relu
self
.
assertSurvive
(
self
.
script
,
args
=
None
)
tests/test_serializer.py
View file @
ac9ac2a4
#! /usr/bin/env python
# -*- coding: utf-8 -*-
from
tensorpack.dataflow.base
import
DataFlow
from
tensorpack.dataflow
import
LMDBSerializer
,
TFRecordSerializer
,
NumpySerializer
,
HDF5Serializer
import
unittest
import
os
import
numpy
as
np
import
os
import
unittest
from
tensorpack.dataflow
import
HDF5Serializer
,
LMDBSerializer
,
NumpySerializer
,
TFRecordSerializer
from
tensorpack.dataflow.base
import
DataFlow
def
delete_file_if_exists
(
fn
):
...
...
tox.ini
View file @
ac9ac2a4
...
...
@@ -12,3 +12,13 @@ exclude = .git,
snippet,
examples-old,
_test.py,
[isort]
line_length
=
100
skip
=
docs/conf.py
multi_line_output
=
4
known_tensorpack
=
tensorpack
known_standard_library
=
numpy
known_third_party
=
bob,gym,matplotlib
no_lines_before
=
STDLIB,THIRDPARTY
sections
=
FUTURE,STDLIB,THIRDPARTY,tensorpack,FIRSTPARTY,LOCALFOLDER
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment