Commit e0190688 authored by Yuxin Wu's avatar Yuxin Wu

apidoc for predict/ and RL/

parent 06ea1c0a
...@@ -73,6 +73,7 @@ extensions = [ ...@@ -73,6 +73,7 @@ extensions = [
] ]
napoleon_google_docstring = True napoleon_google_docstring = True
napoleon_include_init_with_doc = True napoleon_include_init_with_doc = True
napoleon_include_special_with_doc = True
napoleon_numpy_docstring = False napoleon_numpy_docstring = False
napoleon_use_rtype = False napoleon_use_rtype = False
......
...@@ -15,14 +15,16 @@ class PreventStuckPlayer(ProxyPlayer): ...@@ -15,14 +15,16 @@ class PreventStuckPlayer(ProxyPlayer):
""" Prevent the player from getting stuck (repeating a no-op) """ Prevent the player from getting stuck (repeating a no-op)
by inserting a different action. Useful in games such as Atari Breakout by inserting a different action. Useful in games such as Atari Breakout
where the agent needs to press the 'start' button to start playing. where the agent needs to press the 'start' button to start playing.
It does auto-reset, but doesn't auto-restart the underlying player.
""" """
# TODO hash the state as well? # TODO hash the state as well?
def __init__(self, player, nr_repeat, action): def __init__(self, player, nr_repeat, action):
""" """
It does auto-reset, but doesn't auto-restart the underlying player. Args:
:param nr_repeat: trigger the 'action' after this many of repeated action nr_repeat (int): trigger the 'action' after this many of repeated action.
:param action: the action to be triggered to get out of stuck action: the action to be triggered to get out of stuck.
""" """
super(PreventStuckPlayer, self).__init__(player) super(PreventStuckPlayer, self).__init__(player)
self.act_que = deque(maxlen=nr_repeat) self.act_que = deque(maxlen=nr_repeat)
...@@ -44,10 +46,14 @@ class PreventStuckPlayer(ProxyPlayer): ...@@ -44,10 +46,14 @@ class PreventStuckPlayer(ProxyPlayer):
class LimitLengthPlayer(ProxyPlayer): class LimitLengthPlayer(ProxyPlayer):
""" Limit the total number of actions in an episode. """ Limit the total number of actions in an episode.
Will auto restart the underlying player on timeout Will restart the underlying player on timeout.
""" """
def __init__(self, player, limit): def __init__(self, player, limit):
"""
Args:
limit(int): the time limit
"""
super(LimitLengthPlayer, self).__init__(player) super(LimitLengthPlayer, self).__init__(player)
self.limit = limit self.limit = limit
self.cnt = 0 self.cnt = 0
...@@ -70,7 +76,8 @@ class LimitLengthPlayer(ProxyPlayer): ...@@ -70,7 +76,8 @@ class LimitLengthPlayer(ProxyPlayer):
class AutoRestartPlayer(ProxyPlayer): class AutoRestartPlayer(ProxyPlayer):
""" Auto-restart the player on episode ends, """ Auto-restart the player on episode ends,
in case some player wasn't designed to do so. """ in case some player wasn't designed to do so.
"""
def action(self, act): def action(self, act):
r, isOver = self.player.action(act) r, isOver = self.player.action(act)
...@@ -81,8 +88,13 @@ class AutoRestartPlayer(ProxyPlayer): ...@@ -81,8 +88,13 @@ class AutoRestartPlayer(ProxyPlayer):
class MapPlayerState(ProxyPlayer): class MapPlayerState(ProxyPlayer):
""" Map the state of the underlying player by a function. """
def __init__(self, player, func): def __init__(self, player, func):
"""
Args:
func: takes the old state and return a new state.
"""
super(MapPlayerState, self).__init__(player) super(MapPlayerState, self).__init__(player)
self.func = func self.func = func
......
...@@ -15,6 +15,7 @@ __all__ = ['RLEnvironment', 'NaiveRLEnvironment', 'ProxyPlayer', ...@@ -15,6 +15,7 @@ __all__ = ['RLEnvironment', 'NaiveRLEnvironment', 'ProxyPlayer',
@six.add_metaclass(ABCMeta) @six.add_metaclass(ABCMeta)
class RLEnvironment(object): class RLEnvironment(object):
""" Base class of RL environment. """
def __init__(self): def __init__(self):
self.reset_stat() self.reset_stat()
...@@ -29,8 +30,11 @@ class RLEnvironment(object): ...@@ -29,8 +30,11 @@ class RLEnvironment(object):
def action(self, act): def action(self, act):
""" """
Perform an action. Will automatically start a new episode if isOver==True Perform an action. Will automatically start a new episode if isOver==True
:param act: the action
:returns: (reward, isOver) Args:
act: the action
Returns:
tuple: (reward, isOver)
""" """
def restart_episode(self): def restart_episode(self):
...@@ -38,22 +42,26 @@ class RLEnvironment(object): ...@@ -38,22 +42,26 @@ class RLEnvironment(object):
raise NotImplementedError() raise NotImplementedError()
def finish_episode(self): def finish_episode(self):
""" get called when an episode finished""" """ Get called when an episode finished"""
pass pass
def get_action_space(self): def get_action_space(self):
""" return an `ActionSpace` instance""" """ Returns:
:class:`ActionSpace` """
raise NotImplementedError() raise NotImplementedError()
def reset_stat(self): def reset_stat(self):
""" reset all statistics counter""" """ Reset all statistics counter"""
self.stats = defaultdict(list) self.stats = defaultdict(list)
def play_one_episode(self, func, stat='score'): def play_one_episode(self, func, stat='score'):
""" play one episode for eval. """ Play one episode for eval.
:param func: call with the state and return an action
:param stat: a key or list of keys in stats Args:
:returns: the stat(s) after running this episode func: the policy function. Takes a state and returns an action.
stat: a key or list of keys in stats to return.
Returns:
the stat(s) after running this episode
""" """
if not isinstance(stat, list): if not isinstance(stat, list):
stat = [stat] stat = [stat]
...@@ -101,7 +109,7 @@ class DiscreteActionSpace(ActionSpace): ...@@ -101,7 +109,7 @@ class DiscreteActionSpace(ActionSpace):
class NaiveRLEnvironment(RLEnvironment): class NaiveRLEnvironment(RLEnvironment):
""" for testing only""" """ For testing only"""
def __init__(self): def __init__(self):
self.k = 0 self.k = 0
...@@ -116,7 +124,7 @@ class NaiveRLEnvironment(RLEnvironment): ...@@ -116,7 +124,7 @@ class NaiveRLEnvironment(RLEnvironment):
class ProxyPlayer(RLEnvironment): class ProxyPlayer(RLEnvironment):
""" Serve as a proxy another player """ """ Serve as a proxy to another player """
def __init__(self, player): def __init__(self, player):
self.player = player self.player = player
......
...@@ -23,10 +23,14 @@ Experience = namedtuple('Experience', ...@@ -23,10 +23,14 @@ Experience = namedtuple('Experience',
class ExpReplay(DataFlow, Callback): class ExpReplay(DataFlow, Callback):
""" """
Implement experience replay in the paper Implement experience replay in the paper
`Human-level control through deep reinforcement learning`. `Human-level control through deep reinforcement learning
<http://www.nature.com/nature/journal/v518/n7540/full/nature14236.html>`_.
This implementation provides the interface as an DataFlow. This implementation provides the interface as a :class:`DataFlow`.
This DataFlow is not fork-safe (doesn't support multiprocess prefetching) This DataFlow is __not__ fork-safe (thus doesn't support multiprocess prefetching).
This implementation only works with Q-learning. It assumes that state is
batch-able, and the network takes batched inputs.
""" """
def __init__(self, def __init__(self,
...@@ -43,12 +47,14 @@ class ExpReplay(DataFlow, Callback): ...@@ -43,12 +47,14 @@ class ExpReplay(DataFlow, Callback):
history_len=1 history_len=1
): ):
""" """
:param predictor: a callabale running the up-to-date network. Args:
called with a state, return a distribution. predictor_io_names (tuple of list of str): input/output names to
:param player: an `RLEnvironment` predict Q value from state.
:param history_len: length of history frames to concat. zero-filled initial frames player (RLEnvironment): the player.
:param update_frequency: number of new transitions to add to memory history_len (int): length of history frames to concat. Zero-filled
after sampling a batch of transitions for training initial frames.
update_frequency (int): number of new transitions to add to memory
after sampling a batch of transitions for training.
""" """
init_memory_size = int(init_memory_size) init_memory_size = int(init_memory_size)
......
...@@ -30,10 +30,17 @@ _ENV_LOCK = threading.Lock() ...@@ -30,10 +30,17 @@ _ENV_LOCK = threading.Lock()
class GymEnv(RLEnvironment): class GymEnv(RLEnvironment):
""" """
An OpenAI/gym wrapper. Can optionally auto restart. An OpenAI/gym wrapper. Can optionally auto restart.
Only support discrete action space now Only support discrete action space for now.
""" """
def __init__(self, name, dumpdir=None, viz=False, auto_restart=True): def __init__(self, name, dumpdir=None, viz=False, auto_restart=True):
"""
Args:
name (str): the gym environment name.
dumpdir (str): the directory to dump recordings to.
viz (bool): whether to start visualization.
auto_restart (bool): whether to restart after episode ends.
"""
with _ENV_LOCK: with _ENV_LOCK:
self.gymenv = gym.make(name) self.gymenv = gym.make(name)
if dumpdir: if dumpdir:
......
...@@ -11,14 +11,15 @@ __all__ = ['HistoryFramePlayer'] ...@@ -11,14 +11,15 @@ __all__ = ['HistoryFramePlayer']
class HistoryFramePlayer(ProxyPlayer): class HistoryFramePlayer(ProxyPlayer):
""" Include history frames in state, or use black images """ Include history frames in state, or use black images.
Assume player will do auto-restart. It assumes the underlying player will do auto-restart.
""" """
def __init__(self, player, hist_len): def __init__(self, player, hist_len):
""" """
:param hist_len: total length of the state, including the current Args:
and `hist_len-1` history hist_len (int): total length of the state, including the current
and `hist_len-1` history.
""" """
super(HistoryFramePlayer, self).__init__(player) super(HistoryFramePlayer, self).__init__(player)
self.history = deque(maxlen=hist_len) self.history = deque(maxlen=hist_len)
......
...@@ -92,6 +92,10 @@ class LinearWrap(object): ...@@ -92,6 +92,10 @@ class LinearWrap(object):
return LinearWrap(ret) return LinearWrap(ret)
def __call__(self): def __call__(self):
"""
Returns:
tf.Tensor: the underlying wrapped tensor.
"""
return self._t return self._t
def tensor(self): def tensor(self):
......
...@@ -11,8 +11,8 @@ from ..utils.naming import PREDICT_TOWER ...@@ -11,8 +11,8 @@ from ..utils.naming import PREDICT_TOWER
from ..utils import logger from ..utils import logger
from ..tfutils import get_tensors_by_names, TowerContext from ..tfutils import get_tensors_by_names, TowerContext
__all__ = ['OnlinePredictor', 'OfflinePredictor', __all__ = ['PredictorBase', 'AsyncPredictorBase',
'AsyncPredictorBase', 'OnlinePredictor', 'OfflinePredictor',
'MultiTowerOfflinePredictor', 'build_multi_tower_prediction_graph', 'MultiTowerOfflinePredictor', 'build_multi_tower_prediction_graph',
'DataParallelOfflinePredictor'] 'DataParallelOfflinePredictor']
...@@ -20,15 +20,29 @@ __all__ = ['OnlinePredictor', 'OfflinePredictor', ...@@ -20,15 +20,29 @@ __all__ = ['OnlinePredictor', 'OfflinePredictor',
@six.add_metaclass(ABCMeta) @six.add_metaclass(ABCMeta)
class PredictorBase(object): class PredictorBase(object):
""" """
Available attributes: Base class for all predictors.
session
return_input Attributes:
session (tf.Session):
return_input (bool): whether the call will also return (inputs, outputs)
or just outpus
""" """
def __call__(self, *args): def __call__(self, *args):
""" """
if len(args) == 1, assume args[0] is a datapoint (a list) Call the predictor on some inputs.
else, assume args is a datapoinnt
If ``len(args) == 1``, assume ``args[0]`` is a datapoint (a list).
otherwise, assume ``args`` is a datapoinnt
Examples:
When you have a predictor which takes a datapoint [e1, e2], you
can call it in two ways:
.. code-block:: python
predictor(e1, e2)
predictor([e1, e2])
""" """
if len(args) != 1: if len(args) != 1:
dp = args dp = args
...@@ -49,15 +63,18 @@ class PredictorBase(object): ...@@ -49,15 +63,18 @@ class PredictorBase(object):
class AsyncPredictorBase(PredictorBase): class AsyncPredictorBase(PredictorBase):
""" Base class for all async predictors. """
@abstractmethod @abstractmethod
def put_task(self, dp, callback=None): def put_task(self, dp, callback=None):
""" """
:param dp: A data point (list of component) as inputs. Args:
(It should be either batched or not batched depending on the predictor implementation) dp (list): A datapoint as inputs. It could be either batched or not
:param callback: a thread-safe callback to get called with batched depending on the predictor implementation).
either outputs or (inputs, outputs) callback: a thread-safe callback to get called with
:return: a Future of results either outputs or (inputs, outputs).
Returns:
concurrent.futures.Future: a Future of results
""" """
@abstractmethod @abstractmethod
...@@ -72,8 +89,16 @@ class AsyncPredictorBase(PredictorBase): ...@@ -72,8 +89,16 @@ class AsyncPredictorBase(PredictorBase):
class OnlinePredictor(PredictorBase): class OnlinePredictor(PredictorBase):
""" A predictor which directly use an existing session. """
def __init__(self, sess, input_tensors, output_tensors, return_input=False): def __init__(self, sess, input_tensors, output_tensors, return_input=False):
"""
Args:
sess (tf.Session): an existing session.
input_tensors (list): list of names.
output_tensors (list): list of names.
return_input (bool): same as :attr:`PredictorBase.return_input`.
"""
self.session = sess self.session = sess
self.return_input = return_input self.return_input = return_input
...@@ -89,9 +114,13 @@ class OnlinePredictor(PredictorBase): ...@@ -89,9 +114,13 @@ class OnlinePredictor(PredictorBase):
class OfflinePredictor(OnlinePredictor): class OfflinePredictor(OnlinePredictor):
""" Build a predictor from a given config, in an independent graph""" """ A predictor built from a given config, in a new graph. """
def __init__(self, config): def __init__(self, config):
"""
Args:
config (PredictConfig): the config to use.
"""
self.graph = tf.Graph() self.graph = tf.Graph()
with self.graph.as_default(): with self.graph.as_default():
input_placehdrs = config.model.get_input_vars() input_placehdrs = config.model.get_input_vars()
...@@ -109,8 +138,10 @@ class OfflinePredictor(OnlinePredictor): ...@@ -109,8 +138,10 @@ class OfflinePredictor(OnlinePredictor):
def build_multi_tower_prediction_graph(build_tower_fn, towers): def build_multi_tower_prediction_graph(build_tower_fn, towers):
""" """
:param build_tower_fn: the function to be called inside each tower, taking tower as the argument Args:
:param towers: a list of gpu relative id. build_tower_fn: a function that will be called inside each tower,
taking tower id as the argument.
towers: a list of relative GPU id.
""" """
for k in towers: for k in towers:
logger.info( logger.info(
...@@ -122,8 +153,14 @@ def build_multi_tower_prediction_graph(build_tower_fn, towers): ...@@ -122,8 +153,14 @@ def build_multi_tower_prediction_graph(build_tower_fn, towers):
class MultiTowerOfflinePredictor(OnlinePredictor): class MultiTowerOfflinePredictor(OnlinePredictor):
""" A multi-tower multi-GPU predictor. """
def __init__(self, config, towers): def __init__(self, config, towers):
"""
Args:
config (PredictConfig): the config to use.
towers: a list of relative GPU id.
"""
self.graph = tf.Graph() self.graph = tf.Graph()
self.predictors = [] self.predictors = []
with self.graph.as_default(): with self.graph.as_default():
...@@ -149,12 +186,24 @@ class MultiTowerOfflinePredictor(OnlinePredictor): ...@@ -149,12 +186,24 @@ class MultiTowerOfflinePredictor(OnlinePredictor):
return self.predictors[0]._do_call(dp) return self.predictors[0]._do_call(dp)
def get_predictors(self, n): def get_predictors(self, n):
"""
Returns:
PredictorBase: the nth predictor on the nth GPU.
"""
return [self.predictors[k % len(self.predictors)] for k in range(n)] return [self.predictors[k % len(self.predictors)] for k in range(n)]
class DataParallelOfflinePredictor(OnlinePredictor): class DataParallelOfflinePredictor(OnlinePredictor):
""" A data-parallel predictor.
It runs different towers in parallel.
"""
def __init__(self, config, towers): def __init__(self, config, towers):
"""
Args:
config (PredictConfig): the config to use.
towers: a list of relative GPU id.
"""
self.graph = tf.Graph() self.graph = tf.Graph()
with self.graph.as_default(): with self.graph.as_default():
sess = tf.Session(config=config.session_config) sess = tf.Session(config=config.session_config)
......
...@@ -2,7 +2,6 @@ ...@@ -2,7 +2,6 @@
# File: common.py # File: common.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com> # Author: Yuxin Wu <ppwwyyxx@gmail.com>
from collections import namedtuple
import six import six
from tensorpack.models import ModelDesc from tensorpack.models import ModelDesc
...@@ -10,25 +9,20 @@ from ..tfutils import get_default_sess_config ...@@ -10,25 +9,20 @@ from ..tfutils import get_default_sess_config
from ..tfutils.sessinit import SessionInit, JustCurrentSession from ..tfutils.sessinit import SessionInit, JustCurrentSession
from .base import OfflinePredictor from .base import OfflinePredictor
__all__ = ['PredictConfig', 'get_predict_func', 'PredictResult'] __all__ = ['PredictConfig', 'get_predict_func']
PredictResult = namedtuple('PredictResult', ['input', 'output'])
class PredictConfig(object): class PredictConfig(object):
def __init__(self, **kwargs): def __init__(self, **kwargs):
""" """
The config used by `get_predict_func`. Args:
session_init (SessionInit): how to initialize variables of the session.
:param session_init: a `utils.sessinit.SessionInit` instance to model (ModelDesc): the model to use.
initialize variables of a session. input_names (list): a list of input tensor names.
:param model: a `ModelDesc` instance output_names (list): a list of names of the output tensors to predict, the
:param input_names: a list of input variable names. tensors can be any computable tensor in the graph.
:param output_names: a list of names of the output tensors to predict, the return_input: same as in :attr:`PredictorBase.return_input`.
variables can be any computable tensor in the graph.
Predict specific output might not require all input variables.
:param return_input: whether to return (input, output) pair or just output. default to False.
""" """
# TODO use the name "tensor" instead of "variable" # TODO use the name "tensor" instead of "variable"
def assert_type(v, tp): def assert_type(v, tp):
...@@ -68,10 +62,6 @@ class PredictConfig(object): ...@@ -68,10 +62,6 @@ class PredictConfig(object):
def get_predict_func(config): def get_predict_func(config):
""" """
Produce a offline predictor run inside a new session. Equivalent to ``OfflinePredictor(config)``.
:param config: a `PredictConfig` instance.
:returns: A callable predictor that takes a list of input values, and return
a list of output values defined in ``config.output_var_names``.
""" """
return OfflinePredictor(config) return OfflinePredictor(config)
...@@ -32,8 +32,9 @@ class MultiProcessPredictWorker(multiprocessing.Process): ...@@ -32,8 +32,9 @@ class MultiProcessPredictWorker(multiprocessing.Process):
def __init__(self, idx, config): def __init__(self, idx, config):
""" """
:param idx: index of the worker. the 0th worker will print log. Args:
:param config: a `PredictConfig` idx (int): index of the worker. the 0th worker will print log.
config (PredictConfig): the config to use.
""" """
super(MultiProcessPredictWorker, self).__init__() super(MultiProcessPredictWorker, self).__init__()
self.idx = idx self.idx = idx
...@@ -53,12 +54,17 @@ class MultiProcessPredictWorker(multiprocessing.Process): ...@@ -53,12 +54,17 @@ class MultiProcessPredictWorker(multiprocessing.Process):
class MultiProcessQueuePredictWorker(MultiProcessPredictWorker): class MultiProcessQueuePredictWorker(MultiProcessPredictWorker):
""" An offline predictor worker that takes input and produces output by queue""" """
An offline predictor worker that takes input and produces output by queue.
Each process will exit when they see :class:`DIE`.
"""
def __init__(self, idx, inqueue, outqueue, config): def __init__(self, idx, inqueue, outqueue, config):
""" """
:param inqueue: input queue to get data point. elements are (task_id, dp) Args:
:param outqueue: output queue put result. elements are (task_id, output) idx, config: same as in :class:`MultiProcessPredictWorker`.
inqueue (multiprocessing.Queue): input queue to get data point. elements are (task_id, dp)
outqueue (multiprocessing.Queue): output queue to put result. elements are (task_id, output)
""" """
super(MultiProcessQueuePredictWorker, self).__init__(idx, config) super(MultiProcessQueuePredictWorker, self).__init__(idx, config)
self.inqueue = inqueue self.inqueue = inqueue
...@@ -125,12 +131,16 @@ class PredictorWorkerThread(threading.Thread): ...@@ -125,12 +131,16 @@ class PredictorWorkerThread(threading.Thread):
class MultiThreadAsyncPredictor(AsyncPredictorBase): class MultiThreadAsyncPredictor(AsyncPredictorBase):
""" """
An multithread online async predictor which run a list of PredictorBase. An multithread online async predictor which runs a list of PredictorBase.
It would do an extra batching internally. It would do an extra batching internally.
""" """
def __init__(self, predictors, batch_size=5): def __init__(self, predictors, batch_size=5):
""" :param predictors: a list of OnlinePredictor""" """
Args:
predictors (list): a list of OnlinePredictor avaiable to use.
batch_size (int): the maximum of an internal batch.
"""
assert len(predictors) assert len(predictors)
for k in predictors: for k in predictors:
# assert isinstance(k, OnlinePredictor), type(k) # assert isinstance(k, OnlinePredictor), type(k)
...@@ -156,7 +166,7 @@ class MultiThreadAsyncPredictor(AsyncPredictorBase): ...@@ -156,7 +166,7 @@ class MultiThreadAsyncPredictor(AsyncPredictorBase):
def put_task(self, dp, callback=None): def put_task(self, dp, callback=None):
""" """
dp must be non-batched, i.e. single instance Same as in :meth:`AsyncPredictorBase.put_task`.
""" """
f = Future() f = Future()
if callback is not None: if callback is not None:
......
...@@ -25,11 +25,15 @@ __all__ = ['DatasetPredictorBase', 'SimpleDatasetPredictor', ...@@ -25,11 +25,15 @@ __all__ = ['DatasetPredictorBase', 'SimpleDatasetPredictor',
@six.add_metaclass(ABCMeta) @six.add_metaclass(ABCMeta)
class DatasetPredictorBase(object): class DatasetPredictorBase(object):
""" Base class for dataset predictors.
These are predictors which run over a :class:`DataFlow`.
"""
def __init__(self, config, dataset): def __init__(self, config, dataset):
""" """
:param config: a `PredictConfig` instance. Args:
:param dataset: a `DataFlow` instance. config (PredictConfig): the config of predictor.
dataset (DataFlow): the DataFlow to run on.
""" """
assert isinstance(dataset, DataFlow) assert isinstance(dataset, DataFlow)
assert isinstance(config, PredictConfig) assert isinstance(config, PredictConfig)
...@@ -38,27 +42,29 @@ class DatasetPredictorBase(object): ...@@ -38,27 +42,29 @@ class DatasetPredictorBase(object):
@abstractmethod @abstractmethod
def get_result(self): def get_result(self):
""" A generator function, produce output for each input in dataset""" """
Yields:
output for each datapoint in the DataFlow.
"""
pass pass
def get_all_result(self): def get_all_result(self):
""" """
Run over the dataset and return a list of all predictions. Returns:
list: all outputs for all datapoints in the DataFlow.
""" """
return list(self.get_result()) return list(self.get_result())
class SimpleDatasetPredictor(DatasetPredictorBase): class SimpleDatasetPredictor(DatasetPredictorBase):
""" """
Run the predict_config on a given `DataFlow`. Simply create one predictor and run it on the DataFlow.
""" """
def __init__(self, config, dataset): def __init__(self, config, dataset):
super(SimpleDatasetPredictor, self).__init__(config, dataset) super(SimpleDatasetPredictor, self).__init__(config, dataset)
self.predictor = OfflinePredictor(config) self.predictor = OfflinePredictor(config)
def get_result(self): def get_result(self):
""" A generator to produce prediction for each data"""
self.dataset.reset_state() self.dataset.reset_state()
try: try:
sz = self.dataset.size() sz = self.dataset.size()
...@@ -70,20 +76,26 @@ class SimpleDatasetPredictor(DatasetPredictorBase): ...@@ -70,20 +76,26 @@ class SimpleDatasetPredictor(DatasetPredictorBase):
yield res yield res
pbar.update() pbar.update()
# TODO allow unordered
class MultiProcessDatasetPredictor(DatasetPredictorBase): class MultiProcessDatasetPredictor(DatasetPredictorBase):
"""
Run prediction in multiprocesses, on either CPU or GPU.
Each process fetch datapoints as tasks and run predictions independently.
"""
# TODO allow unordered
def __init__(self, config, dataset, nr_proc, use_gpu=True, ordered=True): def __init__(self, config, dataset, nr_proc, use_gpu=True, ordered=True):
""" """
Run prediction in multiprocesses, on either CPU or GPU. Mix mode not supported. Args:
config: same as in :class:`DatasetPredictorBase`.
:param nr_proc: number of processes to use dataset: same as in :class:`DatasetPredictorBase`.
:param use_gpu: use GPU or CPU. nr_proc (int): number of processes to use
If GPU, then nr_proc cannot be more than what's in CUDA_VISIBLE_DEVICES use_gpu (bool): use GPU or CPU.
:param ordered: produce results with the original order of the If GPU, then ``nr_proc`` cannot be more than what's in
dataflow. a bit slower. CUDA_VISIBLE_DEVICES.
ordered (bool): produce outputs in the original order of the
datapoints. This will be a bit slower. Otherwise, :meth:`get_result` will produce
outputs in any order.
""" """
if config.return_input: if config.return_input:
logger.warn("Using the option `return_input` in MultiProcessDatasetPredictor might be slow") logger.warn("Using the option `return_input` in MultiProcessDatasetPredictor might be slow")
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment