Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
c9107add
Commit
c9107add
authored
Apr 22, 2016
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
multigpu predictor
parent
1dcc0e72
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
163 additions
and
17 deletions
+163
-17
tensorpack/callbacks/dump.py
tensorpack/callbacks/dump.py
+1
-0
tensorpack/dataflow/prefetch.py
tensorpack/dataflow/prefetch.py
+4
-3
tensorpack/predict.py
tensorpack/predict.py
+89
-9
tensorpack/train/trainer.py
tensorpack/train/trainer.py
+2
-0
tensorpack/utils/concurrency.py
tensorpack/utils/concurrency.py
+67
-5
No files found.
tensorpack/callbacks/dump.py
View file @
c9107add
...
...
@@ -40,6 +40,7 @@ class DumpParamAsImage(Callback):
self
.
clip
=
clip
def
_before_train
(
self
):
# TODO might not work for multiGPU?
self
.
var
=
self
.
graph
.
get_tensor_by_name
(
self
.
var_name
)
def
_trigger_epoch
(
self
):
...
...
tensorpack/dataflow/prefetch.py
View file @
c9107add
...
...
@@ -6,7 +6,7 @@ import multiprocessing
from
six.moves
import
range
from
.base
import
ProxyDataFlow
from
..utils.concurrency
import
ensure_proc
s
_terminate
from
..utils.concurrency
import
ensure_proc_terminate
from
..utils
import
logger
__all__
=
[
'PrefetchData'
]
...
...
@@ -36,7 +36,8 @@ class PrefetchData(ProxyDataFlow):
"""
:param ds: a `DataFlow` instance.
:param nr_prefetch: size of the queue to hold prefetched datapoints.
:param nr_proc: number of processes to use.
:param nr_proc: number of processes to use. When larger than 1, order
of data points will be random.
"""
super
(
PrefetchData
,
self
)
.
__init__
(
ds
)
self
.
_size
=
self
.
size
()
...
...
@@ -45,7 +46,7 @@ class PrefetchData(ProxyDataFlow):
self
.
queue
=
multiprocessing
.
Queue
(
self
.
nr_prefetch
)
self
.
procs
=
[
PrefetchProcess
(
self
.
ds
,
self
.
queue
)
for
_
in
range
(
self
.
nr_proc
)]
ensure_proc
s
_terminate
(
self
.
procs
)
ensure_proc_terminate
(
self
.
procs
)
for
x
in
self
.
procs
:
x
.
start
()
...
...
tensorpack/predict.py
View file @
c9107add
...
...
@@ -7,9 +7,13 @@ from itertools import count
import
argparse
from
collections
import
namedtuple
import
numpy
as
np
import
bisect
from
tqdm
import
tqdm
from
six.moves
import
zip
import
multiprocessing
from
.utils.concurrency
import
ensure_proc_terminate
,
OrderedResultGatherProc
,
DIE
from
.tfutils
import
*
from
.utils
import
logger
from
.tfutils.modelutils
import
describe_model
...
...
@@ -50,6 +54,7 @@ class PredictConfig(object):
:param output_var_names: a list of names of the output variables to predict, the
variables can be any computable tensor in the graph.
Predict specific output might not require all input variables.
:param nr_gpu: default to 1. Use CUDA_VISIBLE_DEVICES to control which GPU to use sepcifically.
"""
def
assert_type
(
v
,
tp
):
assert
isinstance
(
v
,
tp
),
v
.
__class__
...
...
@@ -59,6 +64,7 @@ class PredictConfig(object):
self
.
model
=
kwargs
.
pop
(
'model'
)
self
.
input_data_mapping
=
kwargs
.
pop
(
'input_data_mapping'
,
None
)
self
.
output_var_names
=
kwargs
.
pop
(
'output_var_names'
)
self
.
nr_gpu
=
kwargs
.
pop
(
'nr_gpu'
,
1
)
assert
len
(
kwargs
)
==
0
,
'Unknown arguments: {}'
.
format
(
str
(
kwargs
.
keys
()))
def
get_predict_func
(
config
):
...
...
@@ -81,8 +87,6 @@ def get_predict_func(config):
output_vars
=
[
tf
.
get_default_graph
()
.
get_tensor_by_name
(
get_op_var_name
(
n
)[
1
])
for
n
in
output_var_names
]
describe_model
()
sess
=
tf
.
Session
(
config
=
config
.
session_config
)
config
.
session_init
.
init
(
sess
)
...
...
@@ -101,27 +105,103 @@ def get_predict_func(config):
PredictResult
=
namedtuple
(
'PredictResult'
,
[
'input'
,
'output'
])
# TODO mutligpu predictor
class
PredictWorker
(
multiprocessing
.
Process
):
def
__init__
(
self
,
idx
,
gpuid
,
inqueue
,
outqueue
,
config
):
super
(
PredictWorker
,
self
)
.
__init__
()
self
.
idx
=
idx
self
.
gpuid
=
gpuid
self
.
inqueue
=
inqueue
self
.
outqueue
=
outqueue
self
.
config
=
config
def
run
(
self
):
os
.
environ
[
'CUDA_VISIBLE_DEVICES'
]
=
self
.
gpuid
G
=
tf
.
Graph
()
# build a graph for each process, because they don't need to share anything
with
G
.
as_default
(),
tf
.
device
(
'/gpu:{}'
.
format
(
self
.
idx
)):
self
.
func
=
get_predict_func
(
self
.
config
)
if
self
.
idx
==
0
:
describe_model
()
while
True
:
tid
,
dp
=
self
.
inqueue
.
get
()
if
tid
==
DIE
:
self
.
outqueue
.
put
((
DIE
,
None
))
return
else
:
res
=
PredictResult
(
dp
,
self
.
func
(
dp
))
self
.
outqueue
.
put
((
tid
,
res
))
def
DFtoQueue
(
ds
,
size
,
nr_consumer
):
q
=
multiprocessing
.
Queue
(
size
)
class
EnqueProc
(
multiprocessing
.
Process
):
def
__init__
(
self
,
ds
,
q
,
nr_consumer
):
super
(
EnqueProc
,
self
)
.
__init__
()
self
.
ds
=
ds
self
.
q
=
q
def
run
(
self
):
for
idx
,
dp
in
enumerate
(
self
.
ds
.
get_data
()):
self
.
q
.
put
((
idx
,
dp
))
print
"Enqueue ends"
for
_
in
range
(
nr_consumer
):
self
.
q
.
put
((
DIE
,
None
))
proc
=
EnqueProc
(
ds
,
q
,
nr_consumer
)
return
q
,
proc
class
DatasetPredictor
(
object
):
"""
Run the predict_config on a given `DataFlow`.
"""
def
__init__
(
self
,
predict_
config
,
dataset
):
def
__init__
(
self
,
config
,
dataset
):
"""
:param
predict_
config: a `PredictConfig` instance.
:param config: a `PredictConfig` instance.
:param dataset: a `DataFlow` instance.
"""
assert
isinstance
(
dataset
,
DataFlow
)
self
.
ds
=
dataset
self
.
predict_func
=
get_predict_func
(
predict_config
)
self
.
nr_gpu
=
config
.
nr_gpu
if
self
.
nr_gpu
>
1
:
self
.
inqueue
,
self
.
inqueue_proc
=
DFtoQueue
(
self
.
ds
,
10
,
self
.
nr_gpu
)
self
.
outqueue
=
multiprocessing
.
Queue
()
try
:
gpus
=
os
.
environ
[
'CUDA_VISIBLE_DEVICES'
]
.
split
(
','
)
except
KeyError
:
gpus
=
range
(
self
.
nr_gpu
)
self
.
workers
=
[
PredictWorker
(
i
,
gpus
[
i
],
self
.
inqueue
,
self
.
outqueue
,
config
)
for
i
in
range
(
self
.
nr_gpu
)]
self
.
result_queue
=
OrderedResultGatherProc
(
self
.
outqueue
)
# run the procs
self
.
inqueue_proc
.
start
()
for
p
in
self
.
workers
:
p
.
start
()
self
.
result_queue
.
start
()
ensure_proc_terminate
(
self
.
workers
)
ensure_proc_terminate
([
self
.
result_queue
,
self
.
inqueue_proc
])
else
:
self
.
func
=
get_predict_func
(
config
)
def
get_result
(
self
):
""" A generator to produce prediction for each data"""
with
tqdm
(
total
=
self
.
ds
.
size
())
as
pbar
:
for
dp
in
self
.
ds
.
get_data
():
yield
PredictResult
(
dp
,
self
.
predict_func
(
dp
))
pbar
.
update
()
if
self
.
nr_gpu
==
1
:
for
dp
in
self
.
ds
.
get_data
():
yield
PredictResult
(
dp
,
self
.
func
(
dp
))
pbar
.
update
()
else
:
while
True
:
res
=
self
.
result_queue
.
get
()
if
res
[
0
]
!=
DIE
:
yield
res
[
1
]
else
:
break
pbar
.
update
()
self
.
inqueue_proc
.
join
()
self
.
inqueue_proc
.
terminate
()
for
p
in
self
.
workers
:
p
.
join
();
p
.
terminate
()
def
get_all_result
(
self
):
"""
...
...
tensorpack/train/trainer.py
View file @
c9107add
...
...
@@ -116,9 +116,11 @@ class QueueInputTrainer(Trainer):
# get gradients to update:
if
self
.
config
.
nr_tower
>
1
:
logger
.
info
(
"Training a model of {} tower"
.
format
(
self
.
config
.
nr_tower
))
# to avoid repeated summary from each device
coll_keys
=
[
tf
.
GraphKeys
.
SUMMARIES
,
MOVING_SUMMARY_VARS_KEY
]
kept_summaries
=
{}
grad_list
=
[]
for
i
in
range
(
self
.
config
.
nr_tower
):
with
tf
.
device
(
'/gpu:{}'
.
format
(
i
)),
\
...
...
tensorpack/utils/concurrency.py
View file @
c9107add
# -*- coding: UTF-8 -*-
# File: concurrency.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
# Credit belongs to Xinyu Zhou
import
threading
import
multiprocessing
import
multiprocessing
,
multiprocess
from
contextlib
import
contextmanager
import
tensorflow
as
tf
import
atexit
...
...
@@ -12,6 +13,9 @@ from six.moves import zip
from
.naming
import
*
__all__
=
[
'StoppableThread'
,
'ensure_proc_terminate'
,
'OrderedResultGatherProc'
,
'OrderedContainer'
,
'DIE'
]
class
StoppableThread
(
threading
.
Thread
):
def
__init__
(
self
):
super
(
StoppableThread
,
self
)
.
__init__
()
...
...
@@ -24,7 +28,16 @@ class StoppableThread(threading.Thread):
return
self
.
_stop
.
isSet
()
class
DIE
(
object
):
pass
def
ensure_proc_terminate
(
proc
):
if
isinstance
(
proc
,
list
):
for
p
in
proc
:
ensure_proc_terminate
(
p
)
return
def
stop_proc_by_weak_ref
(
ref
):
proc
=
ref
()
if
proc
is
None
:
...
...
@@ -34,9 +47,58 @@ def ensure_proc_terminate(proc):
proc
.
terminate
()
proc
.
join
()
assert
isinstance
(
proc
,
multiprocessing
.
Process
)
assert
isinstance
(
proc
,
(
multiprocessing
.
Process
,
multiprocess
.
Process
)
)
atexit
.
register
(
stop_proc_by_weak_ref
,
weakref
.
ref
(
proc
))
def
ensure_procs_terminate
(
procs
):
for
p
in
procs
:
ensure_proc_terminate
(
p
)
class
OrderedContainer
(
object
):
def
__init__
(
self
,
start
=
0
):
self
.
ranks
=
[]
self
.
data
=
[]
self
.
wait_for
=
start
def
put
(
self
,
rank
,
val
):
idx
=
bisect
.
bisect
(
self
.
ranks
,
rank
)
self
.
ranks
.
insert
(
idx
,
rank
)
self
.
data
.
insert
(
idx
,
val
)
def
has_next
(
self
):
if
len
(
self
.
ranks
)
==
0
:
return
False
return
self
.
ranks
[
0
]
==
self
.
wait_for
def
get
(
self
):
assert
self
.
has_next
()
ret
=
self
.
data
[
0
]
rank
=
self
.
ranks
[
0
]
del
self
.
ranks
[
0
]
del
self
.
data
[
0
]
self
.
wait_for
+=
1
return
rank
,
ret
class
OrderedResultGatherProc
(
multiprocessing
.
Process
):
def
__init__
(
self
,
data_queue
,
start
=
0
):
super
(
self
.
__class__
,
self
)
.
__init__
()
self
.
data_queue
=
data_queue
self
.
ordered_container
=
OrderedContainer
(
start
=
start
)
self
.
result_queue
=
multiprocessing
.
Queue
()
def
run
(
self
):
try
:
while
True
:
task_id
,
data
=
self
.
data_queue
.
get
()
if
task_id
==
DIE
:
self
.
result_queue
.
put
((
task_id
,
data
))
else
:
self
.
ordered_container
.
put
(
task_id
,
data
)
while
self
.
ordered_container
.
has_next
():
self
.
result_queue
.
put
(
self
.
ordered_container
.
get
())
except
Exception
as
e
:
import
traceback
traceback
.
print_exc
()
raise
e
def
get
(
self
):
return
self
.
result_queue
.
get
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment