Commit 40e6a223 authored by Yuxin Wu's avatar Yuxin Wu

major refactory in predict

parent ff0a4528
......@@ -170,7 +170,7 @@ def eval_model_multiprocess(model_path):
session_init=SaverRestore(model_path),
output_var_names=['fct/output:0'])
class Worker(ParallelPredictWorker):
class Worker(MultiProcessPredictWorker):
def __init__(self, idx, gpuid, config, outqueue):
super(Worker, self).__init__(idx, gpuid, config)
self.outq = outqueue
......
......@@ -12,7 +12,7 @@ import imp
from tensorpack.utils import *
from tensorpack.utils import sessinit
from tensorpack.dataflow import *
from tensorpack.predict import PredictConfig, DatasetPredictor
from tensorpack.predict import PredictConfig, SimpleDatasetPredictor
parser = argparse.ArgumentParser()
......@@ -26,6 +26,8 @@ args = parser.parse_args()
get_config_func = imp.load_source('config_script', args.config).get_config
# TODO not sure if it this script is still working
with tf.Graph().as_default() as G:
train_config = get_config_func()
config = PredictConfig(
......@@ -37,9 +39,9 @@ with tf.Graph().as_default() as G:
)
ds = ImageFromFile(args.images, 3, resize=(227, 227))
predictor = DatasetPredictor(config, ds, batch=128)
ds = BatchData(ds, 128, remainder=True)
predictor = SimpleDatasetPredictor(config, ds)
res = predictor.get_all_result()
res = [k.output for k in res]
if args.output_type == 'label':
for r in res:
......
......@@ -39,7 +39,7 @@ def dataflow_to_process_queue(ds, size, nr_consumer):
:param nr_consumer: number of consumer of the queue.
will add this many of `DIE` sentinel to the end of the queue.
:returns: (queue, process). The process will take data from `ds` to fill
the queue once you start it.
the queue once you start it. Each element is (task_id, dp).
"""
q = multiprocessing.Queue(size)
class EnqueProc(multiprocessing.Process):
......
......@@ -45,7 +45,7 @@ class PredictConfig(object):
:param output_var_names: a list of names of the output variables to predict, the
variables can be any computable tensor in the graph.
Predict specific output might not require all input variables.
:param nr_gpu: default to 1. Use CUDA_VISIBLE_DEVICES to control which GPU to use sepcifically.
:param return_input: whether to produce (input, output) pair or just output. default to False.
"""
def assert_type(v, tp):
assert isinstance(v, tp), v.__class__
......@@ -54,7 +54,7 @@ class PredictConfig(object):
self.model = kwargs.pop('model')
self.input_data_mapping = kwargs.pop('input_data_mapping', None)
self.output_var_names = kwargs.pop('output_var_names')
self.nr_gpu = kwargs.pop('nr_gpu', 1)
self.return_input = kwargs.pop('return_input', False)
assert len(kwargs) == 0, 'Unknown arguments: {}'.format(str(kwargs.keys()))
def get_predict_func(config):
......
......@@ -2,7 +2,8 @@
# -*- coding: utf-8 -*-
# File: concurrency.py
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
import multiprocessing
import multiprocessing, threading
import tensorflow as tf
from ..utils.concurrency import DIE
from ..tfutils.modelutils import describe_model
......@@ -11,16 +12,17 @@ from ..tfutils import *
from .common import *
__all__ = ['ParallelPredictWorker', 'QueuePredictWorker']
__all__ = ['MultiProcessPredictWorker', 'MultiProcessQueuePredictWorker']
class ParallelPredictWorker(multiprocessing.Process):
class MultiProcessPredictWorker(multiprocessing.Process):
""" Base class for predict worker that runs in multiprocess"""
def __init__(self, idx, gpuid, config):
"""
:param idx: index of the worker. the 0th worker will print log.
:param gpuid: absolute id of the GPU to be used. set to -1 to use CPU.
:param config: a `PredictConfig`
"""
super(ParallelPredictWorker, self).__init__()
super(MultiProcessPredictWorker, self).__init__()
self.idx = idx
self.gpuid = gpuid
self.config = config
......@@ -41,17 +43,17 @@ class ParallelPredictWorker(multiprocessing.Process):
if self.idx == 0:
describe_model()
class QueuePredictWorker(ParallelPredictWorker):
class MultiProcessQueuePredictWorker(MultiProcessPredictWorker):
""" A worker process to run predictor on one GPU """
def __init__(self, idx, gpuid, inqueue, outqueue, config):
"""
:param idx: index of the worker. the 0th worker will print log.
:param gpuid: id of the GPU to be used. set to -1 to use CPU.
:param inqueue: input queue to get data point
:param outqueue: output queue put result
:param inqueue: input queue to get data point. elements are (task_id, dp)
:param outqueue: output queue put result. elements are (task_id, output)
:param config: a `PredictConfig`
"""
super(QueuePredictWorker, self).__init__(idx, gpuid, config)
super(MultiProcessQueuePredictWorker, self).__init__(idx, gpuid, config)
self.inqueue = inqueue
self.outqueue = outqueue
......@@ -63,5 +65,19 @@ class QueuePredictWorker(ParallelPredictWorker):
self.outqueue.put((DIE, None))
return
else:
res = PredictResult(dp, self.func(dp))
self.outqueue.put((tid, res))
self.outqueue.put((tid, self.func(dp)))
class MultiThreadPredictWorker(threading.Thread):
def __init__(self, idx, gpuid, config):
"""
:param idx: index of the worker. the 0th worker will print log.
:param gpuid: absolute id of the GPU to be used. set to -1 to use CPU.
:param config: a `PredictConfig`
"""
super(MultiProcessPredictWorker, self).__init__()
self.idx = idx
self.gpuid = gpuid
self.config = config
def run(self):
pass
......@@ -5,55 +5,104 @@
from six.moves import range
from tqdm import tqdm
from abc import ABCMeta, abstractmethod
from ..dataflow import DataFlow, BatchData
from ..dataflow.dftools import dataflow_to_process_queue
from ..utils.concurrency import ensure_proc_terminate, OrderedResultGatherProc, DIE
from .concurrency import *
from .concurrency import MultiProcessQueuePredictWorker
from .common import *
__all__ = ['DatasetPredictor']
__all__ = ['DatasetPredictorBase', 'SimpleDatasetPredictor',
'MultiProcessDatasetPredictor']
class DatasetPredictorBase(object):
__metaclass__ = ABCMeta
class DatasetPredictor(object):
"""
Run the predict_config on a given `DataFlow`.
"""
def __init__(self, config, dataset):
"""
:param config: a `PredictConfig` instance.
:param dataset: a `DataFlow` instance.
"""
assert isinstance(dataset, DataFlow)
self.ds = dataset
self.nr_gpu = config.nr_gpu
if self.nr_gpu > 1:
self.inqueue, self.inqueue_proc = dataflow_to_process_queue(self.ds, 10, self.nr_gpu)
assert isinstance(config, PredictConfig)
self.config = config
self.dataset = dataset
@abstractmethod
def get_result(self):
""" Generate (inpupt, output) pair of output, for each input in dataset"""
pass
def get_all_result(self):
"""
Run over the dataset and return a list of all predictions.
"""
return list(self.get_result())
class SimpleDatasetPredictor(DatasetPredictorBase):
"""
Run the predict_config on a given `DataFlow`.
"""
def __init__(self, config, dataset):
super(SimpleDatasetPredictor, self).__init__(config, dataset)
self.func = get_predict_func(config)
def get_result(self):
""" A generator to produce prediction for each data"""
with tqdm(total=self.dataset.size()) as pbar:
for dp in self.dataset.get_data():
res = self.func(dp)
if self.config.return_input:
yield (dp, res)
else:
yield res
pbar.update()
class MultiProcessDatasetPredictor(DatasetPredictorBase):
def __init__(self, config, dataset, nr_proc, use_gpu=True):
"""
Run prediction in multiprocesses, on either CPU or GPU. Mix mode not supported.
:param nr_proc: number of processes to use
:param use_gpu: use GPU or CPU.
nr_proc cannot be larger than the total number of GPUs available
in CUDA_VISIBLE_DEVICES or in the system.
"""
assert config.return_input == False, "return_input not supported for MultiProcessDatasetPredictor"
assert nr_proc > 1
super(MultiProcessDatasetPredictor, self).__init__(config, dataset)
self.nr_proc = nr_proc
self.inqueue, self.inqueue_proc = dataflow_to_process_queue(
self.dataset, nr_proc * 2, self.nr_proc)
self.outqueue = multiprocessing.Queue()
if use_gpu:
try:
gpus = os.environ['CUDA_VISIBLE_DEVICES'].split(',')
assert len(gpus) >= self.nr_proc, \
"nr_proc={} while only {} gpus available".format(
self.nr_proc, len(gpus))
except KeyError:
# TODO number of GPUs not checked
gpus = list(range(self.nr_gpu))
self.workers = [QueuePredictWorker(i, gpus[i], self.inqueue, self.outqueue, config)
for i in range(self.nr_gpu)]
self.result_queue = OrderedResultGatherProc(self.outqueue)
else:
gpus = [-1] * self.nr_proc
self.workers = [MultiProcessQueuePredictWorker(
i, gpus[i], self.inqueue, self.outqueue, self.config)
for i in range(self.nr_proc)]
self.result_queue = OrderedResultGatherProc(
self.outqueue, nr_producer=self.nr_proc)
# setup all the procs
self.inqueue_proc.start()
for p in self.workers: p.start()
self.result_queue.start()
ensure_proc_terminate(self.workers)
ensure_proc_terminate([self.result_queue, self.inqueue_proc])
else:
self.func = get_predict_func(config)
ensure_proc_terminate(self.workers + [self.result_queue, self.inqueue_proc])
def get_result(self):
""" A generator to produce prediction for each data"""
with tqdm(total=self.ds.size()) as pbar:
if self.nr_gpu == 1:
for dp in self.ds.get_data():
yield PredictResult(dp, self.func(dp))
pbar.update()
else:
with tqdm(total=self.dataset.size()) as pbar:
die_cnt = 0
while True:
res = self.result_queue.get()
......@@ -62,15 +111,11 @@ class DatasetPredictor(object):
yield res[1]
else:
die_cnt += 1
if die_cnt == self.nr_gpu:
if die_cnt == self.nr_proc:
break
self.inqueue_proc.join()
self.inqueue_proc.terminate()
self.result_queue.join()
self.result_queue.terminate()
for p in self.workers:
p.join(); p.terminate()
def get_all_result(self):
"""
Run over the dataset and return a list of all predictions.
"""
return list(self.get_result())
......@@ -123,18 +123,28 @@ class OrderedResultGatherProc(multiprocessing.Process):
Gather indexed data from a data queue, and produce results with the
original index-based order.
"""
def __init__(self, data_queue, start=0):
super(self.__class__, self).__init__()
def __init__(self, data_queue, nr_producer, start=0):
"""
:param data_queue: a multiprocessing.Queue to produce input dp
:param nr_producer: number of producer processes. Will terminate after receiving this many of DIE sentinel.
:param start: the first task index
"""
super(OrderedResultGatherProc, self).__init__()
self.data_queue = data_queue
self.ordered_container = OrderedContainer(start=start)
self.result_queue = multiprocessing.Queue()
self.nr_producer = nr_producer
def run(self):
nr_end = 0
try:
while True:
task_id, data = self.data_queue.get()
if task_id == DIE:
self.result_queue.put((task_id, data))
nr_end += 1
if nr_end == self.nr_producer:
return
else:
self.ordered_container.put(task_id, data)
while self.ordered_container.has_next():
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment