Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
4000f5d5
Commit
4000f5d5
authored
Jun 23, 2016
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
fix undefined names in multithreadasyncpredictor
parent
32ea8a29
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
9 additions
and
9 deletions
+9
-9
tensorpack/predict/concurrency.py
tensorpack/predict/concurrency.py
+7
-8
tensorpack/train/trainer.py
tensorpack/train/trainer.py
+2
-1
No files found.
tensorpack/predict/concurrency.py
View file @
4000f5d5
...
@@ -83,13 +83,12 @@ class MultiProcessQueuePredictWorker(MultiProcessPredictWorker):
...
@@ -83,13 +83,12 @@ class MultiProcessQueuePredictWorker(MultiProcessPredictWorker):
self
.
outqueue
.
put
((
tid
,
self
.
func
(
dp
)))
self
.
outqueue
.
put
((
tid
,
self
.
func
(
dp
)))
class
PredictorWorkerThread
(
threading
.
Thread
):
class
PredictorWorkerThread
(
threading
.
Thread
):
def
__init__
(
self
,
queue
,
pred_func
,
id
,
nr_input_var
,
batch_size
=
5
):
def
__init__
(
self
,
queue
,
pred_func
,
id
,
batch_size
=
5
):
super
(
PredictorWorkerThread
,
self
)
.
__init__
()
super
(
PredictorWorkerThread
,
self
)
.
__init__
()
self
.
queue
=
queue
self
.
queue
=
queue
self
.
func
=
pred_func
self
.
func
=
pred_func
self
.
daemon
=
True
self
.
daemon
=
True
self
.
batch_size
=
batch_size
self
.
batch_size
=
batch_size
self
.
nr_input_var
=
nr_input_var
self
.
id
=
id
self
.
id
=
id
def
run
(
self
):
def
run
(
self
):
...
@@ -109,16 +108,17 @@ class PredictorWorkerThread(threading.Thread):
...
@@ -109,16 +108,17 @@ class PredictorWorkerThread(threading.Thread):
def
fetch_batch
(
self
):
def
fetch_batch
(
self
):
""" Fetch a batch of data without waiting"""
""" Fetch a batch of data without waiting"""
batched
,
futures
=
[[]
for
_
in
range
(
self
.
nr_input_var
)],
[]
inp
,
f
=
self
.
queue
.
get
()
inp
,
f
=
self
.
queue
.
get
()
for
k
in
range
(
self
.
nr_input_var
):
nr_input_var
=
len
(
inp
)
batched
,
futures
=
[[]
for
_
in
range
(
nr_input_var
)],
[]
for
k
in
range
(
nr_input_var
):
batched
[
k
]
.
append
(
inp
[
k
])
batched
[
k
]
.
append
(
inp
[
k
])
futures
.
append
(
f
)
futures
.
append
(
f
)
cnt
=
1
cnt
=
1
while
cnt
<
self
.
batch_size
:
while
cnt
<
self
.
batch_size
:
try
:
try
:
inp
,
f
=
self
.
queue
.
get_nowait
()
inp
,
f
=
self
.
queue
.
get_nowait
()
for
k
in
range
(
self
.
nr_input_var
):
for
k
in
range
(
nr_input_var
):
batched
[
k
]
.
append
(
inp
[
k
])
batched
[
k
]
.
append
(
inp
[
k
])
futures
.
append
(
f
)
futures
.
append
(
f
)
except
queue
.
Empty
:
except
queue
.
Empty
:
...
@@ -133,11 +133,10 @@ class MultiThreadAsyncPredictor(object):
...
@@ -133,11 +133,10 @@ class MultiThreadAsyncPredictor(object):
"""
"""
def
__init__
(
self
,
funcs
,
batch_size
=
5
):
def
__init__
(
self
,
funcs
,
batch_size
=
5
):
""" :param funcs: a list of predict func"""
""" :param funcs: a list of predict func"""
self
.
input_queue
=
queue
.
Queue
(
maxsize
=
nr_thread
*
10
)
self
.
input_queue
=
queue
.
Queue
(
maxsize
=
len
(
funcs
)
*
10
)
self
.
threads
=
[
self
.
threads
=
[
PredictorWorkerThread
(
PredictorWorkerThread
(
self
.
input_queue
,
f
,
id
,
self
.
input_queue
,
f
,
id
,
batch_size
=
batch_size
)
len
(
input_names
),
batch_size
=
batch_size
)
for
id
,
f
in
enumerate
(
funcs
)]
for
id
,
f
in
enumerate
(
funcs
)]
# TODO XXX set logging here to avoid affecting TF logging
# TODO XXX set logging here to avoid affecting TF logging
...
...
tensorpack/train/trainer.py
View file @
4000f5d5
...
@@ -113,9 +113,10 @@ class QueueInputTrainer(Trainer):
...
@@ -113,9 +113,10 @@ class QueueInputTrainer(Trainer):
"""
"""
super
(
QueueInputTrainer
,
self
)
.
__init__
(
config
)
super
(
QueueInputTrainer
,
self
)
.
__init__
(
config
)
self
.
input_vars
=
self
.
model
.
get_input_vars
()
self
.
input_vars
=
self
.
model
.
get_input_vars
()
# use a smaller queue size for now, to avoid https://github.com/tensorflow/tensorflow/issues/2942
if
input_queue
is
None
:
if
input_queue
is
None
:
self
.
input_queue
=
tf
.
FIFOQueue
(
self
.
input_queue
=
tf
.
FIFOQueue
(
10
0
,
[
x
.
dtype
for
x
in
self
.
input_vars
],
name
=
'input_queue'
)
3
0
,
[
x
.
dtype
for
x
in
self
.
input_vars
],
name
=
'input_queue'
)
else
:
else
:
self
.
input_queue
=
input_queue
self
.
input_queue
=
input_queue
if
predict_tower
is
None
:
if
predict_tower
is
None
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment