Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
7bdaf8ec
Commit
7bdaf8ec
authored
Sep 01, 2019
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
docs cleanup
parent
f17d16da
Changes
21
Hide whitespace changes
Inline
Side-by-side
Showing
21 changed files
with
61 additions
and
73 deletions
+61
-73
docs/conf.py
docs/conf.py
+4
-4
docs/modules/utils.rst
docs/modules/utils.rst
+0
-9
examples/CTC-TIMIT/create-lmdb.py
examples/CTC-TIMIT/create-lmdb.py
+1
-2
examples/DeepQNetwork/atari.py
examples/DeepQNetwork/atari.py
+1
-2
examples/DeepQNetwork/common.py
examples/DeepQNetwork/common.py
+2
-4
examples/DeepQNetwork/expreplay.py
examples/DeepQNetwork/expreplay.py
+1
-2
examples/FasterRCNN/eval.py
examples/FasterRCNN/eval.py
+1
-2
examples/FasterRCNN/train.py
examples/FasterRCNN/train.py
+2
-0
examples/GAN/GAN.py
examples/GAN/GAN.py
+1
-1
tensorpack/graph_builder/distributed.py
tensorpack/graph_builder/distributed.py
+1
-1
tensorpack/graph_builder/training.py
tensorpack/graph_builder/training.py
+3
-3
tensorpack/graph_builder/utils.py
tensorpack/graph_builder/utils.py
+1
-7
tensorpack/predict/base.py
tensorpack/predict/base.py
+3
-4
tensorpack/predict/concurrency.py
tensorpack/predict/concurrency.py
+8
-3
tensorpack/predict/dataset.py
tensorpack/predict/dataset.py
+3
-0
tensorpack/tfutils/varmanip.py
tensorpack/tfutils/varmanip.py
+21
-21
tensorpack/utils/argtools.py
tensorpack/utils/argtools.py
+1
-1
tensorpack/utils/concurrency.py
tensorpack/utils/concurrency.py
+1
-2
tensorpack/utils/loadcaffe.py
tensorpack/utils/loadcaffe.py
+2
-0
tensorpack/utils/serialize.py
tensorpack/utils/serialize.py
+1
-1
tensorpack/utils/timer.py
tensorpack/utils/timer.py
+3
-4
No files found.
docs/conf.py
View file @
7bdaf8ec
...
@@ -399,11 +399,11 @@ _DEPRECATED_NAMES = set([
...
@@ -399,11 +399,11 @@ _DEPRECATED_NAMES = set([
'l2_regularizer'
,
'l1_regularizer'
,
'l2_regularizer'
,
'l1_regularizer'
,
# internal only
# internal only
'execute_only_once'
,
'humanize_time_delta'
,
'SessionUpdate'
,
'SessionUpdate'
,
'average_grads'
,
'get_checkpoint_path'
,
'aggregate_grads'
,
'IterSpeedCounter'
'allreduce_grads'
,
'get_checkpoint_path'
])
])
def
autodoc_skip_member
(
app
,
what
,
name
,
obj
,
skip
,
options
):
def
autodoc_skip_member
(
app
,
what
,
name
,
obj
,
skip
,
options
):
...
...
docs/modules/utils.rst
View file @
7bdaf8ec
...
@@ -56,15 +56,6 @@ tensorpack.utils.serialize module
...
@@ -56,15 +56,6 @@ tensorpack.utils.serialize module
:undoc-members:
:undoc-members:
:show-inheritance:
:show-inheritance:
tensorpack.utils.compatible_serialize module
--------------------------------------------
.. automodule:: tensorpack.utils.compatible_serialize
:members:
:undoc-members:
:show-inheritance:
tensorpack.utils.stats module
tensorpack.utils.stats module
-----------------------------
-----------------------------
...
...
examples/CTC-TIMIT/create-lmdb.py
View file @
7bdaf8ec
...
@@ -10,10 +10,9 @@ import bob.ap
...
@@ -10,10 +10,9 @@ import bob.ap
import
scipy.io.wavfile
as
wavfile
import
scipy.io.wavfile
as
wavfile
from
tensorpack.dataflow
import
DataFlow
,
LMDBSerializer
from
tensorpack.dataflow
import
DataFlow
,
LMDBSerializer
from
tensorpack.utils
import
fs
,
logger
,
serialize
from
tensorpack.utils
import
fs
,
logger
,
serialize
,
get_tqdm
from
tensorpack.utils.argtools
import
memoized
from
tensorpack.utils.argtools
import
memoized
from
tensorpack.utils.stats
import
OnlineMoments
from
tensorpack.utils.stats
import
OnlineMoments
from
tensorpack.utils.utils
import
get_tqdm
CHARSET
=
set
(
string
.
ascii_lowercase
+
' '
)
CHARSET
=
set
(
string
.
ascii_lowercase
+
' '
)
PHONEME_LIST
=
[
PHONEME_LIST
=
[
...
...
examples/DeepQNetwork/atari.py
View file @
7bdaf8ec
...
@@ -13,9 +13,8 @@ from gym import spaces
...
@@ -13,9 +13,8 @@ from gym import spaces
from
gym.envs.atari.atari_env
import
ACTION_MEANING
from
gym.envs.atari.atari_env
import
ACTION_MEANING
from
six.moves
import
range
from
six.moves
import
range
from
tensorpack.utils
import
logger
from
tensorpack.utils
import
logger
,
execute_only_once
,
get_rng
from
tensorpack.utils.fs
import
get_dataset_path
from
tensorpack.utils.fs
import
get_dataset_path
from
tensorpack.utils.utils
import
execute_only_once
,
get_rng
__all__
=
[
'AtariPlayer'
]
__all__
=
[
'AtariPlayer'
]
...
...
examples/DeepQNetwork/common.py
View file @
7bdaf8ec
...
@@ -7,13 +7,11 @@ import numpy as np
...
@@ -7,13 +7,11 @@ import numpy as np
import
random
import
random
import
time
import
time
from
six.moves
import
queue
from
six.moves
import
queue
from
tqdm
import
tqdm
from
tensorpack.callbacks
import
Callback
from
tensorpack.callbacks
import
Callback
from
tensorpack.utils
import
logger
from
tensorpack.utils
import
logger
,
get_tqdm
from
tensorpack.utils.concurrency
import
ShareSessionThread
,
StoppableThread
from
tensorpack.utils.concurrency
import
ShareSessionThread
,
StoppableThread
from
tensorpack.utils.stats
import
StatCounter
from
tensorpack.utils.stats
import
StatCounter
from
tensorpack.utils.utils
import
get_tqdm_kwargs
def
play_one_episode
(
env
,
func
,
render
=
False
):
def
play_one_episode
(
env
,
func
,
render
=
False
):
...
@@ -87,7 +85,7 @@ def eval_with_funcs(predictors, nr_eval, get_player_fn, verbose=False):
...
@@ -87,7 +85,7 @@ def eval_with_funcs(predictors, nr_eval, get_player_fn, verbose=False):
if
verbose
:
if
verbose
:
logger
.
info
(
"Score: {}"
.
format
(
r
))
logger
.
info
(
"Score: {}"
.
format
(
r
))
for
_
in
tqdm
(
range
(
nr_eval
),
**
get_tqdm_kwargs
(
)):
for
_
in
get_tqdm
(
range
(
nr_eval
)):
fetch
()
fetch
()
# waiting is necessary, otherwise the estimated mean score is biased
# waiting is necessary, otherwise the estimated mean score is biased
logger
.
info
(
"Waiting for all the workers to finish the last run..."
)
logger
.
info
(
"Waiting for all the workers to finish the last run..."
)
...
...
examples/DeepQNetwork/expreplay.py
View file @
7bdaf8ec
...
@@ -12,9 +12,8 @@ from six.moves import queue, range
...
@@ -12,9 +12,8 @@ from six.moves import queue, range
from
tensorpack.utils.concurrency
import
LoopThread
,
ShareSessionThread
from
tensorpack.utils.concurrency
import
LoopThread
,
ShareSessionThread
from
tensorpack.callbacks.base
import
Callback
from
tensorpack.callbacks.base
import
Callback
from
tensorpack.dataflow
import
DataFlow
from
tensorpack.dataflow
import
DataFlow
from
tensorpack.utils
import
logger
from
tensorpack.utils
import
logger
,
get_rng
,
get_tqdm
from
tensorpack.utils.stats
import
StatCounter
from
tensorpack.utils.stats
import
StatCounter
from
tensorpack.utils.utils
import
get_rng
,
get_tqdm
__all__
=
[
'ExpReplay'
]
__all__
=
[
'ExpReplay'
]
...
...
examples/FasterRCNN/eval.py
View file @
7bdaf8ec
...
@@ -17,8 +17,7 @@ from scipy import interpolate
...
@@ -17,8 +17,7 @@ from scipy import interpolate
from
tensorpack.callbacks
import
Callback
from
tensorpack.callbacks
import
Callback
from
tensorpack.tfutils.common
import
get_tf_version_tuple
from
tensorpack.tfutils.common
import
get_tf_version_tuple
from
tensorpack.utils
import
logger
from
tensorpack.utils
import
logger
,
get_tqdm
from
tensorpack.utils.utils
import
get_tqdm
from
common
import
CustomResize
,
clip_boxes
from
common
import
CustomResize
,
clip_boxes
from
config
import
config
as
cfg
from
config
import
config
as
cfg
...
...
examples/FasterRCNN/train.py
View file @
7bdaf8ec
...
@@ -47,6 +47,8 @@ if __name__ == '__main__':
...
@@ -47,6 +47,8 @@ if __name__ == '__main__':
# Setup logging ...
# Setup logging ...
is_horovod
=
cfg
.
TRAINER
==
'horovod'
is_horovod
=
cfg
.
TRAINER
==
'horovod'
if
is_horovod
:
hvd
.
init
()
if
not
is_horovod
or
hvd
.
rank
()
==
0
:
if
not
is_horovod
or
hvd
.
rank
()
==
0
:
logger
.
set_logger_dir
(
args
.
logdir
,
'd'
)
logger
.
set_logger_dir
(
args
.
logdir
,
'd'
)
logger
.
info
(
"Environment Information:
\n
"
+
collect_env_info
())
logger
.
info
(
"Environment Information:
\n
"
+
collect_env_info
())
...
...
examples/GAN/GAN.py
View file @
7bdaf8ec
...
@@ -129,7 +129,7 @@ class GANTrainer(TowerTrainer):
...
@@ -129,7 +129,7 @@ class GANTrainer(TowerTrainer):
self
.
tower_func
=
TowerFunc
(
get_cost
,
model
.
get_input_signature
())
self
.
tower_func
=
TowerFunc
(
get_cost
,
model
.
get_input_signature
())
devices
=
[
LeastLoadedDeviceSetter
(
d
,
raw_devices
)
for
d
in
raw_devices
]
devices
=
[
LeastLoadedDeviceSetter
(
d
,
raw_devices
)
for
d
in
raw_devices
]
cost_list
=
DataParallelBuilder
.
build_on_towers
(
cost_list
=
DataParallelBuilder
.
call_for_each_tower
(
list
(
range
(
num_gpu
)),
list
(
range
(
num_gpu
)),
lambda
:
self
.
tower_func
(
*
input
.
get_input_tensors
()),
lambda
:
self
.
tower_func
(
*
input
.
get_input_tensors
()),
devices
)
devices
)
...
...
tensorpack/graph_builder/distributed.py
View file @
7bdaf8ec
...
@@ -11,7 +11,7 @@ from ..utils.argtools import memoized
...
@@ -11,7 +11,7 @@ from ..utils.argtools import memoized
from
.training
import
DataParallelBuilder
,
GraphBuilder
from
.training
import
DataParallelBuilder
,
GraphBuilder
from
.utils
import
OverrideCachingDevice
,
aggregate_grads
,
override_to_local_variable
from
.utils
import
OverrideCachingDevice
,
aggregate_grads
,
override_to_local_variable
__all__
=
[
'DistributedParameterServerBuilder'
,
'DistributedReplicatedBuilder'
]
__all__
=
[]
class
DistributedBuilderBase
(
GraphBuilder
):
class
DistributedBuilderBase
(
GraphBuilder
):
...
...
tensorpack/graph_builder/training.py
View file @
7bdaf8ec
...
@@ -15,13 +15,12 @@ from ..tfutils.common import get_tf_version_tuple
...
@@ -15,13 +15,12 @@ from ..tfutils.common import get_tf_version_tuple
from
..tfutils.gradproc
import
ScaleGradient
from
..tfutils.gradproc
import
ScaleGradient
from
..tfutils.tower
import
TrainTowerContext
from
..tfutils.tower
import
TrainTowerContext
from
..utils
import
logger
from
..utils
import
logger
from
..utils.develop
import
HIDE_DOC
from
.utils
import
(
from
.utils
import
(
GradientPacker
,
LeastLoadedDeviceSetter
,
aggregate_grads
,
allreduce_grads
,
allreduce_grads_hierarchical
,
GradientPacker
,
LeastLoadedDeviceSetter
,
aggregate_grads
,
allreduce_grads
,
allreduce_grads_hierarchical
,
merge_grad_list
,
override_to_local_variable
,
split_grad_list
)
merge_grad_list
,
override_to_local_variable
,
split_grad_list
)
__all__
=
[
'GraphBuilder'
,
__all__
=
[
"DataParallelBuilder"
]
'SyncMultiGPUParameterServerBuilder'
,
'DataParallelBuilder'
,
'SyncMultiGPUReplicatedBuilder'
,
'AsyncMultiGPUBuilder'
]
@
six
.
add_metaclass
(
ABCMeta
)
@
six
.
add_metaclass
(
ABCMeta
)
...
@@ -117,6 +116,7 @@ class DataParallelBuilder(GraphBuilder):
...
@@ -117,6 +116,7 @@ class DataParallelBuilder(GraphBuilder):
ret
.
append
(
func
())
ret
.
append
(
func
())
return
ret
return
ret
@
HIDE_DOC
@
staticmethod
@
staticmethod
def
build_on_towers
(
*
args
,
**
kwargs
):
def
build_on_towers
(
*
args
,
**
kwargs
):
return
DataParallelBuilder
.
call_for_each_tower
(
*
args
,
**
kwargs
)
return
DataParallelBuilder
.
call_for_each_tower
(
*
args
,
**
kwargs
)
...
...
tensorpack/graph_builder/utils.py
View file @
7bdaf8ec
...
@@ -13,13 +13,7 @@ from ..tfutils.varreplace import custom_getter_scope
...
@@ -13,13 +13,7 @@ from ..tfutils.varreplace import custom_getter_scope
from
..utils
import
logger
from
..utils
import
logger
from
..utils.argtools
import
call_only_once
from
..utils.argtools
import
call_only_once
__all__
=
[
'LeastLoadedDeviceSetter'
,
__all__
=
[
"LeastLoadedDeviceSetter"
]
'OverrideCachingDevice'
,
'override_to_local_variable'
,
'allreduce_grads'
,
'average_grads'
,
'aggregate_grads'
]
"""
"""
...
...
tensorpack/predict/base.py
View file @
7bdaf8ec
...
@@ -10,9 +10,8 @@ from ..input_source import PlaceholderInput
...
@@ -10,9 +10,8 @@ from ..input_source import PlaceholderInput
from
..tfutils.common
import
get_tensors_by_names
from
..tfutils.common
import
get_tensors_by_names
from
..tfutils.tower
import
PredictTowerContext
from
..tfutils.tower
import
PredictTowerContext
__all__
=
[
'PredictorBase'
,
'AsyncPredictorBase'
,
__all__
=
[
'PredictorBase'
,
'OnlinePredictor'
,
'OfflinePredictor'
,
'OnlinePredictor'
,
'OfflinePredictor'
]
]
@
six
.
add_metaclass
(
ABCMeta
)
@
six
.
add_metaclass
(
ABCMeta
)
...
@@ -62,7 +61,7 @@ class AsyncPredictorBase(PredictorBase):
...
@@ -62,7 +61,7 @@ class AsyncPredictorBase(PredictorBase):
dp (list): A datapoint as inputs. It could be either batched or not
dp (list): A datapoint as inputs. It could be either batched or not
batched depending on the predictor implementation).
batched depending on the predictor implementation).
callback: a thread-safe callback to get called with
callback: a thread-safe callback to get called with
either outputs or (inputs, outputs).
either outputs or (inputs, outputs)
, if `return_input` is True
.
Returns:
Returns:
concurrent.futures.Future: a Future of results
concurrent.futures.Future: a Future of results
"""
"""
...
...
tensorpack/predict/concurrency.py
View file @
7bdaf8ec
...
@@ -14,8 +14,7 @@ from ..utils import logger
...
@@ -14,8 +14,7 @@ from ..utils import logger
from
..utils.concurrency
import
DIE
,
ShareSessionThread
,
StoppableThread
from
..utils.concurrency
import
DIE
,
ShareSessionThread
,
StoppableThread
from
.base
import
AsyncPredictorBase
,
OfflinePredictor
,
OnlinePredictor
from
.base
import
AsyncPredictorBase
,
OfflinePredictor
,
OnlinePredictor
__all__
=
[
'MultiProcessPredictWorker'
,
'MultiProcessQueuePredictWorker'
,
__all__
=
[
'MultiThreadAsyncPredictor'
]
'MultiThreadAsyncPredictor'
]
class
MultiProcessPredictWorker
(
multiprocessing
.
Process
):
class
MultiProcessPredictWorker
(
multiprocessing
.
Process
):
...
@@ -171,7 +170,13 @@ class MultiThreadAsyncPredictor(AsyncPredictorBase):
...
@@ -171,7 +170,13 @@ class MultiThreadAsyncPredictor(AsyncPredictorBase):
def
put_task
(
self
,
dp
,
callback
=
None
):
def
put_task
(
self
,
dp
,
callback
=
None
):
"""
"""
Same as in :meth:`AsyncPredictorBase.put_task`.
Args:
dp (list): A datapoint as inputs. It could be either batched or not
batched depending on the predictor implementation).
callback: a thread-safe callback. When the results are ready, it will be called
with the "future" object.
Returns:
concurrent.futures.Future: a Future of results.
"""
"""
f
=
Future
()
f
=
Future
()
if
callback
is
not
None
:
if
callback
is
not
None
:
...
...
tensorpack/predict/dataset.py
View file @
7bdaf8ec
...
@@ -11,6 +11,7 @@ from six.moves import range, zip
...
@@ -11,6 +11,7 @@ from six.moves import range, zip
from
..dataflow
import
DataFlow
from
..dataflow
import
DataFlow
from
..dataflow.remote
import
dump_dataflow_to_process_queue
from
..dataflow.remote
import
dump_dataflow_to_process_queue
from
..utils
import
logger
from
..utils
import
logger
from
..utils.develop
import
HIDE_DOC
from
..utils.concurrency
import
DIE
,
OrderedResultGatherProc
,
ensure_proc_terminate
from
..utils.concurrency
import
DIE
,
OrderedResultGatherProc
,
ensure_proc_terminate
from
..utils.gpu
import
change_gpu
,
get_num_gpu
from
..utils.gpu
import
change_gpu
,
get_num_gpu
from
..utils.utils
import
get_tqdm
from
..utils.utils
import
get_tqdm
...
@@ -63,6 +64,7 @@ class SimpleDatasetPredictor(DatasetPredictorBase):
...
@@ -63,6 +64,7 @@ class SimpleDatasetPredictor(DatasetPredictorBase):
super
(
SimpleDatasetPredictor
,
self
)
.
__init__
(
config
,
dataset
)
super
(
SimpleDatasetPredictor
,
self
)
.
__init__
(
config
,
dataset
)
self
.
predictor
=
OfflinePredictor
(
config
)
self
.
predictor
=
OfflinePredictor
(
config
)
@
HIDE_DOC
def
get_result
(
self
):
def
get_result
(
self
):
self
.
dataset
.
reset_state
()
self
.
dataset
.
reset_state
()
try
:
try
:
...
@@ -142,6 +144,7 @@ class MultiProcessDatasetPredictor(DatasetPredictorBase):
...
@@ -142,6 +144,7 @@ class MultiProcessDatasetPredictor(DatasetPredictorBase):
self
.
result_queue
=
self
.
outqueue
self
.
result_queue
=
self
.
outqueue
ensure_proc_terminate
(
self
.
workers
+
[
self
.
inqueue_proc
])
ensure_proc_terminate
(
self
.
workers
+
[
self
.
inqueue_proc
])
@
HIDE_DOC
def
get_result
(
self
):
def
get_result
(
self
):
try
:
try
:
sz
=
len
(
self
.
dataset
)
sz
=
len
(
self
.
dataset
)
...
...
tensorpack/tfutils/varmanip.py
View file @
7bdaf8ec
...
@@ -184,47 +184,47 @@ def save_chkpt_vars(dic, path):
...
@@ -184,47 +184,47 @@ def save_chkpt_vars(dic, path):
saver
.
save
(
sess
,
path
,
write_meta_graph
=
False
)
saver
.
save
(
sess
,
path
,
write_meta_graph
=
False
)
def
get_checkpoint_path
(
model_
path
):
def
get_checkpoint_path
(
path
):
"""
"""
Work around TF problems in checkpoint path handling.
Work around TF problems in checkpoint path handling.
Args:
Args:
model_
path: a user-input path
path: a user-input path
Returns:
Returns:
str: the argument that can be passed to NewCheckpointReader
str: the argument that can be passed to NewCheckpointReader
"""
"""
if
os
.
path
.
basename
(
model_path
)
==
model_
path
:
if
os
.
path
.
basename
(
path
)
==
path
:
model_path
=
os
.
path
.
join
(
'.'
,
model_
path
)
# avoid #4921 and #6142
path
=
os
.
path
.
join
(
'.'
,
path
)
# avoid #4921 and #6142
if
os
.
path
.
basename
(
model_
path
)
==
'checkpoint'
:
if
os
.
path
.
basename
(
path
)
==
'checkpoint'
:
assert
tfv1
.
gfile
.
Exists
(
model_path
),
model_
path
assert
tfv1
.
gfile
.
Exists
(
path
),
path
model_path
=
tf
.
train
.
latest_checkpoint
(
os
.
path
.
dirname
(
model_
path
))
path
=
tf
.
train
.
latest_checkpoint
(
os
.
path
.
dirname
(
path
))
# to be consistent with either v1 or v2
# to be consistent with either v1 or v2
# fix paths if provided a wrong one
# fix paths if provided a wrong one
new_path
=
model_
path
new_path
=
path
if
'00000-of-00001'
in
model_
path
:
if
'00000-of-00001'
in
path
:
new_path
=
model_
path
.
split
(
'.data'
)[
0
]
new_path
=
path
.
split
(
'.data'
)[
0
]
elif
model_
path
.
endswith
(
'.index'
):
elif
path
.
endswith
(
'.index'
):
new_path
=
model_
path
.
split
(
'.index'
)[
0
]
new_path
=
path
.
split
(
'.index'
)[
0
]
if
new_path
!=
model_
path
:
if
new_path
!=
path
:
logger
.
info
(
logger
.
info
(
"Checkpoint path {} is auto-corrected to {}."
.
format
(
model_
path
,
new_path
))
"Checkpoint path {} is auto-corrected to {}."
.
format
(
path
,
new_path
))
model_
path
=
new_path
path
=
new_path
assert
tfv1
.
gfile
.
Exists
(
model_path
)
or
tfv1
.
gfile
.
Exists
(
model_path
+
'.index'
),
model_
path
assert
tfv1
.
gfile
.
Exists
(
path
)
or
tfv1
.
gfile
.
Exists
(
path
+
'.index'
),
path
return
model_
path
return
path
def
load_chkpt_vars
(
model_
path
):
def
load_chkpt_vars
(
path
):
""" Load all variables from a checkpoint to a dict.
""" Load all variables from a checkpoint to a dict.
Args:
Args:
model_
path(str): path to a checkpoint.
path(str): path to a checkpoint.
Returns:
Returns:
dict: a name:value dict
dict: a name:value dict
"""
"""
model_path
=
get_checkpoint_path
(
model_
path
)
path
=
get_checkpoint_path
(
path
)
reader
=
tfv1
.
train
.
NewCheckpointReader
(
model_
path
)
reader
=
tfv1
.
train
.
NewCheckpointReader
(
path
)
var_names
=
reader
.
get_variable_to_shape_map
()
.
keys
()
var_names
=
reader
.
get_variable_to_shape_map
()
.
keys
()
result
=
{}
result
=
{}
for
n
in
var_names
:
for
n
in
var_names
:
...
...
tensorpack/utils/argtools.py
View file @
7bdaf8ec
...
@@ -13,7 +13,7 @@ else:
...
@@ -13,7 +13,7 @@ else:
import
functools
import
functools
__all__
=
[
'map_arg'
,
'memoized'
,
'memoized_method'
,
'graph_memoized'
,
'shape2d'
,
'shape4d'
,
__all__
=
[
'map_arg'
,
'memoized'
,
'memoized_method'
,
'graph_memoized'
,
'shape2d'
,
'shape4d'
,
'memoized_ignoreargs'
,
'log_once'
,
'call_only_once'
]
'memoized_ignoreargs'
,
'log_once'
]
def
map_arg
(
**
maps
):
def
map_arg
(
**
maps
):
...
...
tensorpack/utils/concurrency.py
View file @
7bdaf8ec
...
@@ -26,8 +26,7 @@ else:
...
@@ -26,8 +26,7 @@ else:
__all__
=
[
'StoppableThread'
,
'LoopThread'
,
'ShareSessionThread'
,
__all__
=
[
'StoppableThread'
,
'LoopThread'
,
'ShareSessionThread'
,
'ensure_proc_terminate'
,
'ensure_proc_terminate'
,
'OrderedResultGatherProc'
,
'OrderedContainer'
,
'DIE'
,
'start_proc_mask_signal'
]
'mask_sigint'
,
'start_proc_mask_signal'
]
class
StoppableThread
(
threading
.
Thread
):
class
StoppableThread
(
threading
.
Thread
):
...
...
tensorpack/utils/loadcaffe.py
View file @
7bdaf8ec
...
@@ -97,6 +97,7 @@ def load_caffe(model_desc, model_file):
...
@@ -97,6 +97,7 @@ def load_caffe(model_desc, model_file):
"""
"""
Load a caffe model. You must be able to ``import caffe`` to use this
Load a caffe model. You must be able to ``import caffe`` to use this
function.
function.
Args:
Args:
model_desc (str): path to caffe model description file (.prototxt).
model_desc (str): path to caffe model description file (.prototxt).
model_file (str): path to caffe model parameter file (.caffemodel).
model_file (str): path to caffe model parameter file (.caffemodel).
...
@@ -116,6 +117,7 @@ def load_caffe(model_desc, model_file):
...
@@ -116,6 +117,7 @@ def load_caffe(model_desc, model_file):
def
get_caffe_pb
():
def
get_caffe_pb
():
"""
"""
Get caffe protobuf.
Get caffe protobuf.
Returns:
Returns:
The imported caffe protobuf module.
The imported caffe protobuf module.
"""
"""
...
...
tensorpack/utils/serialize.py
View file @
7bdaf8ec
...
@@ -13,7 +13,7 @@ from .develop import create_dummy_func
...
@@ -13,7 +13,7 @@ from .develop import create_dummy_func
msgpack_numpy
.
patch
()
msgpack_numpy
.
patch
()
assert
msgpack
.
version
>=
(
0
,
5
,
2
)
assert
msgpack
.
version
>=
(
0
,
5
,
2
)
__all__
=
[
'loads'
,
'dumps'
,
'NonPicklableWrapper'
]
__all__
=
[
'loads'
,
'dumps'
]
MAX_MSGPACK_LEN
=
1000000000
MAX_MSGPACK_LEN
=
1000000000
...
...
tensorpack/utils/timer.py
View file @
7bdaf8ec
...
@@ -15,8 +15,7 @@ if six.PY3:
...
@@ -15,8 +15,7 @@ if six.PY3:
from
time
import
perf_counter
as
timer
# noqa
from
time
import
perf_counter
as
timer
# noqa
__all__
=
[
'total_timer'
,
'timed_operation'
,
__all__
=
[
'timed_operation'
,
'IterSpeedCounter'
,
'Timer'
]
'print_total_timer'
,
'IterSpeedCounter'
,
'Timer'
]
@
contextmanager
@
contextmanager
...
@@ -55,7 +54,7 @@ _TOTAL_TIMER_DATA = defaultdict(StatCounter)
...
@@ -55,7 +54,7 @@ _TOTAL_TIMER_DATA = defaultdict(StatCounter)
@
contextmanager
@
contextmanager
def
total_timer
(
msg
):
def
total_timer
(
msg
):
""" A context which add the time spent inside to TotalTimer. """
""" A context which add the time spent inside to
the global
TotalTimer. """
start
=
timer
()
start
=
timer
()
yield
yield
t
=
timer
()
-
start
t
=
timer
()
-
start
...
@@ -64,7 +63,7 @@ def total_timer(msg):
...
@@ -64,7 +63,7 @@ def total_timer(msg):
def
print_total_timer
():
def
print_total_timer
():
"""
"""
Print the content of the TotalTimer, if it's not empty. This function will automatically get
Print the content of the
global
TotalTimer, if it's not empty. This function will automatically get
called when program exits.
called when program exits.
"""
"""
if
len
(
_TOTAL_TIMER_DATA
)
==
0
:
if
len
(
_TOTAL_TIMER_DATA
)
==
0
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment