Commit 40e6a223 authored by Yuxin Wu's avatar Yuxin Wu

major refactory in predict

parent ff0a4528
...@@ -170,7 +170,7 @@ def eval_model_multiprocess(model_path): ...@@ -170,7 +170,7 @@ def eval_model_multiprocess(model_path):
session_init=SaverRestore(model_path), session_init=SaverRestore(model_path),
output_var_names=['fct/output:0']) output_var_names=['fct/output:0'])
class Worker(ParallelPredictWorker): class Worker(MultiProcessPredictWorker):
def __init__(self, idx, gpuid, config, outqueue): def __init__(self, idx, gpuid, config, outqueue):
super(Worker, self).__init__(idx, gpuid, config) super(Worker, self).__init__(idx, gpuid, config)
self.outq = outqueue self.outq = outqueue
......
...@@ -12,7 +12,7 @@ import imp ...@@ -12,7 +12,7 @@ import imp
from tensorpack.utils import * from tensorpack.utils import *
from tensorpack.utils import sessinit from tensorpack.utils import sessinit
from tensorpack.dataflow import * from tensorpack.dataflow import *
from tensorpack.predict import PredictConfig, DatasetPredictor from tensorpack.predict import PredictConfig, SimpleDatasetPredictor
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
...@@ -26,6 +26,8 @@ args = parser.parse_args() ...@@ -26,6 +26,8 @@ args = parser.parse_args()
get_config_func = imp.load_source('config_script', args.config).get_config get_config_func = imp.load_source('config_script', args.config).get_config
# TODO not sure if it this script is still working
with tf.Graph().as_default() as G: with tf.Graph().as_default() as G:
train_config = get_config_func() train_config = get_config_func()
config = PredictConfig( config = PredictConfig(
...@@ -37,9 +39,9 @@ with tf.Graph().as_default() as G: ...@@ -37,9 +39,9 @@ with tf.Graph().as_default() as G:
) )
ds = ImageFromFile(args.images, 3, resize=(227, 227)) ds = ImageFromFile(args.images, 3, resize=(227, 227))
predictor = DatasetPredictor(config, ds, batch=128) ds = BatchData(ds, 128, remainder=True)
predictor = SimpleDatasetPredictor(config, ds)
res = predictor.get_all_result() res = predictor.get_all_result()
res = [k.output for k in res]
if args.output_type == 'label': if args.output_type == 'label':
for r in res: for r in res:
......
...@@ -39,7 +39,7 @@ def dataflow_to_process_queue(ds, size, nr_consumer): ...@@ -39,7 +39,7 @@ def dataflow_to_process_queue(ds, size, nr_consumer):
:param nr_consumer: number of consumer of the queue. :param nr_consumer: number of consumer of the queue.
will add this many of `DIE` sentinel to the end of the queue. will add this many of `DIE` sentinel to the end of the queue.
:returns: (queue, process). The process will take data from `ds` to fill :returns: (queue, process). The process will take data from `ds` to fill
the queue once you start it. the queue once you start it. Each element is (task_id, dp).
""" """
q = multiprocessing.Queue(size) q = multiprocessing.Queue(size)
class EnqueProc(multiprocessing.Process): class EnqueProc(multiprocessing.Process):
......
...@@ -45,7 +45,7 @@ class PredictConfig(object): ...@@ -45,7 +45,7 @@ class PredictConfig(object):
:param output_var_names: a list of names of the output variables to predict, the :param output_var_names: a list of names of the output variables to predict, the
variables can be any computable tensor in the graph. variables can be any computable tensor in the graph.
Predict specific output might not require all input variables. Predict specific output might not require all input variables.
:param nr_gpu: default to 1. Use CUDA_VISIBLE_DEVICES to control which GPU to use sepcifically. :param return_input: whether to produce (input, output) pair or just output. default to False.
""" """
def assert_type(v, tp): def assert_type(v, tp):
assert isinstance(v, tp), v.__class__ assert isinstance(v, tp), v.__class__
...@@ -54,7 +54,7 @@ class PredictConfig(object): ...@@ -54,7 +54,7 @@ class PredictConfig(object):
self.model = kwargs.pop('model') self.model = kwargs.pop('model')
self.input_data_mapping = kwargs.pop('input_data_mapping', None) self.input_data_mapping = kwargs.pop('input_data_mapping', None)
self.output_var_names = kwargs.pop('output_var_names') self.output_var_names = kwargs.pop('output_var_names')
self.nr_gpu = kwargs.pop('nr_gpu', 1) self.return_input = kwargs.pop('return_input', False)
assert len(kwargs) == 0, 'Unknown arguments: {}'.format(str(kwargs.keys())) assert len(kwargs) == 0, 'Unknown arguments: {}'.format(str(kwargs.keys()))
def get_predict_func(config): def get_predict_func(config):
......
...@@ -2,7 +2,8 @@ ...@@ -2,7 +2,8 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# File: concurrency.py # File: concurrency.py
# Author: Yuxin Wu <ppwwyyxxc@gmail.com> # Author: Yuxin Wu <ppwwyyxxc@gmail.com>
import multiprocessing
import multiprocessing, threading
import tensorflow as tf import tensorflow as tf
from ..utils.concurrency import DIE from ..utils.concurrency import DIE
from ..tfutils.modelutils import describe_model from ..tfutils.modelutils import describe_model
...@@ -11,16 +12,17 @@ from ..tfutils import * ...@@ -11,16 +12,17 @@ from ..tfutils import *
from .common import * from .common import *
__all__ = ['ParallelPredictWorker', 'QueuePredictWorker'] __all__ = ['MultiProcessPredictWorker', 'MultiProcessQueuePredictWorker']
class ParallelPredictWorker(multiprocessing.Process): class MultiProcessPredictWorker(multiprocessing.Process):
""" Base class for predict worker that runs in multiprocess"""
def __init__(self, idx, gpuid, config): def __init__(self, idx, gpuid, config):
""" """
:param idx: index of the worker. the 0th worker will print log. :param idx: index of the worker. the 0th worker will print log.
:param gpuid: absolute id of the GPU to be used. set to -1 to use CPU. :param gpuid: absolute id of the GPU to be used. set to -1 to use CPU.
:param config: a `PredictConfig` :param config: a `PredictConfig`
""" """
super(ParallelPredictWorker, self).__init__() super(MultiProcessPredictWorker, self).__init__()
self.idx = idx self.idx = idx
self.gpuid = gpuid self.gpuid = gpuid
self.config = config self.config = config
...@@ -41,17 +43,17 @@ class ParallelPredictWorker(multiprocessing.Process): ...@@ -41,17 +43,17 @@ class ParallelPredictWorker(multiprocessing.Process):
if self.idx == 0: if self.idx == 0:
describe_model() describe_model()
class QueuePredictWorker(ParallelPredictWorker): class MultiProcessQueuePredictWorker(MultiProcessPredictWorker):
""" A worker process to run predictor on one GPU """ """ A worker process to run predictor on one GPU """
def __init__(self, idx, gpuid, inqueue, outqueue, config): def __init__(self, idx, gpuid, inqueue, outqueue, config):
""" """
:param idx: index of the worker. the 0th worker will print log. :param idx: index of the worker. the 0th worker will print log.
:param gpuid: id of the GPU to be used. set to -1 to use CPU. :param gpuid: id of the GPU to be used. set to -1 to use CPU.
:param inqueue: input queue to get data point :param inqueue: input queue to get data point. elements are (task_id, dp)
:param outqueue: output queue put result :param outqueue: output queue put result. elements are (task_id, output)
:param config: a `PredictConfig` :param config: a `PredictConfig`
""" """
super(QueuePredictWorker, self).__init__(idx, gpuid, config) super(MultiProcessQueuePredictWorker, self).__init__(idx, gpuid, config)
self.inqueue = inqueue self.inqueue = inqueue
self.outqueue = outqueue self.outqueue = outqueue
...@@ -63,5 +65,19 @@ class QueuePredictWorker(ParallelPredictWorker): ...@@ -63,5 +65,19 @@ class QueuePredictWorker(ParallelPredictWorker):
self.outqueue.put((DIE, None)) self.outqueue.put((DIE, None))
return return
else: else:
res = PredictResult(dp, self.func(dp)) self.outqueue.put((tid, self.func(dp)))
self.outqueue.put((tid, res))
class MultiThreadPredictWorker(threading.Thread):
def __init__(self, idx, gpuid, config):
"""
:param idx: index of the worker. the 0th worker will print log.
:param gpuid: absolute id of the GPU to be used. set to -1 to use CPU.
:param config: a `PredictConfig`
"""
super(MultiProcessPredictWorker, self).__init__()
self.idx = idx
self.gpuid = gpuid
self.config = config
def run(self):
pass
...@@ -5,72 +5,117 @@ ...@@ -5,72 +5,117 @@
from six.moves import range from six.moves import range
from tqdm import tqdm from tqdm import tqdm
from abc import ABCMeta, abstractmethod
from ..dataflow import DataFlow, BatchData from ..dataflow import DataFlow, BatchData
from ..dataflow.dftools import dataflow_to_process_queue from ..dataflow.dftools import dataflow_to_process_queue
from ..utils.concurrency import ensure_proc_terminate, OrderedResultGatherProc, DIE from ..utils.concurrency import ensure_proc_terminate, OrderedResultGatherProc, DIE
from .concurrency import * from .concurrency import MultiProcessQueuePredictWorker
from .common import *
__all__ = ['DatasetPredictor'] __all__ = ['DatasetPredictorBase', 'SimpleDatasetPredictor',
'MultiProcessDatasetPredictor']
class DatasetPredictorBase(object):
__metaclass__ = ABCMeta
class DatasetPredictor(object):
"""
Run the predict_config on a given `DataFlow`.
"""
def __init__(self, config, dataset): def __init__(self, config, dataset):
""" """
:param config: a `PredictConfig` instance. :param config: a `PredictConfig` instance.
:param dataset: a `DataFlow` instance. :param dataset: a `DataFlow` instance.
""" """
assert isinstance(dataset, DataFlow) assert isinstance(dataset, DataFlow)
self.ds = dataset assert isinstance(config, PredictConfig)
self.nr_gpu = config.nr_gpu self.config = config
if self.nr_gpu > 1: self.dataset = dataset
self.inqueue, self.inqueue_proc = dataflow_to_process_queue(self.ds, 10, self.nr_gpu)
self.outqueue = multiprocessing.Queue()
try:
gpus = os.environ['CUDA_VISIBLE_DEVICES'].split(',')
except KeyError:
gpus = list(range(self.nr_gpu))
self.workers = [QueuePredictWorker(i, gpus[i], self.inqueue, self.outqueue, config)
for i in range(self.nr_gpu)]
self.result_queue = OrderedResultGatherProc(self.outqueue)
# setup all the procs
self.inqueue_proc.start()
for p in self.workers: p.start()
self.result_queue.start()
ensure_proc_terminate(self.workers)
ensure_proc_terminate([self.result_queue, self.inqueue_proc])
else:
self.func = get_predict_func(config)
@abstractmethod
def get_result(self): def get_result(self):
""" A generator to produce prediction for each data""" """ Generate (inpupt, output) pair of output, for each input in dataset"""
with tqdm(total=self.ds.size()) as pbar: pass
if self.nr_gpu == 1:
for dp in self.ds.get_data():
yield PredictResult(dp, self.func(dp))
pbar.update()
else:
die_cnt = 0
while True:
res = self.result_queue.get()
pbar.update()
if res[0] != DIE:
yield res[1]
else:
die_cnt += 1
if die_cnt == self.nr_gpu:
break
self.inqueue_proc.join()
self.inqueue_proc.terminate()
for p in self.workers:
p.join(); p.terminate()
def get_all_result(self): def get_all_result(self):
""" """
Run over the dataset and return a list of all predictions. Run over the dataset and return a list of all predictions.
""" """
return list(self.get_result()) return list(self.get_result())
class SimpleDatasetPredictor(DatasetPredictorBase):
"""
Run the predict_config on a given `DataFlow`.
"""
def __init__(self, config, dataset):
super(SimpleDatasetPredictor, self).__init__(config, dataset)
self.func = get_predict_func(config)
def get_result(self):
""" A generator to produce prediction for each data"""
with tqdm(total=self.dataset.size()) as pbar:
for dp in self.dataset.get_data():
res = self.func(dp)
if self.config.return_input:
yield (dp, res)
else:
yield res
pbar.update()
class MultiProcessDatasetPredictor(DatasetPredictorBase):
def __init__(self, config, dataset, nr_proc, use_gpu=True):
"""
Run prediction in multiprocesses, on either CPU or GPU. Mix mode not supported.
:param nr_proc: number of processes to use
:param use_gpu: use GPU or CPU.
nr_proc cannot be larger than the total number of GPUs available
in CUDA_VISIBLE_DEVICES or in the system.
"""
assert config.return_input == False, "return_input not supported for MultiProcessDatasetPredictor"
assert nr_proc > 1
super(MultiProcessDatasetPredictor, self).__init__(config, dataset)
self.nr_proc = nr_proc
self.inqueue, self.inqueue_proc = dataflow_to_process_queue(
self.dataset, nr_proc * 2, self.nr_proc)
self.outqueue = multiprocessing.Queue()
if use_gpu:
try:
gpus = os.environ['CUDA_VISIBLE_DEVICES'].split(',')
assert len(gpus) >= self.nr_proc, \
"nr_proc={} while only {} gpus available".format(
self.nr_proc, len(gpus))
except KeyError:
# TODO number of GPUs not checked
gpus = list(range(self.nr_gpu))
else:
gpus = [-1] * self.nr_proc
self.workers = [MultiProcessQueuePredictWorker(
i, gpus[i], self.inqueue, self.outqueue, self.config)
for i in range(self.nr_proc)]
self.result_queue = OrderedResultGatherProc(
self.outqueue, nr_producer=self.nr_proc)
# setup all the procs
self.inqueue_proc.start()
for p in self.workers: p.start()
self.result_queue.start()
ensure_proc_terminate(self.workers + [self.result_queue, self.inqueue_proc])
def get_result(self):
with tqdm(total=self.dataset.size()) as pbar:
die_cnt = 0
while True:
res = self.result_queue.get()
pbar.update()
if res[0] != DIE:
yield res[1]
else:
die_cnt += 1
if die_cnt == self.nr_proc:
break
self.inqueue_proc.join()
self.inqueue_proc.terminate()
self.result_queue.join()
self.result_queue.terminate()
for p in self.workers:
p.join(); p.terminate()
...@@ -123,18 +123,28 @@ class OrderedResultGatherProc(multiprocessing.Process): ...@@ -123,18 +123,28 @@ class OrderedResultGatherProc(multiprocessing.Process):
Gather indexed data from a data queue, and produce results with the Gather indexed data from a data queue, and produce results with the
original index-based order. original index-based order.
""" """
def __init__(self, data_queue, start=0): def __init__(self, data_queue, nr_producer, start=0):
super(self.__class__, self).__init__() """
:param data_queue: a multiprocessing.Queue to produce input dp
:param nr_producer: number of producer processes. Will terminate after receiving this many of DIE sentinel.
:param start: the first task index
"""
super(OrderedResultGatherProc, self).__init__()
self.data_queue = data_queue self.data_queue = data_queue
self.ordered_container = OrderedContainer(start=start) self.ordered_container = OrderedContainer(start=start)
self.result_queue = multiprocessing.Queue() self.result_queue = multiprocessing.Queue()
self.nr_producer = nr_producer
def run(self): def run(self):
nr_end = 0
try: try:
while True: while True:
task_id, data = self.data_queue.get() task_id, data = self.data_queue.get()
if task_id == DIE: if task_id == DIE:
self.result_queue.put((task_id, data)) self.result_queue.put((task_id, data))
nr_end += 1
if nr_end == self.nr_producer:
return
else: else:
self.ordered_container.put(task_id, data) self.ordered_container.put(task_id, data)
while self.ordered_container.has_next(): while self.ordered_container.has_next():
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment