Commit 67f37f29 authored by Yuxin Wu's avatar Yuxin Wu

fix some concurrency bug

parent dceac084
......@@ -7,7 +7,6 @@ from itertools import count
import argparse
from collections import namedtuple
import numpy as np
import bisect
from tqdm import tqdm
from six.moves import zip
......@@ -21,23 +20,24 @@ from .dataflow import DataFlow, BatchData
__all__ = ['PredictConfig', 'DatasetPredictor', 'get_predict_func']
PredictResult = namedtuple('PredictResult', ['input', 'output'])
class PredictConfig(object):
def __init__(self, **kwargs):
"""
The config used by `get_predict_func`.
:param session_config: a `tf.ConfigProto` instance to instantiate the
session. default to a session running 1 GPU.
:param session_config: a `tf.ConfigProto` instance to instantiate the session.
:param session_init: a `utils.sessinit.SessionInit` instance to
initialize variables of a session.
:param input_data_mapping: Decide the mapping from each component in data
to the input tensor, since you may not need all input variables
of the graph to run the graph for prediction (for example
the `label` input is not used if you only need probability
distribution).
It should be a list with size=len(data_point),
where each element is an index of the input variables each
component of the data point should be fed into.
of the Model to run the graph for prediction (for example
the `label` input is not used if you only need probability distribution).
It should be a list of int with length equal to `len(data_point)`,
where each element in the list defines which input variables each
component in the data point should be fed into.
If not given, defaults to range(len(input_vars))
For example, in image classification task, the testing
......@@ -46,7 +46,7 @@ class PredictConfig(object):
input_vars: [image_var, label_var]
the mapping should look like: ::
the mapping should then look like: ::
input_data_mapping: [0] # the first component in a datapoint should map to `image_var`
......@@ -95,19 +95,19 @@ def get_predict_func(config):
"Graph has {} inputs but dataset only gives {} components!".format(
len(input_map), len(dp))
feed = dict(zip(input_map, dp))
results = sess.run(output_vars, feed_dict=feed)
if len(output_vars) == 1:
return results[0]
else:
return results
return sess.run(output_vars, feed_dict=feed)
return run_input
PredictResult = namedtuple('PredictResult', ['input', 'output'])
class PredictWorker(multiprocessing.Process):
""" A worker process to run predictor on one GPU """
def __init__(self, idx, gpuid, inqueue, outqueue, config):
"""
:param idx: index of the worker
:param gpuid: id of the GPU to be used
:param inqueue: input queue to get data point
:param outqueue: output queue put result
:param config: a `PredictConfig`
"""
super(PredictWorker, self).__init__()
self.idx = idx
self.gpuid = gpuid
......@@ -132,6 +132,15 @@ class PredictWorker(multiprocessing.Process):
self.outqueue.put((tid, res))
def DFtoQueue(ds, size, nr_consumer):
"""
Build a queue that produce data from `DataFlow`, and a process
that fills the queue.
:param ds: a `DataFlow`
:param size: size of the queue
:param nr_consumer: number of consumer of the queue.
will add this many of `DIE` sentinel to the end of the queue.
:returns: (queue, process)
"""
q = multiprocessing.Queue(size)
class EnqueProc(multiprocessing.Process):
def __init__(self, ds, q, nr_consumer):
......@@ -172,17 +181,15 @@ class DatasetPredictor(object):
for i in range(self.nr_gpu)]
self.result_queue = OrderedResultGatherProc(self.outqueue)
# run the procs
# setup all the procs
self.inqueue_proc.start()
for p in self.workers: p.start()
self.result_queue.start()
ensure_proc_terminate(self.workers)
ensure_proc_terminate([self.result_queue, self.inqueue_proc])
else:
self.func = get_predict_func(config)
def get_result(self):
""" A generator to produce prediction for each data"""
with tqdm(total=self.ds.size()) as pbar:
......@@ -191,11 +198,14 @@ class DatasetPredictor(object):
yield PredictResult(dp, self.func(dp))
pbar.update()
else:
die_cnt = 0
while True:
res = self.result_queue.get()
if res[0] != DIE:
yield res[1]
else:
die_cnt += 1
if die_cnt == self.nr_gpu:
break
pbar.update()
self.inqueue_proc.join()
......
......@@ -4,14 +4,10 @@
# Credit belongs to Xinyu Zhou
import threading
import multiprocessing, multiprocess
from contextlib import contextmanager
import tensorflow as tf
import multiprocessing
import atexit
import bisect
import weakref
from six.moves import zip
from .naming import *
__all__ = ['StoppableThread', 'ensure_proc_terminate',
'OrderedResultGatherProc', 'OrderedContainer', 'DIE']
......@@ -29,9 +25,9 @@ class StoppableThread(threading.Thread):
class DIE(object):
""" A placeholder class indicating end of queue """
pass
def ensure_proc_terminate(proc):
if isinstance(proc, list):
for p in proc:
......@@ -47,11 +43,14 @@ def ensure_proc_terminate(proc):
proc.terminate()
proc.join()
assert isinstance(proc, (multiprocessing.Process, multiprocess.Process))
assert isinstance(proc, multiprocessing.Process)
atexit.register(stop_proc_by_weak_ref, weakref.ref(proc))
class OrderedContainer(object):
"""
Like a priority queue, but will always wait for item with index (x+1) before producing (x+2).
"""
def __init__(self, start=0):
self.ranks = []
self.data = []
......@@ -78,9 +77,12 @@ class OrderedContainer(object):
class OrderedResultGatherProc(multiprocessing.Process):
"""
Gather indexed data from a data queue, and produce results with the
original index-based order.
"""
def __init__(self, data_queue, start=0):
super(self.__class__, self).__init__()
self.data_queue = data_queue
self.ordered_container = OrderedContainer(start=start)
self.result_queue = multiprocessing.Queue()
......
......@@ -57,17 +57,25 @@ def _set_file(path):
filename=path, encoding='utf-8', mode='w')
logger.addHandler(hdl)
def set_logger_dir(dirname):
def set_logger_dir(dirname, action=None):
"""
Set the directory for global logging.
:param dirname: log directory
:param action: an action (k/b/d/n) to be performed. Will ask user by default.
"""
global LOG_FILE, LOG_DIR
if os.path.isdir(dirname):
logger.warn("""\
Directory {} exists! Please either backup/delete it, or use a new directory \
unless you're resuming from a previous task.""".format(dirname))
logger.info("Select Action: k (keep) / b (backup) / d (delete) / n (new):")
if not action:
while True:
act = input().lower().strip()
if act:
break
else:
act = action
if act == 'b':
backup_name = dirname + get_time_str()
shutil.move(dirname, backup_name)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment