Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
c9107add
Commit
c9107add
authored
Apr 22, 2016
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
multigpu predictor
parent
1dcc0e72
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
163 additions
and
17 deletions
+163
-17
tensorpack/callbacks/dump.py
tensorpack/callbacks/dump.py
+1
-0
tensorpack/dataflow/prefetch.py
tensorpack/dataflow/prefetch.py
+4
-3
tensorpack/predict.py
tensorpack/predict.py
+89
-9
tensorpack/train/trainer.py
tensorpack/train/trainer.py
+2
-0
tensorpack/utils/concurrency.py
tensorpack/utils/concurrency.py
+67
-5
No files found.
tensorpack/callbacks/dump.py
View file @
c9107add
...
@@ -40,6 +40,7 @@ class DumpParamAsImage(Callback):
...
@@ -40,6 +40,7 @@ class DumpParamAsImage(Callback):
self
.
clip
=
clip
self
.
clip
=
clip
def
_before_train
(
self
):
def
_before_train
(
self
):
# TODO might not work for multiGPU?
self
.
var
=
self
.
graph
.
get_tensor_by_name
(
self
.
var_name
)
self
.
var
=
self
.
graph
.
get_tensor_by_name
(
self
.
var_name
)
def
_trigger_epoch
(
self
):
def
_trigger_epoch
(
self
):
...
...
tensorpack/dataflow/prefetch.py
View file @
c9107add
...
@@ -6,7 +6,7 @@ import multiprocessing
...
@@ -6,7 +6,7 @@ import multiprocessing
from
six.moves
import
range
from
six.moves
import
range
from
.base
import
ProxyDataFlow
from
.base
import
ProxyDataFlow
from
..utils.concurrency
import
ensure_proc
s
_terminate
from
..utils.concurrency
import
ensure_proc_terminate
from
..utils
import
logger
from
..utils
import
logger
__all__
=
[
'PrefetchData'
]
__all__
=
[
'PrefetchData'
]
...
@@ -36,7 +36,8 @@ class PrefetchData(ProxyDataFlow):
...
@@ -36,7 +36,8 @@ class PrefetchData(ProxyDataFlow):
"""
"""
:param ds: a `DataFlow` instance.
:param ds: a `DataFlow` instance.
:param nr_prefetch: size of the queue to hold prefetched datapoints.
:param nr_prefetch: size of the queue to hold prefetched datapoints.
:param nr_proc: number of processes to use.
:param nr_proc: number of processes to use. When larger than 1, order
of data points will be random.
"""
"""
super
(
PrefetchData
,
self
)
.
__init__
(
ds
)
super
(
PrefetchData
,
self
)
.
__init__
(
ds
)
self
.
_size
=
self
.
size
()
self
.
_size
=
self
.
size
()
...
@@ -45,7 +46,7 @@ class PrefetchData(ProxyDataFlow):
...
@@ -45,7 +46,7 @@ class PrefetchData(ProxyDataFlow):
self
.
queue
=
multiprocessing
.
Queue
(
self
.
nr_prefetch
)
self
.
queue
=
multiprocessing
.
Queue
(
self
.
nr_prefetch
)
self
.
procs
=
[
PrefetchProcess
(
self
.
ds
,
self
.
queue
)
self
.
procs
=
[
PrefetchProcess
(
self
.
ds
,
self
.
queue
)
for
_
in
range
(
self
.
nr_proc
)]
for
_
in
range
(
self
.
nr_proc
)]
ensure_proc
s
_terminate
(
self
.
procs
)
ensure_proc_terminate
(
self
.
procs
)
for
x
in
self
.
procs
:
for
x
in
self
.
procs
:
x
.
start
()
x
.
start
()
...
...
tensorpack/predict.py
View file @
c9107add
...
@@ -7,9 +7,13 @@ from itertools import count
...
@@ -7,9 +7,13 @@ from itertools import count
import
argparse
import
argparse
from
collections
import
namedtuple
from
collections
import
namedtuple
import
numpy
as
np
import
numpy
as
np
import
bisect
from
tqdm
import
tqdm
from
tqdm
import
tqdm
from
six.moves
import
zip
from
six.moves
import
zip
import
multiprocessing
from
.utils.concurrency
import
ensure_proc_terminate
,
OrderedResultGatherProc
,
DIE
from
.tfutils
import
*
from
.tfutils
import
*
from
.utils
import
logger
from
.utils
import
logger
from
.tfutils.modelutils
import
describe_model
from
.tfutils.modelutils
import
describe_model
...
@@ -50,6 +54,7 @@ class PredictConfig(object):
...
@@ -50,6 +54,7 @@ class PredictConfig(object):
:param output_var_names: a list of names of the output variables to predict, the
:param output_var_names: a list of names of the output variables to predict, the
variables can be any computable tensor in the graph.
variables can be any computable tensor in the graph.
Predict specific output might not require all input variables.
Predict specific output might not require all input variables.
:param nr_gpu: default to 1. Use CUDA_VISIBLE_DEVICES to control which GPU to use sepcifically.
"""
"""
def
assert_type
(
v
,
tp
):
def
assert_type
(
v
,
tp
):
assert
isinstance
(
v
,
tp
),
v
.
__class__
assert
isinstance
(
v
,
tp
),
v
.
__class__
...
@@ -59,6 +64,7 @@ class PredictConfig(object):
...
@@ -59,6 +64,7 @@ class PredictConfig(object):
self
.
model
=
kwargs
.
pop
(
'model'
)
self
.
model
=
kwargs
.
pop
(
'model'
)
self
.
input_data_mapping
=
kwargs
.
pop
(
'input_data_mapping'
,
None
)
self
.
input_data_mapping
=
kwargs
.
pop
(
'input_data_mapping'
,
None
)
self
.
output_var_names
=
kwargs
.
pop
(
'output_var_names'
)
self
.
output_var_names
=
kwargs
.
pop
(
'output_var_names'
)
self
.
nr_gpu
=
kwargs
.
pop
(
'nr_gpu'
,
1
)
assert
len
(
kwargs
)
==
0
,
'Unknown arguments: {}'
.
format
(
str
(
kwargs
.
keys
()))
assert
len
(
kwargs
)
==
0
,
'Unknown arguments: {}'
.
format
(
str
(
kwargs
.
keys
()))
def
get_predict_func
(
config
):
def
get_predict_func
(
config
):
...
@@ -81,8 +87,6 @@ def get_predict_func(config):
...
@@ -81,8 +87,6 @@ def get_predict_func(config):
output_vars
=
[
tf
.
get_default_graph
()
.
get_tensor_by_name
(
get_op_var_name
(
n
)[
1
])
output_vars
=
[
tf
.
get_default_graph
()
.
get_tensor_by_name
(
get_op_var_name
(
n
)[
1
])
for
n
in
output_var_names
]
for
n
in
output_var_names
]
describe_model
()
sess
=
tf
.
Session
(
config
=
config
.
session_config
)
sess
=
tf
.
Session
(
config
=
config
.
session_config
)
config
.
session_init
.
init
(
sess
)
config
.
session_init
.
init
(
sess
)
...
@@ -101,27 +105,103 @@ def get_predict_func(config):
...
@@ -101,27 +105,103 @@ def get_predict_func(config):
PredictResult
=
namedtuple
(
'PredictResult'
,
[
'input'
,
'output'
])
PredictResult
=
namedtuple
(
'PredictResult'
,
[
'input'
,
'output'
])
# TODO mutligpu predictor
class
PredictWorker
(
multiprocessing
.
Process
):
def
__init__
(
self
,
idx
,
gpuid
,
inqueue
,
outqueue
,
config
):
super
(
PredictWorker
,
self
)
.
__init__
()
self
.
idx
=
idx
self
.
gpuid
=
gpuid
self
.
inqueue
=
inqueue
self
.
outqueue
=
outqueue
self
.
config
=
config
def
run
(
self
):
os
.
environ
[
'CUDA_VISIBLE_DEVICES'
]
=
self
.
gpuid
G
=
tf
.
Graph
()
# build a graph for each process, because they don't need to share anything
with
G
.
as_default
(),
tf
.
device
(
'/gpu:{}'
.
format
(
self
.
idx
)):
self
.
func
=
get_predict_func
(
self
.
config
)
if
self
.
idx
==
0
:
describe_model
()
while
True
:
tid
,
dp
=
self
.
inqueue
.
get
()
if
tid
==
DIE
:
self
.
outqueue
.
put
((
DIE
,
None
))
return
else
:
res
=
PredictResult
(
dp
,
self
.
func
(
dp
))
self
.
outqueue
.
put
((
tid
,
res
))
def
DFtoQueue
(
ds
,
size
,
nr_consumer
):
q
=
multiprocessing
.
Queue
(
size
)
class
EnqueProc
(
multiprocessing
.
Process
):
def
__init__
(
self
,
ds
,
q
,
nr_consumer
):
super
(
EnqueProc
,
self
)
.
__init__
()
self
.
ds
=
ds
self
.
q
=
q
def
run
(
self
):
for
idx
,
dp
in
enumerate
(
self
.
ds
.
get_data
()):
self
.
q
.
put
((
idx
,
dp
))
print
"Enqueue ends"
for
_
in
range
(
nr_consumer
):
self
.
q
.
put
((
DIE
,
None
))
proc
=
EnqueProc
(
ds
,
q
,
nr_consumer
)
return
q
,
proc
class
DatasetPredictor
(
object
):
class
DatasetPredictor
(
object
):
"""
"""
Run the predict_config on a given `DataFlow`.
Run the predict_config on a given `DataFlow`.
"""
"""
def
__init__
(
self
,
predict_
config
,
dataset
):
def
__init__
(
self
,
config
,
dataset
):
"""
"""
:param
predict_
config: a `PredictConfig` instance.
:param config: a `PredictConfig` instance.
:param dataset: a `DataFlow` instance.
:param dataset: a `DataFlow` instance.
"""
"""
assert
isinstance
(
dataset
,
DataFlow
)
assert
isinstance
(
dataset
,
DataFlow
)
self
.
ds
=
dataset
self
.
ds
=
dataset
self
.
predict_func
=
get_predict_func
(
predict_config
)
self
.
nr_gpu
=
config
.
nr_gpu
if
self
.
nr_gpu
>
1
:
self
.
inqueue
,
self
.
inqueue_proc
=
DFtoQueue
(
self
.
ds
,
10
,
self
.
nr_gpu
)
self
.
outqueue
=
multiprocessing
.
Queue
()
try
:
gpus
=
os
.
environ
[
'CUDA_VISIBLE_DEVICES'
]
.
split
(
','
)
except
KeyError
:
gpus
=
range
(
self
.
nr_gpu
)
self
.
workers
=
[
PredictWorker
(
i
,
gpus
[
i
],
self
.
inqueue
,
self
.
outqueue
,
config
)
for
i
in
range
(
self
.
nr_gpu
)]
self
.
result_queue
=
OrderedResultGatherProc
(
self
.
outqueue
)
# run the procs
self
.
inqueue_proc
.
start
()
for
p
in
self
.
workers
:
p
.
start
()
self
.
result_queue
.
start
()
ensure_proc_terminate
(
self
.
workers
)
ensure_proc_terminate
([
self
.
result_queue
,
self
.
inqueue_proc
])
else
:
self
.
func
=
get_predict_func
(
config
)
def
get_result
(
self
):
def
get_result
(
self
):
""" A generator to produce prediction for each data"""
""" A generator to produce prediction for each data"""
with
tqdm
(
total
=
self
.
ds
.
size
())
as
pbar
:
with
tqdm
(
total
=
self
.
ds
.
size
())
as
pbar
:
for
dp
in
self
.
ds
.
get_data
():
if
self
.
nr_gpu
==
1
:
yield
PredictResult
(
dp
,
self
.
predict_func
(
dp
))
for
dp
in
self
.
ds
.
get_data
():
pbar
.
update
()
yield
PredictResult
(
dp
,
self
.
func
(
dp
))
pbar
.
update
()
else
:
while
True
:
res
=
self
.
result_queue
.
get
()
if
res
[
0
]
!=
DIE
:
yield
res
[
1
]
else
:
break
pbar
.
update
()
self
.
inqueue_proc
.
join
()
self
.
inqueue_proc
.
terminate
()
for
p
in
self
.
workers
:
p
.
join
();
p
.
terminate
()
def
get_all_result
(
self
):
def
get_all_result
(
self
):
"""
"""
...
...
tensorpack/train/trainer.py
View file @
c9107add
...
@@ -116,9 +116,11 @@ class QueueInputTrainer(Trainer):
...
@@ -116,9 +116,11 @@ class QueueInputTrainer(Trainer):
# get gradients to update:
# get gradients to update:
if
self
.
config
.
nr_tower
>
1
:
if
self
.
config
.
nr_tower
>
1
:
logger
.
info
(
"Training a model of {} tower"
.
format
(
self
.
config
.
nr_tower
))
logger
.
info
(
"Training a model of {} tower"
.
format
(
self
.
config
.
nr_tower
))
# to avoid repeated summary from each device
# to avoid repeated summary from each device
coll_keys
=
[
tf
.
GraphKeys
.
SUMMARIES
,
MOVING_SUMMARY_VARS_KEY
]
coll_keys
=
[
tf
.
GraphKeys
.
SUMMARIES
,
MOVING_SUMMARY_VARS_KEY
]
kept_summaries
=
{}
kept_summaries
=
{}
grad_list
=
[]
grad_list
=
[]
for
i
in
range
(
self
.
config
.
nr_tower
):
for
i
in
range
(
self
.
config
.
nr_tower
):
with
tf
.
device
(
'/gpu:{}'
.
format
(
i
)),
\
with
tf
.
device
(
'/gpu:{}'
.
format
(
i
)),
\
...
...
tensorpack/utils/concurrency.py
View file @
c9107add
# -*- coding: UTF-8 -*-
# -*- coding: UTF-8 -*-
# File: concurrency.py
# File: concurrency.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
# Credit belongs to Xinyu Zhou
import
threading
import
threading
import
multiprocessing
import
multiprocessing
,
multiprocess
from
contextlib
import
contextmanager
from
contextlib
import
contextmanager
import
tensorflow
as
tf
import
tensorflow
as
tf
import
atexit
import
atexit
...
@@ -12,6 +13,9 @@ from six.moves import zip
...
@@ -12,6 +13,9 @@ from six.moves import zip
from
.naming
import
*
from
.naming
import
*
__all__
=
[
'StoppableThread'
,
'ensure_proc_terminate'
,
'OrderedResultGatherProc'
,
'OrderedContainer'
,
'DIE'
]
class
StoppableThread
(
threading
.
Thread
):
class
StoppableThread
(
threading
.
Thread
):
def
__init__
(
self
):
def
__init__
(
self
):
super
(
StoppableThread
,
self
)
.
__init__
()
super
(
StoppableThread
,
self
)
.
__init__
()
...
@@ -24,7 +28,16 @@ class StoppableThread(threading.Thread):
...
@@ -24,7 +28,16 @@ class StoppableThread(threading.Thread):
return
self
.
_stop
.
isSet
()
return
self
.
_stop
.
isSet
()
class
DIE
(
object
):
pass
def
ensure_proc_terminate
(
proc
):
def
ensure_proc_terminate
(
proc
):
if
isinstance
(
proc
,
list
):
for
p
in
proc
:
ensure_proc_terminate
(
p
)
return
def
stop_proc_by_weak_ref
(
ref
):
def
stop_proc_by_weak_ref
(
ref
):
proc
=
ref
()
proc
=
ref
()
if
proc
is
None
:
if
proc
is
None
:
...
@@ -34,9 +47,58 @@ def ensure_proc_terminate(proc):
...
@@ -34,9 +47,58 @@ def ensure_proc_terminate(proc):
proc
.
terminate
()
proc
.
terminate
()
proc
.
join
()
proc
.
join
()
assert
isinstance
(
proc
,
multiprocessing
.
Process
)
assert
isinstance
(
proc
,
(
multiprocessing
.
Process
,
multiprocess
.
Process
)
)
atexit
.
register
(
stop_proc_by_weak_ref
,
weakref
.
ref
(
proc
))
atexit
.
register
(
stop_proc_by_weak_ref
,
weakref
.
ref
(
proc
))
def
ensure_procs_terminate
(
procs
):
for
p
in
procs
:
class
OrderedContainer
(
object
):
ensure_proc_terminate
(
p
)
def
__init__
(
self
,
start
=
0
):
self
.
ranks
=
[]
self
.
data
=
[]
self
.
wait_for
=
start
def
put
(
self
,
rank
,
val
):
idx
=
bisect
.
bisect
(
self
.
ranks
,
rank
)
self
.
ranks
.
insert
(
idx
,
rank
)
self
.
data
.
insert
(
idx
,
val
)
def
has_next
(
self
):
if
len
(
self
.
ranks
)
==
0
:
return
False
return
self
.
ranks
[
0
]
==
self
.
wait_for
def
get
(
self
):
assert
self
.
has_next
()
ret
=
self
.
data
[
0
]
rank
=
self
.
ranks
[
0
]
del
self
.
ranks
[
0
]
del
self
.
data
[
0
]
self
.
wait_for
+=
1
return
rank
,
ret
class
OrderedResultGatherProc
(
multiprocessing
.
Process
):
def
__init__
(
self
,
data_queue
,
start
=
0
):
super
(
self
.
__class__
,
self
)
.
__init__
()
self
.
data_queue
=
data_queue
self
.
ordered_container
=
OrderedContainer
(
start
=
start
)
self
.
result_queue
=
multiprocessing
.
Queue
()
def
run
(
self
):
try
:
while
True
:
task_id
,
data
=
self
.
data_queue
.
get
()
if
task_id
==
DIE
:
self
.
result_queue
.
put
((
task_id
,
data
))
else
:
self
.
ordered_container
.
put
(
task_id
,
data
)
while
self
.
ordered_container
.
has_next
():
self
.
result_queue
.
put
(
self
.
ordered_container
.
get
())
except
Exception
as
e
:
import
traceback
traceback
.
print_exc
()
raise
e
def
get
(
self
):
return
self
.
result_queue
.
get
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment