Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
32ea8a29
Commit
32ea8a29
authored
Jun 23, 2016
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
better metagraph saving & multithreadasyncpredictor
parent
429d8a85
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
36 additions
and
32 deletions
+36
-32
tensorpack/callbacks/common.py
tensorpack/callbacks/common.py
+7
-3
tensorpack/predict/concurrency.py
tensorpack/predict/concurrency.py
+29
-29
No files found.
tensorpack/callbacks/common.py
View file @
32ea8a29
...
...
@@ -50,15 +50,19 @@ class ModelSaver(Callback):
def
_trigger_epoch
(
self
):
try
:
if
not
self
.
meta_graph_written
:
self
.
saver
.
export_meta_graph
(
os
.
path
.
join
(
logger
.
LOG_DIR
,
'graph-{}.meta'
.
format
(
logger
.
get_time_str
())),
collection_list
=
self
.
graph
.
get_all_collection_keys
())
self
.
meta_graph_written
=
True
self
.
saver
.
save
(
tf
.
get_default_session
(),
self
.
path
,
global_step
=
self
.
global_step
,
write_meta_graph
=
not
self
.
meta_graph_written
)
write_meta_graph
=
False
)
except
Exception
:
# disk error sometimes..
logger
.
exception
(
"Exception in ModelSaver.trigger_epoch!"
)
if
not
self
.
meta_graph_written
:
self
.
meta_graph_written
=
True
class
MinSaver
(
Callback
):
def
__init__
(
self
,
monitor_stat
):
...
...
tensorpack/predict/concurrency.py
View file @
32ea8a29
...
...
@@ -69,6 +69,8 @@ class MultiProcessQueuePredictWorker(MultiProcessPredictWorker):
super
(
MultiProcessQueuePredictWorker
,
self
)
.
__init__
(
idx
,
gpuid
,
config
)
self
.
inqueue
=
inqueue
self
.
outqueue
=
outqueue
assert
isinstance
(
self
.
inqueue
,
multiprocessing
.
Queue
)
assert
isinstance
(
self
.
outqueue
,
multiprocessing
.
Queue
)
def
run
(
self
):
self
.
_init_runtime
()
...
...
@@ -91,13 +93,27 @@ class PredictorWorkerThread(threading.Thread):
self
.
id
=
id
def
run
(
self
):
def
fetch
():
#self.xxx = None
while
True
:
batched
,
futures
=
self
.
fetch_batch
()
outputs
=
self
.
func
(
batched
)
#print "batched size: ", len(batched), "queuesize: ", self.queue.qsize()
# debug, for speed testing
#if self.xxx is None:
#self.xxx = outputs = self.func([batched])
#else:
#outputs = [[self.xxx[0][0]] * len(batched), [self.xxx[1][0]] * len(batched)]
for
idx
,
f
in
enumerate
(
futures
):
f
.
set_result
([
k
[
idx
]
for
k
in
outputs
])
def
fetch_batch
(
self
):
""" Fetch a batch of data without waiting"""
batched
,
futures
=
[[]
for
_
in
range
(
self
.
nr_input_var
)],
[]
inp
,
f
=
self
.
queue
.
get
()
for
k
in
range
(
self
.
nr_input_var
):
batched
[
k
]
.
append
(
inp
[
k
])
futures
.
append
(
f
)
# fill a batch
cnt
=
1
while
cnt
<
self
.
batch_size
:
try
:
...
...
@@ -109,37 +125,21 @@ class PredictorWorkerThread(threading.Thread):
break
cnt
+=
1
return
batched
,
futures
#self.xxx = None
while
True
:
batched
,
futures
=
fetch
()
#print "batched size: ", len(batched), "queuesize: ", self.queue.qsize()
outputs
=
self
.
func
(
batched
)
# debug, for speed testing
#if self.xxx is None:
#self.xxx = outputs = self.func([batched])
#else:
#outputs = [[self.xxx[0][0]] * len(batched), [self.xxx[1][0]] * len(batched)]
for
idx
,
f
in
enumerate
(
futures
):
f
.
set_result
([
k
[
idx
]
for
k
in
outputs
])
class
MultiThreadAsyncPredictor
(
object
):
"""
An online predictor (use the current active session) that works with
QueueInputTrainer. Use async interface, support multi-thread and multi-GPU.
"""
def
__init__
(
self
,
trainer
,
input_names
,
output_names
,
nr_thread
,
batch_size
=
5
):
"""
:param trainer: a `QueueInputTrainer` instance.
An multithread predictor which run a list of predict func.
Use async interface, support multi-thread and multi-GPU.
"""
def
__init__
(
self
,
funcs
,
batch_size
=
5
):
""" :param funcs: a list of predict func"""
self
.
input_queue
=
queue
.
Queue
(
maxsize
=
nr_thread
*
10
)
self
.
threads
=
[
PredictorWorkerThread
(
self
.
input_queue
,
f
,
id
,
len
(
input_names
),
batch_size
=
batch_size
)
for
id
,
f
in
enumerate
(
trainer
.
get_predict_funcs
(
input_names
,
output_names
,
nr_thread
))]
for
id
,
f
in
enumerate
(
funcs
)]
# TODO XXX set logging here to avoid affecting TF logging
import
tornado.options
as
options
options
.
parse_command_line
([
'--logging=debug'
])
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment