Commit 57fb68fa authored by Yuxin Wu's avatar Yuxin Wu

unordered datasetpredictor & more tqdm

parent a2f4f439
...@@ -4,14 +4,13 @@ ...@@ -4,14 +4,13 @@
import tensorflow as tf import tensorflow as tf
import numpy as np import numpy as np
from tqdm import tqdm
from abc import ABCMeta, abstractmethod from abc import ABCMeta, abstractmethod
from collections import namedtuple from collections import namedtuple
import six import six
from six.moves import zip, map from six.moves import zip, map
from ..dataflow import DataFlow from ..dataflow import DataFlow
from ..utils import get_tqdm_kwargs, logger, execute_only_once from ..utils import get_tqdm, logger, execute_only_once
from ..utils.stat import RatioCounter, BinaryStatistics from ..utils.stat import RatioCounter, BinaryStatistics
from ..tfutils import get_op_tensor_name, get_op_var_name from ..tfutils import get_op_tensor_name, get_op_var_name
from .base import Callback from .base import Callback
...@@ -124,7 +123,7 @@ class InferenceRunner(Callback): ...@@ -124,7 +123,7 @@ class InferenceRunner(Callback):
sess = tf.get_default_session() sess = tf.get_default_session()
self.ds.reset_state() self.ds.reset_state()
with tqdm(total=self.ds.size(), **get_tqdm_kwargs()) as pbar: with get_tqdm(total=self.ds.size()) as pbar:
for dp in self.ds.get_data(): for dp in self.ds.get_data():
outputs = self.pred_func(dp) outputs = self.pred_func(dp)
for inf, tensormap in zip(self.infs, self.inf_to_tensors): for inf, tensormap in zip(self.infs, self.inf_to_tensors):
......
...@@ -8,7 +8,7 @@ import numpy as np ...@@ -8,7 +8,7 @@ import numpy as np
from collections import deque, defaultdict from collections import deque, defaultdict
from six.moves import range, map from six.moves import range, map
from .base import DataFlow, ProxyDataFlow, RNGDataFlow from .base import DataFlow, ProxyDataFlow, RNGDataFlow
from ..utils import * from ..utils import logger, get_tqdm
__all__ = ['BatchData', 'FixedSizeData', 'MapData', __all__ = ['BatchData', 'FixedSizeData', 'MapData',
'RepeatedData', 'MapDataComponent', 'RandomChooseData', 'RepeatedData', 'MapDataComponent', 'RandomChooseData',
...@@ -21,8 +21,7 @@ class TestDataSpeed(ProxyDataFlow): ...@@ -21,8 +21,7 @@ class TestDataSpeed(ProxyDataFlow):
self.test_size = size self.test_size = size
def get_data(self): def get_data(self):
from tqdm import tqdm with get_tqdm(total=range(self.test_size)) as pbar:
with tqdm(range(self.test_size), **get_tqdm_kwargs()) as pbar:
for dp in self.ds.get_data(): for dp in self.ds.get_data():
pbar.update() pbar.update()
for dp in self.ds.get_data(): for dp in self.ds.get_data():
......
...@@ -3,10 +3,9 @@ ...@@ -3,10 +3,9 @@
# Author: Yuxin Wu <ppwwyyxxc@gmail.com> # Author: Yuxin Wu <ppwwyyxxc@gmail.com>
import numpy as np import numpy as np
from tqdm import tqdm
from six.moves import range from six.moves import range
from ..utils import logger, get_rng, get_tqdm_kwargs from ..utils import logger, get_rng, get_tqdm
from ..utils.timer import timed_operation from ..utils.timer import timed_operation
from ..utils.loadcaffe import get_caffe_pb from ..utils.loadcaffe import get_caffe_pb
from .base import RNGDataFlow from .base import RNGDataFlow
...@@ -82,7 +81,7 @@ class LMDBData(RNGDataFlow): ...@@ -82,7 +81,7 @@ class LMDBData(RNGDataFlow):
if not self.keys: if not self.keys:
self.keys = [] self.keys = []
with timed_operation("Loading LMDB keys ...", log_start=True), \ with timed_operation("Loading LMDB keys ...", log_start=True), \
tqdm(get_tqdm_kwargs(total=self._size)) as pbar: get_tqdm(total=self._size) as pbar:
for k in self._txn.cursor(): for k in self._txn.cursor():
if k != '__keys__': if k != '__keys__':
self.keys.append(k) self.keys.append(k)
......
...@@ -4,7 +4,6 @@ ...@@ -4,7 +4,6 @@
# Author: Yuxin Wu <ppwwyyxxc@gmail.com> # Author: Yuxin Wu <ppwwyyxxc@gmail.com>
from six.moves import range, zip from six.moves import range, zip
from tqdm import tqdm
from abc import ABCMeta, abstractmethod from abc import ABCMeta, abstractmethod
import multiprocessing import multiprocessing
import os import os
...@@ -12,7 +11,7 @@ import os ...@@ -12,7 +11,7 @@ import os
from ..dataflow import DataFlow, BatchData from ..dataflow import DataFlow, BatchData
from ..dataflow.dftools import dataflow_to_process_queue from ..dataflow.dftools import dataflow_to_process_queue
from ..utils.concurrency import ensure_proc_terminate, OrderedResultGatherProc, DIE from ..utils.concurrency import ensure_proc_terminate, OrderedResultGatherProc, DIE
from ..utils import logger from ..utils import logger, get_tqdm
from ..utils.gpu import change_gpu from ..utils.gpu import change_gpu
from .concurrency import MultiProcessQueuePredictWorker from .concurrency import MultiProcessQueuePredictWorker
...@@ -60,7 +59,7 @@ class SimpleDatasetPredictor(DatasetPredictorBase): ...@@ -60,7 +59,7 @@ class SimpleDatasetPredictor(DatasetPredictorBase):
sz = self.dataset.size() sz = self.dataset.size()
except NotImplementedError: except NotImplementedError:
sz = 0 sz = 0
with tqdm(total=sz, disable=(sz==0)) as pbar: with get_tqdm(total=sz, disable=(sz==0)) as pbar:
for dp in self.dataset.get_data(): for dp in self.dataset.get_data():
res = self.predictor(dp) res = self.predictor(dp)
yield res yield res
...@@ -68,13 +67,15 @@ class SimpleDatasetPredictor(DatasetPredictorBase): ...@@ -68,13 +67,15 @@ class SimpleDatasetPredictor(DatasetPredictorBase):
# TODO allow unordered # TODO allow unordered
class MultiProcessDatasetPredictor(DatasetPredictorBase): class MultiProcessDatasetPredictor(DatasetPredictorBase):
def __init__(self, config, dataset, nr_proc, use_gpu=True): def __init__(self, config, dataset, nr_proc, use_gpu=True, ordered=True):
""" """
Run prediction in multiprocesses, on either CPU or GPU. Mix mode not supported. Run prediction in multiprocesses, on either CPU or GPU. Mix mode not supported.
:param nr_proc: number of processes to use :param nr_proc: number of processes to use
:param use_gpu: use GPU or CPU. :param use_gpu: use GPU or CPU.
If GPU, then nr_proc cannot be more than what's in CUDA_VISIBLE_DEVICES If GPU, then nr_proc cannot be more than what's in CUDA_VISIBLE_DEVICES
:param ordered: produce results with the original order of the
dataflow. a bit slower.
""" """
if config.return_input: if config.return_input:
logger.warn("Using the option `return_input` in MultiProcessDatasetPredictor might be slow") logger.warn("Using the option `return_input` in MultiProcessDatasetPredictor might be slow")
...@@ -82,10 +83,11 @@ class MultiProcessDatasetPredictor(DatasetPredictorBase): ...@@ -82,10 +83,11 @@ class MultiProcessDatasetPredictor(DatasetPredictorBase):
super(MultiProcessDatasetPredictor, self).__init__(config, dataset) super(MultiProcessDatasetPredictor, self).__init__(config, dataset)
self.nr_proc = nr_proc self.nr_proc = nr_proc
self.ordered = ordered
self.inqueue, self.inqueue_proc = dataflow_to_process_queue( self.inqueue, self.inqueue_proc = dataflow_to_process_queue(
self.dataset, nr_proc * 2, self.nr_proc) self.dataset, nr_proc * 2, self.nr_proc) # put (idx, dp) to inqueue
self.outqueue = multiprocessing.Queue()
if use_gpu: if use_gpu:
try: try:
gpus = os.environ['CUDA_VISIBLE_DEVICES'].split(',') gpus = os.environ['CUDA_VISIBLE_DEVICES'].split(',')
...@@ -97,13 +99,13 @@ class MultiProcessDatasetPredictor(DatasetPredictorBase): ...@@ -97,13 +99,13 @@ class MultiProcessDatasetPredictor(DatasetPredictorBase):
gpus = list(range(self.nr_proc)) gpus = list(range(self.nr_proc))
else: else:
gpus = ['-1'] * self.nr_proc gpus = ['-1'] * self.nr_proc
# worker produces (idx, result) to outqueue
self.outqueue = multiprocessing.Queue()
self.workers = [MultiProcessQueuePredictWorker( self.workers = [MultiProcessQueuePredictWorker(
i, self.inqueue, self.outqueue, self.config) i, self.inqueue, self.outqueue, self.config)
for i in range(self.nr_proc)] for i in range(self.nr_proc)]
self.result_queue = OrderedResultGatherProc(
self.outqueue, nr_producer=self.nr_proc)
# setup all the procs # start inqueue and workers
self.inqueue_proc.start() self.inqueue_proc.start()
for p, gpuid in zip(self.workers, gpus): for p, gpuid in zip(self.workers, gpus):
if gpuid == '-1': if gpuid == '-1':
...@@ -112,15 +114,22 @@ class MultiProcessDatasetPredictor(DatasetPredictorBase): ...@@ -112,15 +114,22 @@ class MultiProcessDatasetPredictor(DatasetPredictorBase):
logger.info("Worker {} uses GPU {}".format(p.idx, gpuid)) logger.info("Worker {} uses GPU {}".format(p.idx, gpuid))
with change_gpu(gpuid): with change_gpu(gpuid):
p.start() p.start()
self.result_queue.start()
ensure_proc_terminate(self.workers + [self.result_queue, self.inqueue_proc]) if ordered:
self.result_queue = OrderedResultGatherProc(
self.outqueue, nr_producer=self.nr_proc)
self.result_queue.start()
ensure_proc_terminate(self.result_queue)
else:
self.result_queue = self.outqueue
ensure_proc_terminate(self.workers + [self.inqueue_proc])
def get_result(self): def get_result(self):
try: try:
sz = self.dataset.size() sz = self.dataset.size()
except NotImplementedError: except NotImplementedError:
sz = 0 sz = 0
with tqdm(total=sz, disable=(sz==0)) as pbar: with get_tqdm(total=sz, disable=(sz==0)) as pbar:
die_cnt = 0 die_cnt = 0
while True: while True:
res = self.result_queue.get() res = self.result_queue.get()
...@@ -133,7 +142,8 @@ class MultiProcessDatasetPredictor(DatasetPredictorBase): ...@@ -133,7 +142,8 @@ class MultiProcessDatasetPredictor(DatasetPredictorBase):
break break
self.inqueue_proc.join() self.inqueue_proc.join()
self.inqueue_proc.terminate() self.inqueue_proc.terminate()
self.result_queue.join() if self.ordered: # if ordered, than result_queue is a Process
self.result_queue.terminate() self.result_queue.join()
self.result_queue.terminate()
for p in self.workers: for p in self.workers:
p.join(); p.terminate() p.join(); p.terminate()
...@@ -6,6 +6,7 @@ import os, sys ...@@ -6,6 +6,7 @@ import os, sys
from contextlib import contextmanager from contextlib import contextmanager
import inspect import inspect
from datetime import datetime from datetime import datetime
from tqdm import tqdm
import time import time
import numpy as np import numpy as np
...@@ -13,6 +14,7 @@ __all__ = ['change_env', ...@@ -13,6 +14,7 @@ __all__ = ['change_env',
'get_rng', 'get_rng',
'get_dataset_path', 'get_dataset_path',
'get_tqdm_kwargs', 'get_tqdm_kwargs',
'get_tqdm',
'execute_only_once' 'execute_only_once'
] ]
...@@ -73,3 +75,6 @@ def get_tqdm_kwargs(**kwargs): ...@@ -73,3 +75,6 @@ def get_tqdm_kwargs(**kwargs):
default['mininterval'] = 60 default['mininterval'] = 60
default.update(kwargs) default.update(kwargs)
return default return default
def get_tqdm(**kwargs):
return tqdm(**get_tqdm_kwargs(**kwargs))
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment