Commit e0190688 authored by Yuxin Wu's avatar Yuxin Wu

apidoc for predict/ and RL/

parent 06ea1c0a
......@@ -73,6 +73,7 @@ extensions = [
]
napoleon_google_docstring = True
napoleon_include_init_with_doc = True
napoleon_include_special_with_doc = True
napoleon_numpy_docstring = False
napoleon_use_rtype = False
......
......@@ -15,14 +15,16 @@ class PreventStuckPlayer(ProxyPlayer):
""" Prevent the player from getting stuck (repeating a no-op)
by inserting a different action. Useful in games such as Atari Breakout
where the agent needs to press the 'start' button to start playing.
It does auto-reset, but doesn't auto-restart the underlying player.
"""
# TODO hash the state as well?
def __init__(self, player, nr_repeat, action):
"""
It does auto-reset, but doesn't auto-restart the underlying player.
:param nr_repeat: trigger the 'action' after this many of repeated action
:param action: the action to be triggered to get out of stuck
Args:
nr_repeat (int): trigger the 'action' after this many of repeated action.
action: the action to be triggered to get out of stuck.
"""
super(PreventStuckPlayer, self).__init__(player)
self.act_que = deque(maxlen=nr_repeat)
......@@ -44,10 +46,14 @@ class PreventStuckPlayer(ProxyPlayer):
class LimitLengthPlayer(ProxyPlayer):
""" Limit the total number of actions in an episode.
Will auto restart the underlying player on timeout
Will restart the underlying player on timeout.
"""
def __init__(self, player, limit):
"""
Args:
limit(int): the time limit
"""
super(LimitLengthPlayer, self).__init__(player)
self.limit = limit
self.cnt = 0
......@@ -70,7 +76,8 @@ class LimitLengthPlayer(ProxyPlayer):
class AutoRestartPlayer(ProxyPlayer):
""" Auto-restart the player on episode ends,
in case some player wasn't designed to do so. """
in case some player wasn't designed to do so.
"""
def action(self, act):
r, isOver = self.player.action(act)
......@@ -81,8 +88,13 @@ class AutoRestartPlayer(ProxyPlayer):
class MapPlayerState(ProxyPlayer):
""" Map the state of the underlying player by a function. """
def __init__(self, player, func):
"""
Args:
func: takes the old state and return a new state.
"""
super(MapPlayerState, self).__init__(player)
self.func = func
......
......@@ -15,6 +15,7 @@ __all__ = ['RLEnvironment', 'NaiveRLEnvironment', 'ProxyPlayer',
@six.add_metaclass(ABCMeta)
class RLEnvironment(object):
""" Base class of RL environment. """
def __init__(self):
self.reset_stat()
......@@ -29,8 +30,11 @@ class RLEnvironment(object):
def action(self, act):
"""
Perform an action. Will automatically start a new episode if isOver==True
:param act: the action
:returns: (reward, isOver)
Args:
act: the action
Returns:
tuple: (reward, isOver)
"""
def restart_episode(self):
......@@ -38,22 +42,26 @@ class RLEnvironment(object):
raise NotImplementedError()
def finish_episode(self):
""" get called when an episode finished"""
""" Get called when an episode finished"""
pass
def get_action_space(self):
""" return an `ActionSpace` instance"""
""" Returns:
:class:`ActionSpace` """
raise NotImplementedError()
def reset_stat(self):
""" reset all statistics counter"""
""" Reset all statistics counter"""
self.stats = defaultdict(list)
def play_one_episode(self, func, stat='score'):
""" play one episode for eval.
:param func: call with the state and return an action
:param stat: a key or list of keys in stats
:returns: the stat(s) after running this episode
""" Play one episode for eval.
Args:
func: the policy function. Takes a state and returns an action.
stat: a key or list of keys in stats to return.
Returns:
the stat(s) after running this episode
"""
if not isinstance(stat, list):
stat = [stat]
......@@ -101,7 +109,7 @@ class DiscreteActionSpace(ActionSpace):
class NaiveRLEnvironment(RLEnvironment):
""" for testing only"""
""" For testing only"""
def __init__(self):
self.k = 0
......@@ -116,7 +124,7 @@ class NaiveRLEnvironment(RLEnvironment):
class ProxyPlayer(RLEnvironment):
""" Serve as a proxy another player """
""" Serve as a proxy to another player """
def __init__(self, player):
self.player = player
......
......@@ -23,10 +23,14 @@ Experience = namedtuple('Experience',
class ExpReplay(DataFlow, Callback):
"""
Implement experience replay in the paper
`Human-level control through deep reinforcement learning`.
`Human-level control through deep reinforcement learning
<http://www.nature.com/nature/journal/v518/n7540/full/nature14236.html>`_.
This implementation provides the interface as an DataFlow.
This DataFlow is not fork-safe (doesn't support multiprocess prefetching)
This implementation provides the interface as a :class:`DataFlow`.
This DataFlow is __not__ fork-safe (thus doesn't support multiprocess prefetching).
This implementation only works with Q-learning. It assumes that state is
batch-able, and the network takes batched inputs.
"""
def __init__(self,
......@@ -43,12 +47,14 @@ class ExpReplay(DataFlow, Callback):
history_len=1
):
"""
:param predictor: a callabale running the up-to-date network.
called with a state, return a distribution.
:param player: an `RLEnvironment`
:param history_len: length of history frames to concat. zero-filled initial frames
:param update_frequency: number of new transitions to add to memory
after sampling a batch of transitions for training
Args:
predictor_io_names (tuple of list of str): input/output names to
predict Q value from state.
player (RLEnvironment): the player.
history_len (int): length of history frames to concat. Zero-filled
initial frames.
update_frequency (int): number of new transitions to add to memory
after sampling a batch of transitions for training.
"""
init_memory_size = int(init_memory_size)
......
......@@ -30,10 +30,17 @@ _ENV_LOCK = threading.Lock()
class GymEnv(RLEnvironment):
"""
An OpenAI/gym wrapper. Can optionally auto restart.
Only support discrete action space now
Only support discrete action space for now.
"""
def __init__(self, name, dumpdir=None, viz=False, auto_restart=True):
"""
Args:
name (str): the gym environment name.
dumpdir (str): the directory to dump recordings to.
viz (bool): whether to start visualization.
auto_restart (bool): whether to restart after episode ends.
"""
with _ENV_LOCK:
self.gymenv = gym.make(name)
if dumpdir:
......
......@@ -11,14 +11,15 @@ __all__ = ['HistoryFramePlayer']
class HistoryFramePlayer(ProxyPlayer):
""" Include history frames in state, or use black images
Assume player will do auto-restart.
""" Include history frames in state, or use black images.
It assumes the underlying player will do auto-restart.
"""
def __init__(self, player, hist_len):
"""
:param hist_len: total length of the state, including the current
and `hist_len-1` history
Args:
hist_len (int): total length of the state, including the current
and `hist_len-1` history.
"""
super(HistoryFramePlayer, self).__init__(player)
self.history = deque(maxlen=hist_len)
......
......@@ -92,6 +92,10 @@ class LinearWrap(object):
return LinearWrap(ret)
def __call__(self):
"""
Returns:
tf.Tensor: the underlying wrapped tensor.
"""
return self._t
def tensor(self):
......
......@@ -11,8 +11,8 @@ from ..utils.naming import PREDICT_TOWER
from ..utils import logger
from ..tfutils import get_tensors_by_names, TowerContext
__all__ = ['OnlinePredictor', 'OfflinePredictor',
'AsyncPredictorBase',
__all__ = ['PredictorBase', 'AsyncPredictorBase',
'OnlinePredictor', 'OfflinePredictor',
'MultiTowerOfflinePredictor', 'build_multi_tower_prediction_graph',
'DataParallelOfflinePredictor']
......@@ -20,15 +20,29 @@ __all__ = ['OnlinePredictor', 'OfflinePredictor',
@six.add_metaclass(ABCMeta)
class PredictorBase(object):
"""
Available attributes:
session
return_input
Base class for all predictors.
Attributes:
session (tf.Session):
return_input (bool): whether the call will also return (inputs, outputs)
or just outpus
"""
def __call__(self, *args):
"""
if len(args) == 1, assume args[0] is a datapoint (a list)
else, assume args is a datapoinnt
Call the predictor on some inputs.
If ``len(args) == 1``, assume ``args[0]`` is a datapoint (a list).
otherwise, assume ``args`` is a datapoinnt
Examples:
When you have a predictor which takes a datapoint [e1, e2], you
can call it in two ways:
.. code-block:: python
predictor(e1, e2)
predictor([e1, e2])
"""
if len(args) != 1:
dp = args
......@@ -49,15 +63,18 @@ class PredictorBase(object):
class AsyncPredictorBase(PredictorBase):
""" Base class for all async predictors. """
@abstractmethod
def put_task(self, dp, callback=None):
"""
:param dp: A data point (list of component) as inputs.
(It should be either batched or not batched depending on the predictor implementation)
:param callback: a thread-safe callback to get called with
either outputs or (inputs, outputs)
:return: a Future of results
Args:
dp (list): A datapoint as inputs. It could be either batched or not
batched depending on the predictor implementation).
callback: a thread-safe callback to get called with
either outputs or (inputs, outputs).
Returns:
concurrent.futures.Future: a Future of results
"""
@abstractmethod
......@@ -72,8 +89,16 @@ class AsyncPredictorBase(PredictorBase):
class OnlinePredictor(PredictorBase):
""" A predictor which directly use an existing session. """
def __init__(self, sess, input_tensors, output_tensors, return_input=False):
"""
Args:
sess (tf.Session): an existing session.
input_tensors (list): list of names.
output_tensors (list): list of names.
return_input (bool): same as :attr:`PredictorBase.return_input`.
"""
self.session = sess
self.return_input = return_input
......@@ -89,9 +114,13 @@ class OnlinePredictor(PredictorBase):
class OfflinePredictor(OnlinePredictor):
""" Build a predictor from a given config, in an independent graph"""
""" A predictor built from a given config, in a new graph. """
def __init__(self, config):
"""
Args:
config (PredictConfig): the config to use.
"""
self.graph = tf.Graph()
with self.graph.as_default():
input_placehdrs = config.model.get_input_vars()
......@@ -109,8 +138,10 @@ class OfflinePredictor(OnlinePredictor):
def build_multi_tower_prediction_graph(build_tower_fn, towers):
"""
:param build_tower_fn: the function to be called inside each tower, taking tower as the argument
:param towers: a list of gpu relative id.
Args:
build_tower_fn: a function that will be called inside each tower,
taking tower id as the argument.
towers: a list of relative GPU id.
"""
for k in towers:
logger.info(
......@@ -122,8 +153,14 @@ def build_multi_tower_prediction_graph(build_tower_fn, towers):
class MultiTowerOfflinePredictor(OnlinePredictor):
""" A multi-tower multi-GPU predictor. """
def __init__(self, config, towers):
"""
Args:
config (PredictConfig): the config to use.
towers: a list of relative GPU id.
"""
self.graph = tf.Graph()
self.predictors = []
with self.graph.as_default():
......@@ -149,12 +186,24 @@ class MultiTowerOfflinePredictor(OnlinePredictor):
return self.predictors[0]._do_call(dp)
def get_predictors(self, n):
"""
Returns:
PredictorBase: the nth predictor on the nth GPU.
"""
return [self.predictors[k % len(self.predictors)] for k in range(n)]
class DataParallelOfflinePredictor(OnlinePredictor):
""" A data-parallel predictor.
It runs different towers in parallel.
"""
def __init__(self, config, towers):
"""
Args:
config (PredictConfig): the config to use.
towers: a list of relative GPU id.
"""
self.graph = tf.Graph()
with self.graph.as_default():
sess = tf.Session(config=config.session_config)
......
......@@ -2,7 +2,6 @@
# File: common.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
from collections import namedtuple
import six
from tensorpack.models import ModelDesc
......@@ -10,25 +9,20 @@ from ..tfutils import get_default_sess_config
from ..tfutils.sessinit import SessionInit, JustCurrentSession
from .base import OfflinePredictor
__all__ = ['PredictConfig', 'get_predict_func', 'PredictResult']
PredictResult = namedtuple('PredictResult', ['input', 'output'])
__all__ = ['PredictConfig', 'get_predict_func']
class PredictConfig(object):
def __init__(self, **kwargs):
"""
The config used by `get_predict_func`.
:param session_init: a `utils.sessinit.SessionInit` instance to
initialize variables of a session.
:param model: a `ModelDesc` instance
:param input_names: a list of input variable names.
:param output_names: a list of names of the output tensors to predict, the
variables can be any computable tensor in the graph.
Predict specific output might not require all input variables.
:param return_input: whether to return (input, output) pair or just output. default to False.
Args:
session_init (SessionInit): how to initialize variables of the session.
model (ModelDesc): the model to use.
input_names (list): a list of input tensor names.
output_names (list): a list of names of the output tensors to predict, the
tensors can be any computable tensor in the graph.
return_input: same as in :attr:`PredictorBase.return_input`.
"""
# TODO use the name "tensor" instead of "variable"
def assert_type(v, tp):
......@@ -68,10 +62,6 @@ class PredictConfig(object):
def get_predict_func(config):
"""
Produce a offline predictor run inside a new session.
:param config: a `PredictConfig` instance.
:returns: A callable predictor that takes a list of input values, and return
a list of output values defined in ``config.output_var_names``.
Equivalent to ``OfflinePredictor(config)``.
"""
return OfflinePredictor(config)
......@@ -32,8 +32,9 @@ class MultiProcessPredictWorker(multiprocessing.Process):
def __init__(self, idx, config):
"""
:param idx: index of the worker. the 0th worker will print log.
:param config: a `PredictConfig`
Args:
idx (int): index of the worker. the 0th worker will print log.
config (PredictConfig): the config to use.
"""
super(MultiProcessPredictWorker, self).__init__()
self.idx = idx
......@@ -53,12 +54,17 @@ class MultiProcessPredictWorker(multiprocessing.Process):
class MultiProcessQueuePredictWorker(MultiProcessPredictWorker):
""" An offline predictor worker that takes input and produces output by queue"""
"""
An offline predictor worker that takes input and produces output by queue.
Each process will exit when they see :class:`DIE`.
"""
def __init__(self, idx, inqueue, outqueue, config):
"""
:param inqueue: input queue to get data point. elements are (task_id, dp)
:param outqueue: output queue put result. elements are (task_id, output)
Args:
idx, config: same as in :class:`MultiProcessPredictWorker`.
inqueue (multiprocessing.Queue): input queue to get data point. elements are (task_id, dp)
outqueue (multiprocessing.Queue): output queue to put result. elements are (task_id, output)
"""
super(MultiProcessQueuePredictWorker, self).__init__(idx, config)
self.inqueue = inqueue
......@@ -125,12 +131,16 @@ class PredictorWorkerThread(threading.Thread):
class MultiThreadAsyncPredictor(AsyncPredictorBase):
"""
An multithread online async predictor which run a list of PredictorBase.
An multithread online async predictor which runs a list of PredictorBase.
It would do an extra batching internally.
"""
def __init__(self, predictors, batch_size=5):
""" :param predictors: a list of OnlinePredictor"""
"""
Args:
predictors (list): a list of OnlinePredictor avaiable to use.
batch_size (int): the maximum of an internal batch.
"""
assert len(predictors)
for k in predictors:
# assert isinstance(k, OnlinePredictor), type(k)
......@@ -156,7 +166,7 @@ class MultiThreadAsyncPredictor(AsyncPredictorBase):
def put_task(self, dp, callback=None):
"""
dp must be non-batched, i.e. single instance
Same as in :meth:`AsyncPredictorBase.put_task`.
"""
f = Future()
if callback is not None:
......
......@@ -25,11 +25,15 @@ __all__ = ['DatasetPredictorBase', 'SimpleDatasetPredictor',
@six.add_metaclass(ABCMeta)
class DatasetPredictorBase(object):
""" Base class for dataset predictors.
These are predictors which run over a :class:`DataFlow`.
"""
def __init__(self, config, dataset):
"""
:param config: a `PredictConfig` instance.
:param dataset: a `DataFlow` instance.
Args:
config (PredictConfig): the config of predictor.
dataset (DataFlow): the DataFlow to run on.
"""
assert isinstance(dataset, DataFlow)
assert isinstance(config, PredictConfig)
......@@ -38,27 +42,29 @@ class DatasetPredictorBase(object):
@abstractmethod
def get_result(self):
""" A generator function, produce output for each input in dataset"""
"""
Yields:
output for each datapoint in the DataFlow.
"""
pass
def get_all_result(self):
"""
Run over the dataset and return a list of all predictions.
Returns:
list: all outputs for all datapoints in the DataFlow.
"""
return list(self.get_result())
class SimpleDatasetPredictor(DatasetPredictorBase):
"""
Run the predict_config on a given `DataFlow`.
Simply create one predictor and run it on the DataFlow.
"""
def __init__(self, config, dataset):
super(SimpleDatasetPredictor, self).__init__(config, dataset)
self.predictor = OfflinePredictor(config)
def get_result(self):
""" A generator to produce prediction for each data"""
self.dataset.reset_state()
try:
sz = self.dataset.size()
......@@ -70,20 +76,26 @@ class SimpleDatasetPredictor(DatasetPredictorBase):
yield res
pbar.update()
# TODO allow unordered
class MultiProcessDatasetPredictor(DatasetPredictorBase):
"""
Run prediction in multiprocesses, on either CPU or GPU.
Each process fetch datapoints as tasks and run predictions independently.
"""
# TODO allow unordered
def __init__(self, config, dataset, nr_proc, use_gpu=True, ordered=True):
"""
Run prediction in multiprocesses, on either CPU or GPU. Mix mode not supported.
:param nr_proc: number of processes to use
:param use_gpu: use GPU or CPU.
If GPU, then nr_proc cannot be more than what's in CUDA_VISIBLE_DEVICES
:param ordered: produce results with the original order of the
dataflow. a bit slower.
Args:
config: same as in :class:`DatasetPredictorBase`.
dataset: same as in :class:`DatasetPredictorBase`.
nr_proc (int): number of processes to use
use_gpu (bool): use GPU or CPU.
If GPU, then ``nr_proc`` cannot be more than what's in
CUDA_VISIBLE_DEVICES.
ordered (bool): produce outputs in the original order of the
datapoints. This will be a bit slower. Otherwise, :meth:`get_result` will produce
outputs in any order.
"""
if config.return_input:
logger.warn("Using the option `return_input` in MultiProcessDatasetPredictor might be slow")
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment