Commit bbc17cb1 authored by Yuxin Wu's avatar Yuxin Wu

some naming and code migration

parent 14a28c01
tensorpack.RL package
=====================
Submodules
----------
tensorpack.RL.atari module
--------------------------
.. automodule:: tensorpack.RL.atari
:members:
:undoc-members:
:show-inheritance:
tensorpack.RL.common module
---------------------------
.. automodule:: tensorpack.RL.common
:members:
:undoc-members:
:show-inheritance:
tensorpack.RL.envbase module
----------------------------
.. automodule:: tensorpack.RL.envbase
:members:
:undoc-members:
:show-inheritance:
tensorpack.RL.expreplay module
------------------------------
.. automodule:: tensorpack.RL.expreplay
:members:
:undoc-members:
:show-inheritance:
tensorpack.RL.history module
----------------------------
.. automodule:: tensorpack.RL.history
:members:
:undoc-members:
:show-inheritance:
tensorpack.RL.simulator module
------------------------------
.. automodule:: tensorpack.RL.simulator
:members:
:undoc-members:
:show-inheritance:
Module contents
---------------
.. automodule:: tensorpack.RL
:members:
:undoc-members:
:show-inheritance:
tensorpack.predict package
==========================
Submodules
----------
tensorpack.predict.common module
--------------------------------
.. automodule:: tensorpack.predict.common
:members:
:undoc-members:
:show-inheritance:
tensorpack.predict.concurrency module
-------------------------------------
.. automodule:: tensorpack.predict.concurrency
:members:
:undoc-members:
:show-inheritance:
tensorpack.predict.dataset module
---------------------------------
.. automodule:: tensorpack.predict.dataset
:members:
:undoc-members:
:show-inheritance:
Module contents
---------------
.. automodule:: tensorpack.predict
:members:
:undoc-members:
:show-inheritance:
......@@ -196,7 +196,7 @@ if __name__ == '__main__':
cfg = PredictConfig(
model=Model(),
session_init=SaverRestore(args.load),
input_var_names=['state']
input_var_names=['state'],
output_var_names=['fct/output:0'])
if args.task == 'play':
play_model(cfg)
......
......@@ -16,7 +16,7 @@ from ..utils import logger
from ..utils.timer import *
from ..tfutils import *
from .common import *
from .base import OfflinePredictor
try:
if six.PY2:
......@@ -24,7 +24,7 @@ try:
else:
from concurrent.futures import Future
except ImportError:
logger.warn("Cannot import Future in either tornado.concurrent or py3 standard lib. MultiThreadAsyncPredictor won't be available.")
logger.warn("Cannot import Future in tornado.concurrent. MultiThreadAsyncPredictor won't be available.")
__all__ = ['MultiProcessPredictWorker', 'MultiProcessQueuePredictWorker']
else:
__all__ = ['MultiProcessPredictWorker', 'MultiProcessQueuePredictWorker',
......@@ -32,41 +32,31 @@ else:
class MultiProcessPredictWorker(multiprocessing.Process):
""" Base class for predict worker that runs offline in multiprocess"""
def __init__(self, idx, gpuid, config):
def __init__(self, idx, config):
"""
:param idx: index of the worker. the 0th worker will print log.
:param gpuid: absolute id of the GPU to be used. set to -1 to use CPU.
:param config: a `PredictConfig`
"""
super(MultiProcessPredictWorker, self).__init__()
self.idx = idx
self.gpuid = gpuid
self.config = config
def _init_runtime(self):
if self.gpuid >= 0:
logger.info("Worker {} uses GPU {}".format(self.idx, self.gpuid))
os.environ['CUDA_VISIBLE_DEVICES'] = str(self.gpuid)
else:
logger.info("Worker {} uses CPU".format(self.idx))
os.environ['CUDA_VISIBLE_DEVICES'] = ''
G = tf.Graph() # build a graph for each process, because they don't need to share anything
with G.as_default():
if self.idx != 0:
from tensorpack.models._common import disable_layer_logging
disable_layer_logging()
self.func = get_predict_func(self.config)
if self.idx == 0:
describe_model()
if self.idx != 0:
from tensorpack.models._common import disable_layer_logging
disable_layer_logging()
self.func = OfflinePredictor(self.config)
if self.idx == 0:
describe_model()
class MultiProcessQueuePredictWorker(MultiProcessPredictWorker):
""" An offline predictor worker that takes input and produces output by queue"""
def __init__(self, idx, gpuid, inqueue, outqueue, config):
def __init__(self, idx, inqueue, outqueue, config):
"""
:param inqueue: input queue to get data point. elements are (task_id, dp)
:param outqueue: output queue put result. elements are (task_id, output)
"""
super(MultiProcessQueuePredictWorker, self).__init__(idx, gpuid, config)
super(MultiProcessQueuePredictWorker, self).__init__(idx, config)
self.inqueue = inqueue
self.outqueue = outqueue
assert isinstance(self.inqueue, multiprocessing.Queue)
......
......@@ -3,7 +3,7 @@
# File: dataset.py
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
from six.moves import range
from six.moves import range, zip
from tqdm import tqdm
from abc import ABCMeta, abstractmethod
import multiprocessing
......@@ -14,7 +14,8 @@ from ..dataflow.dftools import dataflow_to_process_queue
from ..utils.concurrency import ensure_proc_terminate, OrderedResultGatherProc, DIE
from .concurrency import MultiProcessQueuePredictWorker
from .common import *
from .common import PredictConfig
from .base import OfflinePredictor
__all__ = ['DatasetPredictorBase', 'SimpleDatasetPredictor',
'MultiProcessDatasetPredictor']
......@@ -34,7 +35,7 @@ class DatasetPredictorBase(object):
@abstractmethod
def get_result(self):
""" Generate (inpupt, output) pair of output, for each input in dataset"""
""" A generator function, produce output for each input in dataset"""
pass
def get_all_result(self):
......@@ -49,17 +50,14 @@ class SimpleDatasetPredictor(DatasetPredictorBase):
"""
def __init__(self, config, dataset):
super(SimpleDatasetPredictor, self).__init__(config, dataset)
self.func = get_predict_func(config)
self.predictor = OfflinePredictor(config)
def get_result(self):
""" A generator to produce prediction for each data"""
with tqdm(total=self.dataset.size()) as pbar:
for dp in self.dataset.get_data():
res = self.func(dp)
if self.config.return_input:
yield (dp, res)
else:
yield res
res = self.predictor(dp)
yield res
pbar.update()
class MultiProcessDatasetPredictor(DatasetPredictorBase):
......@@ -69,11 +67,11 @@ class MultiProcessDatasetPredictor(DatasetPredictorBase):
:param nr_proc: number of processes to use
:param use_gpu: use GPU or CPU.
If GPU, then nr_proc cannot be larger than the total number of GPUs available
in CUDA_VISIBLE_DEVICES or in the system.
If GPU, then nr_proc cannot be more than what's in CUDA_VISIBLE_DEVICES
"""
assert config.return_input == False, "return_input not supported for MultiProcessDatasetPredictor"
assert nr_proc > 1
if config.return_input:
logger.warn("Using the option `return_input` in MultiProcessDatasetPredictor might be slow")
assert nr_proc > 1, nr_proc
super(MultiProcessDatasetPredictor, self).__init__(config, dataset)
self.nr_proc = nr_proc
......@@ -91,7 +89,7 @@ class MultiProcessDatasetPredictor(DatasetPredictorBase):
# TODO number of GPUs not checked
gpus = list(range(self.nr_gpu))
else:
gpus = [-1] * self.nr_proc
gpus = [''] * self.nr_proc
self.workers = [MultiProcessQueuePredictWorker(
i, gpus[i], self.inqueue, self.outqueue, self.config)
for i in range(self.nr_proc)]
......@@ -100,7 +98,13 @@ class MultiProcessDatasetPredictor(DatasetPredictorBase):
# setup all the procs
self.inqueue_proc.start()
for p in self.workers: p.start()
for p, gpuid in zip(self.workers, gpus):
if gpuid == '':
logger.info("Worker {} uses CPU".format(p.idx))
else:
logger.info("Worker {} uses GPU {}".format(p.idx, gpuid))
with change_gpu(gpuid):
p.start()
self.result_queue.start()
ensure_proc_terminate(self.workers + [self.result_queue, self.inqueue_proc])
......
......@@ -211,7 +211,7 @@ class QueueInputTrainer(Trainer):
def get_predict_func(self, input_names, output_names, tower=0):
"""
:param tower: return the kth predict_func
:returns: a predictor function
:returns: an `OnlinePredictor`
"""
tower = self.predict_tower[tower % len(self.predict_tower)]
raw_input_vars = get_vars_by_names(input_names)
......@@ -220,7 +220,7 @@ class QueueInputTrainer(Trainer):
return OnlinePredictor(self.sess, raw_input_vars, output_vars)
def get_predict_funcs(self, input_names, output_names, n):
""" return n predicts functions evenly on each predict_tower"""
""" return n predictors evenly on each predict_tower"""
return [self.get_predict_func(input_names, output_names, k)
for k in range(n)]
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment