Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
57fb68fa
Commit
57fb68fa
authored
Nov 20, 2016
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
unordered datasetpredictor & more tqdm
parent
a2f4f439
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
35 additions
and
23 deletions
+35
-23
tensorpack/callbacks/inference.py
tensorpack/callbacks/inference.py
+2
-3
tensorpack/dataflow/common.py
tensorpack/dataflow/common.py
+2
-3
tensorpack/dataflow/format.py
tensorpack/dataflow/format.py
+2
-3
tensorpack/predict/dataset.py
tensorpack/predict/dataset.py
+24
-14
tensorpack/utils/utils.py
tensorpack/utils/utils.py
+5
-0
No files found.
tensorpack/callbacks/inference.py
View file @
57fb68fa
...
...
@@ -4,14 +4,13 @@
import
tensorflow
as
tf
import
numpy
as
np
from
tqdm
import
tqdm
from
abc
import
ABCMeta
,
abstractmethod
from
collections
import
namedtuple
import
six
from
six.moves
import
zip
,
map
from
..dataflow
import
DataFlow
from
..utils
import
get_tqdm
_kwargs
,
logger
,
execute_only_once
from
..utils
import
get_tqdm
,
logger
,
execute_only_once
from
..utils.stat
import
RatioCounter
,
BinaryStatistics
from
..tfutils
import
get_op_tensor_name
,
get_op_var_name
from
.base
import
Callback
...
...
@@ -124,7 +123,7 @@ class InferenceRunner(Callback):
sess
=
tf
.
get_default_session
()
self
.
ds
.
reset_state
()
with
tqdm
(
total
=
self
.
ds
.
size
(),
**
get_tqdm_kwargs
())
as
pbar
:
with
get_tqdm
(
total
=
self
.
ds
.
size
())
as
pbar
:
for
dp
in
self
.
ds
.
get_data
():
outputs
=
self
.
pred_func
(
dp
)
for
inf
,
tensormap
in
zip
(
self
.
infs
,
self
.
inf_to_tensors
):
...
...
tensorpack/dataflow/common.py
View file @
57fb68fa
...
...
@@ -8,7 +8,7 @@ import numpy as np
from
collections
import
deque
,
defaultdict
from
six.moves
import
range
,
map
from
.base
import
DataFlow
,
ProxyDataFlow
,
RNGDataFlow
from
..utils
import
*
from
..utils
import
logger
,
get_tqdm
__all__
=
[
'BatchData'
,
'FixedSizeData'
,
'MapData'
,
'RepeatedData'
,
'MapDataComponent'
,
'RandomChooseData'
,
...
...
@@ -21,8 +21,7 @@ class TestDataSpeed(ProxyDataFlow):
self
.
test_size
=
size
def
get_data
(
self
):
from
tqdm
import
tqdm
with
tqdm
(
range
(
self
.
test_size
),
**
get_tqdm_kwargs
())
as
pbar
:
with
get_tqdm
(
total
=
range
(
self
.
test_size
))
as
pbar
:
for
dp
in
self
.
ds
.
get_data
():
pbar
.
update
()
for
dp
in
self
.
ds
.
get_data
():
...
...
tensorpack/dataflow/format.py
View file @
57fb68fa
...
...
@@ -3,10 +3,9 @@
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
import
numpy
as
np
from
tqdm
import
tqdm
from
six.moves
import
range
from
..utils
import
logger
,
get_rng
,
get_tqdm
_kwargs
from
..utils
import
logger
,
get_rng
,
get_tqdm
from
..utils.timer
import
timed_operation
from
..utils.loadcaffe
import
get_caffe_pb
from
.base
import
RNGDataFlow
...
...
@@ -82,7 +81,7 @@ class LMDBData(RNGDataFlow):
if
not
self
.
keys
:
self
.
keys
=
[]
with
timed_operation
(
"Loading LMDB keys ..."
,
log_start
=
True
),
\
tqdm
(
get_tqdm_kwargs
(
total
=
self
.
_size
)
)
as
pbar
:
get_tqdm
(
total
=
self
.
_size
)
as
pbar
:
for
k
in
self
.
_txn
.
cursor
():
if
k
!=
'__keys__'
:
self
.
keys
.
append
(
k
)
...
...
tensorpack/predict/dataset.py
View file @
57fb68fa
...
...
@@ -4,7 +4,6 @@
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
from
six.moves
import
range
,
zip
from
tqdm
import
tqdm
from
abc
import
ABCMeta
,
abstractmethod
import
multiprocessing
import
os
...
...
@@ -12,7 +11,7 @@ import os
from
..dataflow
import
DataFlow
,
BatchData
from
..dataflow.dftools
import
dataflow_to_process_queue
from
..utils.concurrency
import
ensure_proc_terminate
,
OrderedResultGatherProc
,
DIE
from
..utils
import
logger
from
..utils
import
logger
,
get_tqdm
from
..utils.gpu
import
change_gpu
from
.concurrency
import
MultiProcessQueuePredictWorker
...
...
@@ -60,7 +59,7 @@ class SimpleDatasetPredictor(DatasetPredictorBase):
sz
=
self
.
dataset
.
size
()
except
NotImplementedError
:
sz
=
0
with
tqdm
(
total
=
sz
,
disable
=
(
sz
==
0
))
as
pbar
:
with
get_
tqdm
(
total
=
sz
,
disable
=
(
sz
==
0
))
as
pbar
:
for
dp
in
self
.
dataset
.
get_data
():
res
=
self
.
predictor
(
dp
)
yield
res
...
...
@@ -68,13 +67,15 @@ class SimpleDatasetPredictor(DatasetPredictorBase):
# TODO allow unordered
class
MultiProcessDatasetPredictor
(
DatasetPredictorBase
):
def
__init__
(
self
,
config
,
dataset
,
nr_proc
,
use_gpu
=
True
):
def
__init__
(
self
,
config
,
dataset
,
nr_proc
,
use_gpu
=
True
,
ordered
=
True
):
"""
Run prediction in multiprocesses, on either CPU or GPU. Mix mode not supported.
:param nr_proc: number of processes to use
:param use_gpu: use GPU or CPU.
If GPU, then nr_proc cannot be more than what's in CUDA_VISIBLE_DEVICES
:param ordered: produce results with the original order of the
dataflow. a bit slower.
"""
if
config
.
return_input
:
logger
.
warn
(
"Using the option `return_input` in MultiProcessDatasetPredictor might be slow"
)
...
...
@@ -82,10 +83,11 @@ class MultiProcessDatasetPredictor(DatasetPredictorBase):
super
(
MultiProcessDatasetPredictor
,
self
)
.
__init__
(
config
,
dataset
)
self
.
nr_proc
=
nr_proc
self
.
ordered
=
ordered
self
.
inqueue
,
self
.
inqueue_proc
=
dataflow_to_process_queue
(
self
.
dataset
,
nr_proc
*
2
,
self
.
nr_proc
)
self
.
outqueue
=
multiprocessing
.
Queue
()
self
.
dataset
,
nr_proc
*
2
,
self
.
nr_proc
)
# put (idx, dp) to inqueue
if
use_gpu
:
try
:
gpus
=
os
.
environ
[
'CUDA_VISIBLE_DEVICES'
]
.
split
(
','
)
...
...
@@ -97,13 +99,13 @@ class MultiProcessDatasetPredictor(DatasetPredictorBase):
gpus
=
list
(
range
(
self
.
nr_proc
))
else
:
gpus
=
[
'-1'
]
*
self
.
nr_proc
# worker produces (idx, result) to outqueue
self
.
outqueue
=
multiprocessing
.
Queue
()
self
.
workers
=
[
MultiProcessQueuePredictWorker
(
i
,
self
.
inqueue
,
self
.
outqueue
,
self
.
config
)
for
i
in
range
(
self
.
nr_proc
)]
self
.
result_queue
=
OrderedResultGatherProc
(
self
.
outqueue
,
nr_producer
=
self
.
nr_proc
)
# s
etup all the proc
s
# s
tart inqueue and worker
s
self
.
inqueue_proc
.
start
()
for
p
,
gpuid
in
zip
(
self
.
workers
,
gpus
):
if
gpuid
==
'-1'
:
...
...
@@ -112,15 +114,22 @@ class MultiProcessDatasetPredictor(DatasetPredictorBase):
logger
.
info
(
"Worker {} uses GPU {}"
.
format
(
p
.
idx
,
gpuid
))
with
change_gpu
(
gpuid
):
p
.
start
()
if
ordered
:
self
.
result_queue
=
OrderedResultGatherProc
(
self
.
outqueue
,
nr_producer
=
self
.
nr_proc
)
self
.
result_queue
.
start
()
ensure_proc_terminate
(
self
.
workers
+
[
self
.
result_queue
,
self
.
inqueue_proc
])
ensure_proc_terminate
(
self
.
result_queue
)
else
:
self
.
result_queue
=
self
.
outqueue
ensure_proc_terminate
(
self
.
workers
+
[
self
.
inqueue_proc
])
def
get_result
(
self
):
try
:
sz
=
self
.
dataset
.
size
()
except
NotImplementedError
:
sz
=
0
with
tqdm
(
total
=
sz
,
disable
=
(
sz
==
0
))
as
pbar
:
with
get_
tqdm
(
total
=
sz
,
disable
=
(
sz
==
0
))
as
pbar
:
die_cnt
=
0
while
True
:
res
=
self
.
result_queue
.
get
()
...
...
@@ -133,6 +142,7 @@ class MultiProcessDatasetPredictor(DatasetPredictorBase):
break
self
.
inqueue_proc
.
join
()
self
.
inqueue_proc
.
terminate
()
if
self
.
ordered
:
# if ordered, than result_queue is a Process
self
.
result_queue
.
join
()
self
.
result_queue
.
terminate
()
for
p
in
self
.
workers
:
...
...
tensorpack/utils/utils.py
View file @
57fb68fa
...
...
@@ -6,6 +6,7 @@ import os, sys
from
contextlib
import
contextmanager
import
inspect
from
datetime
import
datetime
from
tqdm
import
tqdm
import
time
import
numpy
as
np
...
...
@@ -13,6 +14,7 @@ __all__ = ['change_env',
'get_rng'
,
'get_dataset_path'
,
'get_tqdm_kwargs'
,
'get_tqdm'
,
'execute_only_once'
]
...
...
@@ -73,3 +75,6 @@ def get_tqdm_kwargs(**kwargs):
default
[
'mininterval'
]
=
60
default
.
update
(
kwargs
)
return
default
def
get_tqdm
(
**
kwargs
):
return
tqdm
(
**
get_tqdm_kwargs
(
**
kwargs
))
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment