Commit 67f37f29 authored by Yuxin Wu's avatar Yuxin Wu

fix some concurrency bug

parent dceac084
...@@ -7,7 +7,6 @@ from itertools import count ...@@ -7,7 +7,6 @@ from itertools import count
import argparse import argparse
from collections import namedtuple from collections import namedtuple
import numpy as np import numpy as np
import bisect
from tqdm import tqdm from tqdm import tqdm
from six.moves import zip from six.moves import zip
...@@ -21,23 +20,24 @@ from .dataflow import DataFlow, BatchData ...@@ -21,23 +20,24 @@ from .dataflow import DataFlow, BatchData
__all__ = ['PredictConfig', 'DatasetPredictor', 'get_predict_func'] __all__ = ['PredictConfig', 'DatasetPredictor', 'get_predict_func']
PredictResult = namedtuple('PredictResult', ['input', 'output'])
class PredictConfig(object): class PredictConfig(object):
def __init__(self, **kwargs): def __init__(self, **kwargs):
""" """
The config used by `get_predict_func`. The config used by `get_predict_func`.
:param session_config: a `tf.ConfigProto` instance to instantiate the :param session_config: a `tf.ConfigProto` instance to instantiate the session.
session. default to a session running 1 GPU.
:param session_init: a `utils.sessinit.SessionInit` instance to :param session_init: a `utils.sessinit.SessionInit` instance to
initialize variables of a session. initialize variables of a session.
:param input_data_mapping: Decide the mapping from each component in data :param input_data_mapping: Decide the mapping from each component in data
to the input tensor, since you may not need all input variables to the input tensor, since you may not need all input variables
of the graph to run the graph for prediction (for example of the Model to run the graph for prediction (for example
the `label` input is not used if you only need probability the `label` input is not used if you only need probability distribution).
distribution).
It should be a list with size=len(data_point), It should be a list of int with length equal to `len(data_point)`,
where each element is an index of the input variables each where each element in the list defines which input variables each
component of the data point should be fed into. component in the data point should be fed into.
If not given, defaults to range(len(input_vars)) If not given, defaults to range(len(input_vars))
For example, in image classification task, the testing For example, in image classification task, the testing
...@@ -46,7 +46,7 @@ class PredictConfig(object): ...@@ -46,7 +46,7 @@ class PredictConfig(object):
input_vars: [image_var, label_var] input_vars: [image_var, label_var]
the mapping should look like: :: the mapping should then look like: ::
input_data_mapping: [0] # the first component in a datapoint should map to `image_var` input_data_mapping: [0] # the first component in a datapoint should map to `image_var`
...@@ -95,19 +95,19 @@ def get_predict_func(config): ...@@ -95,19 +95,19 @@ def get_predict_func(config):
"Graph has {} inputs but dataset only gives {} components!".format( "Graph has {} inputs but dataset only gives {} components!".format(
len(input_map), len(dp)) len(input_map), len(dp))
feed = dict(zip(input_map, dp)) feed = dict(zip(input_map, dp))
return sess.run(output_vars, feed_dict=feed)
results = sess.run(output_vars, feed_dict=feed)
if len(output_vars) == 1:
return results[0]
else:
return results
return run_input return run_input
PredictResult = namedtuple('PredictResult', ['input', 'output'])
class PredictWorker(multiprocessing.Process): class PredictWorker(multiprocessing.Process):
""" A worker process to run predictor on one GPU """
def __init__(self, idx, gpuid, inqueue, outqueue, config): def __init__(self, idx, gpuid, inqueue, outqueue, config):
"""
:param idx: index of the worker
:param gpuid: id of the GPU to be used
:param inqueue: input queue to get data point
:param outqueue: output queue put result
:param config: a `PredictConfig`
"""
super(PredictWorker, self).__init__() super(PredictWorker, self).__init__()
self.idx = idx self.idx = idx
self.gpuid = gpuid self.gpuid = gpuid
...@@ -132,6 +132,15 @@ class PredictWorker(multiprocessing.Process): ...@@ -132,6 +132,15 @@ class PredictWorker(multiprocessing.Process):
self.outqueue.put((tid, res)) self.outqueue.put((tid, res))
def DFtoQueue(ds, size, nr_consumer): def DFtoQueue(ds, size, nr_consumer):
"""
Build a queue that produce data from `DataFlow`, and a process
that fills the queue.
:param ds: a `DataFlow`
:param size: size of the queue
:param nr_consumer: number of consumer of the queue.
will add this many of `DIE` sentinel to the end of the queue.
:returns: (queue, process)
"""
q = multiprocessing.Queue(size) q = multiprocessing.Queue(size)
class EnqueProc(multiprocessing.Process): class EnqueProc(multiprocessing.Process):
def __init__(self, ds, q, nr_consumer): def __init__(self, ds, q, nr_consumer):
...@@ -172,17 +181,15 @@ class DatasetPredictor(object): ...@@ -172,17 +181,15 @@ class DatasetPredictor(object):
for i in range(self.nr_gpu)] for i in range(self.nr_gpu)]
self.result_queue = OrderedResultGatherProc(self.outqueue) self.result_queue = OrderedResultGatherProc(self.outqueue)
# run the procs # setup all the procs
self.inqueue_proc.start() self.inqueue_proc.start()
for p in self.workers: p.start() for p in self.workers: p.start()
self.result_queue.start() self.result_queue.start()
ensure_proc_terminate(self.workers) ensure_proc_terminate(self.workers)
ensure_proc_terminate([self.result_queue, self.inqueue_proc]) ensure_proc_terminate([self.result_queue, self.inqueue_proc])
else: else:
self.func = get_predict_func(config) self.func = get_predict_func(config)
def get_result(self): def get_result(self):
""" A generator to produce prediction for each data""" """ A generator to produce prediction for each data"""
with tqdm(total=self.ds.size()) as pbar: with tqdm(total=self.ds.size()) as pbar:
...@@ -191,12 +198,15 @@ class DatasetPredictor(object): ...@@ -191,12 +198,15 @@ class DatasetPredictor(object):
yield PredictResult(dp, self.func(dp)) yield PredictResult(dp, self.func(dp))
pbar.update() pbar.update()
else: else:
die_cnt = 0
while True: while True:
res = self.result_queue.get() res = self.result_queue.get()
if res[0] != DIE: if res[0] != DIE:
yield res[1] yield res[1]
else: else:
break die_cnt += 1
if die_cnt == self.nr_gpu:
break
pbar.update() pbar.update()
self.inqueue_proc.join() self.inqueue_proc.join()
self.inqueue_proc.terminate() self.inqueue_proc.terminate()
......
...@@ -4,14 +4,10 @@ ...@@ -4,14 +4,10 @@
# Credit belongs to Xinyu Zhou # Credit belongs to Xinyu Zhou
import threading import threading
import multiprocessing, multiprocess import multiprocessing
from contextlib import contextmanager
import tensorflow as tf
import atexit import atexit
import bisect
import weakref import weakref
from six.moves import zip
from .naming import *
__all__ = ['StoppableThread', 'ensure_proc_terminate', __all__ = ['StoppableThread', 'ensure_proc_terminate',
'OrderedResultGatherProc', 'OrderedContainer', 'DIE'] 'OrderedResultGatherProc', 'OrderedContainer', 'DIE']
...@@ -29,9 +25,9 @@ class StoppableThread(threading.Thread): ...@@ -29,9 +25,9 @@ class StoppableThread(threading.Thread):
class DIE(object): class DIE(object):
""" A placeholder class indicating end of queue """
pass pass
def ensure_proc_terminate(proc): def ensure_proc_terminate(proc):
if isinstance(proc, list): if isinstance(proc, list):
for p in proc: for p in proc:
...@@ -47,11 +43,14 @@ def ensure_proc_terminate(proc): ...@@ -47,11 +43,14 @@ def ensure_proc_terminate(proc):
proc.terminate() proc.terminate()
proc.join() proc.join()
assert isinstance(proc, (multiprocessing.Process, multiprocess.Process)) assert isinstance(proc, multiprocessing.Process)
atexit.register(stop_proc_by_weak_ref, weakref.ref(proc)) atexit.register(stop_proc_by_weak_ref, weakref.ref(proc))
class OrderedContainer(object): class OrderedContainer(object):
"""
Like a priority queue, but will always wait for item with index (x+1) before producing (x+2).
"""
def __init__(self, start=0): def __init__(self, start=0):
self.ranks = [] self.ranks = []
self.data = [] self.data = []
...@@ -78,9 +77,12 @@ class OrderedContainer(object): ...@@ -78,9 +77,12 @@ class OrderedContainer(object):
class OrderedResultGatherProc(multiprocessing.Process): class OrderedResultGatherProc(multiprocessing.Process):
"""
Gather indexed data from a data queue, and produce results with the
original index-based order.
"""
def __init__(self, data_queue, start=0): def __init__(self, data_queue, start=0):
super(self.__class__, self).__init__() super(self.__class__, self).__init__()
self.data_queue = data_queue self.data_queue = data_queue
self.ordered_container = OrderedContainer(start=start) self.ordered_container = OrderedContainer(start=start)
self.result_queue = multiprocessing.Queue() self.result_queue = multiprocessing.Queue()
......
...@@ -57,17 +57,25 @@ def _set_file(path): ...@@ -57,17 +57,25 @@ def _set_file(path):
filename=path, encoding='utf-8', mode='w') filename=path, encoding='utf-8', mode='w')
logger.addHandler(hdl) logger.addHandler(hdl)
def set_logger_dir(dirname): def set_logger_dir(dirname, action=None):
"""
Set the directory for global logging.
:param dirname: log directory
:param action: an action (k/b/d/n) to be performed. Will ask user by default.
"""
global LOG_FILE, LOG_DIR global LOG_FILE, LOG_DIR
if os.path.isdir(dirname): if os.path.isdir(dirname):
logger.warn("""\ logger.warn("""\
Directory {} exists! Please either backup/delete it, or use a new directory \ Directory {} exists! Please either backup/delete it, or use a new directory \
unless you're resuming from a previous task.""".format(dirname)) unless you're resuming from a previous task.""".format(dirname))
logger.info("Select Action: k (keep) / b (backup) / d (delete) / n (new):") logger.info("Select Action: k (keep) / b (backup) / d (delete) / n (new):")
while True: if not action:
act = input().lower().strip() while True:
if act: act = input().lower().strip()
break if act:
break
else:
act = action
if act == 'b': if act == 'b':
backup_name = dirname + get_time_str() backup_name = dirname + get_time_str()
shutil.move(dirname, backup_name) shutil.move(dirname, backup_name)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment