Commit bbc17cb1 authored by Yuxin Wu's avatar Yuxin Wu

some naming and code migration

parent 14a28c01
tensorpack.RL package
=====================
Submodules
----------
tensorpack.RL.atari module
--------------------------
.. automodule:: tensorpack.RL.atari
:members:
:undoc-members:
:show-inheritance:
tensorpack.RL.common module
---------------------------
.. automodule:: tensorpack.RL.common
:members:
:undoc-members:
:show-inheritance:
tensorpack.RL.envbase module
----------------------------
.. automodule:: tensorpack.RL.envbase
:members:
:undoc-members:
:show-inheritance:
tensorpack.RL.expreplay module
------------------------------
.. automodule:: tensorpack.RL.expreplay
:members:
:undoc-members:
:show-inheritance:
tensorpack.RL.history module
----------------------------
.. automodule:: tensorpack.RL.history
:members:
:undoc-members:
:show-inheritance:
tensorpack.RL.simulator module
------------------------------
.. automodule:: tensorpack.RL.simulator
:members:
:undoc-members:
:show-inheritance:
Module contents
---------------
.. automodule:: tensorpack.RL
:members:
:undoc-members:
:show-inheritance:
tensorpack.predict package
==========================
Submodules
----------
tensorpack.predict.common module
--------------------------------
.. automodule:: tensorpack.predict.common
:members:
:undoc-members:
:show-inheritance:
tensorpack.predict.concurrency module
-------------------------------------
.. automodule:: tensorpack.predict.concurrency
:members:
:undoc-members:
:show-inheritance:
tensorpack.predict.dataset module
---------------------------------
.. automodule:: tensorpack.predict.dataset
:members:
:undoc-members:
:show-inheritance:
Module contents
---------------
.. automodule:: tensorpack.predict
:members:
:undoc-members:
:show-inheritance:
...@@ -196,7 +196,7 @@ if __name__ == '__main__': ...@@ -196,7 +196,7 @@ if __name__ == '__main__':
cfg = PredictConfig( cfg = PredictConfig(
model=Model(), model=Model(),
session_init=SaverRestore(args.load), session_init=SaverRestore(args.load),
input_var_names=['state'] input_var_names=['state'],
output_var_names=['fct/output:0']) output_var_names=['fct/output:0'])
if args.task == 'play': if args.task == 'play':
play_model(cfg) play_model(cfg)
......
...@@ -16,7 +16,7 @@ from ..utils import logger ...@@ -16,7 +16,7 @@ from ..utils import logger
from ..utils.timer import * from ..utils.timer import *
from ..tfutils import * from ..tfutils import *
from .common import * from .base import OfflinePredictor
try: try:
if six.PY2: if six.PY2:
...@@ -24,7 +24,7 @@ try: ...@@ -24,7 +24,7 @@ try:
else: else:
from concurrent.futures import Future from concurrent.futures import Future
except ImportError: except ImportError:
logger.warn("Cannot import Future in either tornado.concurrent or py3 standard lib. MultiThreadAsyncPredictor won't be available.") logger.warn("Cannot import Future in tornado.concurrent. MultiThreadAsyncPredictor won't be available.")
__all__ = ['MultiProcessPredictWorker', 'MultiProcessQueuePredictWorker'] __all__ = ['MultiProcessPredictWorker', 'MultiProcessQueuePredictWorker']
else: else:
__all__ = ['MultiProcessPredictWorker', 'MultiProcessQueuePredictWorker', __all__ = ['MultiProcessPredictWorker', 'MultiProcessQueuePredictWorker',
...@@ -32,41 +32,31 @@ else: ...@@ -32,41 +32,31 @@ else:
class MultiProcessPredictWorker(multiprocessing.Process): class MultiProcessPredictWorker(multiprocessing.Process):
""" Base class for predict worker that runs offline in multiprocess""" """ Base class for predict worker that runs offline in multiprocess"""
def __init__(self, idx, gpuid, config): def __init__(self, idx, config):
""" """
:param idx: index of the worker. the 0th worker will print log. :param idx: index of the worker. the 0th worker will print log.
:param gpuid: absolute id of the GPU to be used. set to -1 to use CPU.
:param config: a `PredictConfig` :param config: a `PredictConfig`
""" """
super(MultiProcessPredictWorker, self).__init__() super(MultiProcessPredictWorker, self).__init__()
self.idx = idx self.idx = idx
self.gpuid = gpuid
self.config = config self.config = config
def _init_runtime(self): def _init_runtime(self):
if self.gpuid >= 0: if self.idx != 0:
logger.info("Worker {} uses GPU {}".format(self.idx, self.gpuid)) from tensorpack.models._common import disable_layer_logging
os.environ['CUDA_VISIBLE_DEVICES'] = str(self.gpuid) disable_layer_logging()
else: self.func = OfflinePredictor(self.config)
logger.info("Worker {} uses CPU".format(self.idx)) if self.idx == 0:
os.environ['CUDA_VISIBLE_DEVICES'] = '' describe_model()
G = tf.Graph() # build a graph for each process, because they don't need to share anything
with G.as_default():
if self.idx != 0:
from tensorpack.models._common import disable_layer_logging
disable_layer_logging()
self.func = get_predict_func(self.config)
if self.idx == 0:
describe_model()
class MultiProcessQueuePredictWorker(MultiProcessPredictWorker): class MultiProcessQueuePredictWorker(MultiProcessPredictWorker):
""" An offline predictor worker that takes input and produces output by queue""" """ An offline predictor worker that takes input and produces output by queue"""
def __init__(self, idx, gpuid, inqueue, outqueue, config): def __init__(self, idx, inqueue, outqueue, config):
""" """
:param inqueue: input queue to get data point. elements are (task_id, dp) :param inqueue: input queue to get data point. elements are (task_id, dp)
:param outqueue: output queue put result. elements are (task_id, output) :param outqueue: output queue put result. elements are (task_id, output)
""" """
super(MultiProcessQueuePredictWorker, self).__init__(idx, gpuid, config) super(MultiProcessQueuePredictWorker, self).__init__(idx, config)
self.inqueue = inqueue self.inqueue = inqueue
self.outqueue = outqueue self.outqueue = outqueue
assert isinstance(self.inqueue, multiprocessing.Queue) assert isinstance(self.inqueue, multiprocessing.Queue)
......
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
# File: dataset.py # File: dataset.py
# Author: Yuxin Wu <ppwwyyxxc@gmail.com> # Author: Yuxin Wu <ppwwyyxxc@gmail.com>
from six.moves import range from six.moves import range, zip
from tqdm import tqdm from tqdm import tqdm
from abc import ABCMeta, abstractmethod from abc import ABCMeta, abstractmethod
import multiprocessing import multiprocessing
...@@ -14,7 +14,8 @@ from ..dataflow.dftools import dataflow_to_process_queue ...@@ -14,7 +14,8 @@ from ..dataflow.dftools import dataflow_to_process_queue
from ..utils.concurrency import ensure_proc_terminate, OrderedResultGatherProc, DIE from ..utils.concurrency import ensure_proc_terminate, OrderedResultGatherProc, DIE
from .concurrency import MultiProcessQueuePredictWorker from .concurrency import MultiProcessQueuePredictWorker
from .common import * from .common import PredictConfig
from .base import OfflinePredictor
__all__ = ['DatasetPredictorBase', 'SimpleDatasetPredictor', __all__ = ['DatasetPredictorBase', 'SimpleDatasetPredictor',
'MultiProcessDatasetPredictor'] 'MultiProcessDatasetPredictor']
...@@ -34,7 +35,7 @@ class DatasetPredictorBase(object): ...@@ -34,7 +35,7 @@ class DatasetPredictorBase(object):
@abstractmethod @abstractmethod
def get_result(self): def get_result(self):
""" Generate (inpupt, output) pair of output, for each input in dataset""" """ A generator function, produce output for each input in dataset"""
pass pass
def get_all_result(self): def get_all_result(self):
...@@ -49,17 +50,14 @@ class SimpleDatasetPredictor(DatasetPredictorBase): ...@@ -49,17 +50,14 @@ class SimpleDatasetPredictor(DatasetPredictorBase):
""" """
def __init__(self, config, dataset): def __init__(self, config, dataset):
super(SimpleDatasetPredictor, self).__init__(config, dataset) super(SimpleDatasetPredictor, self).__init__(config, dataset)
self.func = get_predict_func(config) self.predictor = OfflinePredictor(config)
def get_result(self): def get_result(self):
""" A generator to produce prediction for each data""" """ A generator to produce prediction for each data"""
with tqdm(total=self.dataset.size()) as pbar: with tqdm(total=self.dataset.size()) as pbar:
for dp in self.dataset.get_data(): for dp in self.dataset.get_data():
res = self.func(dp) res = self.predictor(dp)
if self.config.return_input: yield res
yield (dp, res)
else:
yield res
pbar.update() pbar.update()
class MultiProcessDatasetPredictor(DatasetPredictorBase): class MultiProcessDatasetPredictor(DatasetPredictorBase):
...@@ -69,11 +67,11 @@ class MultiProcessDatasetPredictor(DatasetPredictorBase): ...@@ -69,11 +67,11 @@ class MultiProcessDatasetPredictor(DatasetPredictorBase):
:param nr_proc: number of processes to use :param nr_proc: number of processes to use
:param use_gpu: use GPU or CPU. :param use_gpu: use GPU or CPU.
If GPU, then nr_proc cannot be larger than the total number of GPUs available If GPU, then nr_proc cannot be more than what's in CUDA_VISIBLE_DEVICES
in CUDA_VISIBLE_DEVICES or in the system.
""" """
assert config.return_input == False, "return_input not supported for MultiProcessDatasetPredictor" if config.return_input:
assert nr_proc > 1 logger.warn("Using the option `return_input` in MultiProcessDatasetPredictor might be slow")
assert nr_proc > 1, nr_proc
super(MultiProcessDatasetPredictor, self).__init__(config, dataset) super(MultiProcessDatasetPredictor, self).__init__(config, dataset)
self.nr_proc = nr_proc self.nr_proc = nr_proc
...@@ -91,7 +89,7 @@ class MultiProcessDatasetPredictor(DatasetPredictorBase): ...@@ -91,7 +89,7 @@ class MultiProcessDatasetPredictor(DatasetPredictorBase):
# TODO number of GPUs not checked # TODO number of GPUs not checked
gpus = list(range(self.nr_gpu)) gpus = list(range(self.nr_gpu))
else: else:
gpus = [-1] * self.nr_proc gpus = [''] * self.nr_proc
self.workers = [MultiProcessQueuePredictWorker( self.workers = [MultiProcessQueuePredictWorker(
i, gpus[i], self.inqueue, self.outqueue, self.config) i, gpus[i], self.inqueue, self.outqueue, self.config)
for i in range(self.nr_proc)] for i in range(self.nr_proc)]
...@@ -100,7 +98,13 @@ class MultiProcessDatasetPredictor(DatasetPredictorBase): ...@@ -100,7 +98,13 @@ class MultiProcessDatasetPredictor(DatasetPredictorBase):
# setup all the procs # setup all the procs
self.inqueue_proc.start() self.inqueue_proc.start()
for p in self.workers: p.start() for p, gpuid in zip(self.workers, gpus):
if gpuid == '':
logger.info("Worker {} uses CPU".format(p.idx))
else:
logger.info("Worker {} uses GPU {}".format(p.idx, gpuid))
with change_gpu(gpuid):
p.start()
self.result_queue.start() self.result_queue.start()
ensure_proc_terminate(self.workers + [self.result_queue, self.inqueue_proc]) ensure_proc_terminate(self.workers + [self.result_queue, self.inqueue_proc])
......
...@@ -211,7 +211,7 @@ class QueueInputTrainer(Trainer): ...@@ -211,7 +211,7 @@ class QueueInputTrainer(Trainer):
def get_predict_func(self, input_names, output_names, tower=0): def get_predict_func(self, input_names, output_names, tower=0):
""" """
:param tower: return the kth predict_func :param tower: return the kth predict_func
:returns: a predictor function :returns: an `OnlinePredictor`
""" """
tower = self.predict_tower[tower % len(self.predict_tower)] tower = self.predict_tower[tower % len(self.predict_tower)]
raw_input_vars = get_vars_by_names(input_names) raw_input_vars = get_vars_by_names(input_names)
...@@ -220,7 +220,7 @@ class QueueInputTrainer(Trainer): ...@@ -220,7 +220,7 @@ class QueueInputTrainer(Trainer):
return OnlinePredictor(self.sess, raw_input_vars, output_vars) return OnlinePredictor(self.sess, raw_input_vars, output_vars)
def get_predict_funcs(self, input_names, output_names, n): def get_predict_funcs(self, input_names, output_names, n):
""" return n predicts functions evenly on each predict_tower""" """ return n predictors evenly on each predict_tower"""
return [self.get_predict_func(input_names, output_names, k) return [self.get_predict_func(input_names, output_names, k)
for k in range(n)] for k in range(n)]
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment