Commit 32ea8a29 authored by Yuxin Wu's avatar Yuxin Wu

better metagraph saving & multithreadasyncpredictor

parent 429d8a85
...@@ -50,15 +50,19 @@ class ModelSaver(Callback): ...@@ -50,15 +50,19 @@ class ModelSaver(Callback):
def _trigger_epoch(self): def _trigger_epoch(self):
try: try:
if not self.meta_graph_written:
self.saver.export_meta_graph(
os.path.join(logger.LOG_DIR,
'graph-{}.meta'.format(logger.get_time_str())),
collection_list=self.graph.get_all_collection_keys())
self.meta_graph_written = True
self.saver.save( self.saver.save(
tf.get_default_session(), tf.get_default_session(),
self.path, self.path,
global_step=self.global_step, global_step=self.global_step,
write_meta_graph=not self.meta_graph_written) write_meta_graph=False)
except Exception: # disk error sometimes.. except Exception: # disk error sometimes..
logger.exception("Exception in ModelSaver.trigger_epoch!") logger.exception("Exception in ModelSaver.trigger_epoch!")
if not self.meta_graph_written:
self.meta_graph_written = True
class MinSaver(Callback): class MinSaver(Callback):
def __init__(self, monitor_stat): def __init__(self, monitor_stat):
......
...@@ -69,6 +69,8 @@ class MultiProcessQueuePredictWorker(MultiProcessPredictWorker): ...@@ -69,6 +69,8 @@ class MultiProcessQueuePredictWorker(MultiProcessPredictWorker):
super(MultiProcessQueuePredictWorker, self).__init__(idx, gpuid, config) super(MultiProcessQueuePredictWorker, self).__init__(idx, gpuid, config)
self.inqueue = inqueue self.inqueue = inqueue
self.outqueue = outqueue self.outqueue = outqueue
assert isinstance(self.inqueue, multiprocessing.Queue)
assert isinstance(self.outqueue, multiprocessing.Queue)
def run(self): def run(self):
self._init_runtime() self._init_runtime()
...@@ -91,13 +93,27 @@ class PredictorWorkerThread(threading.Thread): ...@@ -91,13 +93,27 @@ class PredictorWorkerThread(threading.Thread):
self.id = id self.id = id
def run(self): def run(self):
def fetch(): #self.xxx = None
while True:
batched, futures = self.fetch_batch()
outputs = self.func(batched)
#print "batched size: ", len(batched), "queuesize: ", self.queue.qsize()
# debug, for speed testing
#if self.xxx is None:
#self.xxx = outputs = self.func([batched])
#else:
#outputs = [[self.xxx[0][0]] * len(batched), [self.xxx[1][0]] * len(batched)]
for idx, f in enumerate(futures):
f.set_result([k[idx] for k in outputs])
def fetch_batch(self):
""" Fetch a batch of data without waiting"""
batched, futures = [[] for _ in range(self.nr_input_var)], [] batched, futures = [[] for _ in range(self.nr_input_var)], []
inp, f = self.queue.get() inp, f = self.queue.get()
for k in range(self.nr_input_var): for k in range(self.nr_input_var):
batched[k].append(inp[k]) batched[k].append(inp[k])
futures.append(f) futures.append(f)
# fill a batch
cnt = 1 cnt = 1
while cnt < self.batch_size: while cnt < self.batch_size:
try: try:
...@@ -109,37 +125,21 @@ class PredictorWorkerThread(threading.Thread): ...@@ -109,37 +125,21 @@ class PredictorWorkerThread(threading.Thread):
break break
cnt += 1 cnt += 1
return batched, futures return batched, futures
#self.xxx = None
while True:
batched, futures = fetch()
#print "batched size: ", len(batched), "queuesize: ", self.queue.qsize()
outputs = self.func(batched)
# debug, for speed testing
#if self.xxx is None:
#self.xxx = outputs = self.func([batched])
#else:
#outputs = [[self.xxx[0][0]] * len(batched), [self.xxx[1][0]] * len(batched)]
for idx, f in enumerate(futures):
f.set_result([k[idx] for k in outputs])
class MultiThreadAsyncPredictor(object): class MultiThreadAsyncPredictor(object):
""" """
An online predictor (use the current active session) that works with An multithread predictor which run a list of predict func.
QueueInputTrainer. Use async interface, support multi-thread and multi-GPU. Use async interface, support multi-thread and multi-GPU.
"""
def __init__(self, trainer, input_names, output_names, nr_thread, batch_size=5):
"""
:param trainer: a `QueueInputTrainer` instance.
""" """
def __init__(self, funcs, batch_size=5):
""" :param funcs: a list of predict func"""
self.input_queue = queue.Queue(maxsize=nr_thread*10) self.input_queue = queue.Queue(maxsize=nr_thread*10)
self.threads = [ self.threads = [
PredictorWorkerThread( PredictorWorkerThread(
self.input_queue, f, id, self.input_queue, f, id,
len(input_names), batch_size=batch_size) len(input_names), batch_size=batch_size)
for id, f in enumerate( for id, f in enumerate(funcs)]
trainer.get_predict_funcs(
input_names, output_names, nr_thread))]
# TODO XXX set logging here to avoid affecting TF logging # TODO XXX set logging here to avoid affecting TF logging
import tornado.options as options import tornado.options as options
options.parse_command_line(['--logging=debug']) options.parse_command_line(['--logging=debug'])
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment