Commit 57fb68fa authored by Yuxin Wu's avatar Yuxin Wu

unordered datasetpredictor & more tqdm

parent a2f4f439
......@@ -4,14 +4,13 @@
import tensorflow as tf
import numpy as np
from tqdm import tqdm
from abc import ABCMeta, abstractmethod
from collections import namedtuple
import six
from six.moves import zip, map
from ..dataflow import DataFlow
from ..utils import get_tqdm_kwargs, logger, execute_only_once
from ..utils import get_tqdm, logger, execute_only_once
from ..utils.stat import RatioCounter, BinaryStatistics
from ..tfutils import get_op_tensor_name, get_op_var_name
from .base import Callback
......@@ -124,7 +123,7 @@ class InferenceRunner(Callback):
sess = tf.get_default_session()
self.ds.reset_state()
with tqdm(total=self.ds.size(), **get_tqdm_kwargs()) as pbar:
with get_tqdm(total=self.ds.size()) as pbar:
for dp in self.ds.get_data():
outputs = self.pred_func(dp)
for inf, tensormap in zip(self.infs, self.inf_to_tensors):
......
......@@ -8,7 +8,7 @@ import numpy as np
from collections import deque, defaultdict
from six.moves import range, map
from .base import DataFlow, ProxyDataFlow, RNGDataFlow
from ..utils import *
from ..utils import logger, get_tqdm
__all__ = ['BatchData', 'FixedSizeData', 'MapData',
'RepeatedData', 'MapDataComponent', 'RandomChooseData',
......@@ -21,8 +21,7 @@ class TestDataSpeed(ProxyDataFlow):
self.test_size = size
def get_data(self):
from tqdm import tqdm
with tqdm(range(self.test_size), **get_tqdm_kwargs()) as pbar:
with get_tqdm(total=range(self.test_size)) as pbar:
for dp in self.ds.get_data():
pbar.update()
for dp in self.ds.get_data():
......
......@@ -3,10 +3,9 @@
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
import numpy as np
from tqdm import tqdm
from six.moves import range
from ..utils import logger, get_rng, get_tqdm_kwargs
from ..utils import logger, get_rng, get_tqdm
from ..utils.timer import timed_operation
from ..utils.loadcaffe import get_caffe_pb
from .base import RNGDataFlow
......@@ -82,7 +81,7 @@ class LMDBData(RNGDataFlow):
if not self.keys:
self.keys = []
with timed_operation("Loading LMDB keys ...", log_start=True), \
tqdm(get_tqdm_kwargs(total=self._size)) as pbar:
get_tqdm(total=self._size) as pbar:
for k in self._txn.cursor():
if k != '__keys__':
self.keys.append(k)
......
......@@ -4,7 +4,6 @@
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
from six.moves import range, zip
from tqdm import tqdm
from abc import ABCMeta, abstractmethod
import multiprocessing
import os
......@@ -12,7 +11,7 @@ import os
from ..dataflow import DataFlow, BatchData
from ..dataflow.dftools import dataflow_to_process_queue
from ..utils.concurrency import ensure_proc_terminate, OrderedResultGatherProc, DIE
from ..utils import logger
from ..utils import logger, get_tqdm
from ..utils.gpu import change_gpu
from .concurrency import MultiProcessQueuePredictWorker
......@@ -60,7 +59,7 @@ class SimpleDatasetPredictor(DatasetPredictorBase):
sz = self.dataset.size()
except NotImplementedError:
sz = 0
with tqdm(total=sz, disable=(sz==0)) as pbar:
with get_tqdm(total=sz, disable=(sz==0)) as pbar:
for dp in self.dataset.get_data():
res = self.predictor(dp)
yield res
......@@ -68,13 +67,15 @@ class SimpleDatasetPredictor(DatasetPredictorBase):
# TODO allow unordered
class MultiProcessDatasetPredictor(DatasetPredictorBase):
def __init__(self, config, dataset, nr_proc, use_gpu=True):
def __init__(self, config, dataset, nr_proc, use_gpu=True, ordered=True):
"""
Run prediction in multiprocesses, on either CPU or GPU. Mix mode not supported.
:param nr_proc: number of processes to use
:param use_gpu: use GPU or CPU.
If GPU, then nr_proc cannot be more than what's in CUDA_VISIBLE_DEVICES
:param ordered: produce results with the original order of the
dataflow. a bit slower.
"""
if config.return_input:
logger.warn("Using the option `return_input` in MultiProcessDatasetPredictor might be slow")
......@@ -82,10 +83,11 @@ class MultiProcessDatasetPredictor(DatasetPredictorBase):
super(MultiProcessDatasetPredictor, self).__init__(config, dataset)
self.nr_proc = nr_proc
self.ordered = ordered
self.inqueue, self.inqueue_proc = dataflow_to_process_queue(
self.dataset, nr_proc * 2, self.nr_proc)
self.outqueue = multiprocessing.Queue()
self.dataset, nr_proc * 2, self.nr_proc) # put (idx, dp) to inqueue
if use_gpu:
try:
gpus = os.environ['CUDA_VISIBLE_DEVICES'].split(',')
......@@ -97,13 +99,13 @@ class MultiProcessDatasetPredictor(DatasetPredictorBase):
gpus = list(range(self.nr_proc))
else:
gpus = ['-1'] * self.nr_proc
# worker produces (idx, result) to outqueue
self.outqueue = multiprocessing.Queue()
self.workers = [MultiProcessQueuePredictWorker(
i, self.inqueue, self.outqueue, self.config)
for i in range(self.nr_proc)]
self.result_queue = OrderedResultGatherProc(
self.outqueue, nr_producer=self.nr_proc)
# setup all the procs
# start inqueue and workers
self.inqueue_proc.start()
for p, gpuid in zip(self.workers, gpus):
if gpuid == '-1':
......@@ -112,15 +114,22 @@ class MultiProcessDatasetPredictor(DatasetPredictorBase):
logger.info("Worker {} uses GPU {}".format(p.idx, gpuid))
with change_gpu(gpuid):
p.start()
self.result_queue.start()
ensure_proc_terminate(self.workers + [self.result_queue, self.inqueue_proc])
if ordered:
self.result_queue = OrderedResultGatherProc(
self.outqueue, nr_producer=self.nr_proc)
self.result_queue.start()
ensure_proc_terminate(self.result_queue)
else:
self.result_queue = self.outqueue
ensure_proc_terminate(self.workers + [self.inqueue_proc])
def get_result(self):
try:
sz = self.dataset.size()
except NotImplementedError:
sz = 0
with tqdm(total=sz, disable=(sz==0)) as pbar:
with get_tqdm(total=sz, disable=(sz==0)) as pbar:
die_cnt = 0
while True:
res = self.result_queue.get()
......@@ -133,7 +142,8 @@ class MultiProcessDatasetPredictor(DatasetPredictorBase):
break
self.inqueue_proc.join()
self.inqueue_proc.terminate()
self.result_queue.join()
self.result_queue.terminate()
if self.ordered: # if ordered, than result_queue is a Process
self.result_queue.join()
self.result_queue.terminate()
for p in self.workers:
p.join(); p.terminate()
......@@ -6,6 +6,7 @@ import os, sys
from contextlib import contextmanager
import inspect
from datetime import datetime
from tqdm import tqdm
import time
import numpy as np
......@@ -13,6 +14,7 @@ __all__ = ['change_env',
'get_rng',
'get_dataset_path',
'get_tqdm_kwargs',
'get_tqdm',
'execute_only_once'
]
......@@ -73,3 +75,6 @@ def get_tqdm_kwargs(**kwargs):
default['mininterval'] = 60
default.update(kwargs)
return default
def get_tqdm(**kwargs):
return tqdm(**get_tqdm_kwargs(**kwargs))
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment