Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
e0190688
Commit
e0190688
authored
Jan 07, 2017
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
apidoc for predict/ and RL/
parent
06ea1c0a
Changes
11
Show whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
188 additions
and
88 deletions
+188
-88
docs/conf.py
docs/conf.py
+1
-0
tensorpack/RL/common.py
tensorpack/RL/common.py
+17
-5
tensorpack/RL/envbase.py
tensorpack/RL/envbase.py
+19
-11
tensorpack/RL/expreplay.py
tensorpack/RL/expreplay.py
+15
-9
tensorpack/RL/gymenv.py
tensorpack/RL/gymenv.py
+8
-1
tensorpack/RL/history.py
tensorpack/RL/history.py
+5
-4
tensorpack/models/__init__.py
tensorpack/models/__init__.py
+4
-0
tensorpack/predict/base.py
tensorpack/predict/base.py
+64
-15
tensorpack/predict/common.py
tensorpack/predict/common.py
+9
-19
tensorpack/predict/concurrency.py
tensorpack/predict/concurrency.py
+18
-8
tensorpack/predict/dataset.py
tensorpack/predict/dataset.py
+28
-16
No files found.
docs/conf.py
View file @
e0190688
...
...
@@ -73,6 +73,7 @@ extensions = [
]
napoleon_google_docstring
=
True
napoleon_include_init_with_doc
=
True
napoleon_include_special_with_doc
=
True
napoleon_numpy_docstring
=
False
napoleon_use_rtype
=
False
...
...
tensorpack/RL/common.py
View file @
e0190688
...
...
@@ -15,14 +15,16 @@ class PreventStuckPlayer(ProxyPlayer):
""" Prevent the player from getting stuck (repeating a no-op)
by inserting a different action. Useful in games such as Atari Breakout
where the agent needs to press the 'start' button to start playing.
It does auto-reset, but doesn't auto-restart the underlying player.
"""
# TODO hash the state as well?
def
__init__
(
self
,
player
,
nr_repeat
,
action
):
"""
It does auto-reset, but doesn't auto-restart the underlying player.
:param nr_repeat: trigger the 'action' after this many of repeated action
:param action: the action to be triggered to get out of stuck
Args:
nr_repeat (int): trigger the 'action' after this many of repeated action.
action: the action to be triggered to get out of stuck.
"""
super
(
PreventStuckPlayer
,
self
)
.
__init__
(
player
)
self
.
act_que
=
deque
(
maxlen
=
nr_repeat
)
...
...
@@ -44,10 +46,14 @@ class PreventStuckPlayer(ProxyPlayer):
class
LimitLengthPlayer
(
ProxyPlayer
):
""" Limit the total number of actions in an episode.
Will
auto restart the underlying player on timeout
Will
restart the underlying player on timeout.
"""
def
__init__
(
self
,
player
,
limit
):
"""
Args:
limit(int): the time limit
"""
super
(
LimitLengthPlayer
,
self
)
.
__init__
(
player
)
self
.
limit
=
limit
self
.
cnt
=
0
...
...
@@ -70,7 +76,8 @@ class LimitLengthPlayer(ProxyPlayer):
class
AutoRestartPlayer
(
ProxyPlayer
):
""" Auto-restart the player on episode ends,
in case some player wasn't designed to do so. """
in case some player wasn't designed to do so.
"""
def
action
(
self
,
act
):
r
,
isOver
=
self
.
player
.
action
(
act
)
...
...
@@ -81,8 +88,13 @@ class AutoRestartPlayer(ProxyPlayer):
class
MapPlayerState
(
ProxyPlayer
):
""" Map the state of the underlying player by a function. """
def
__init__
(
self
,
player
,
func
):
"""
Args:
func: takes the old state and return a new state.
"""
super
(
MapPlayerState
,
self
)
.
__init__
(
player
)
self
.
func
=
func
...
...
tensorpack/RL/envbase.py
View file @
e0190688
...
...
@@ -15,6 +15,7 @@ __all__ = ['RLEnvironment', 'NaiveRLEnvironment', 'ProxyPlayer',
@
six
.
add_metaclass
(
ABCMeta
)
class
RLEnvironment
(
object
):
""" Base class of RL environment. """
def
__init__
(
self
):
self
.
reset_stat
()
...
...
@@ -29,8 +30,11 @@ class RLEnvironment(object):
def
action
(
self
,
act
):
"""
Perform an action. Will automatically start a new episode if isOver==True
:param act: the action
:returns: (reward, isOver)
Args:
act: the action
Returns:
tuple: (reward, isOver)
"""
def
restart_episode
(
self
):
...
...
@@ -38,22 +42,26 @@ class RLEnvironment(object):
raise
NotImplementedError
()
def
finish_episode
(
self
):
"""
g
et called when an episode finished"""
"""
G
et called when an episode finished"""
pass
def
get_action_space
(
self
):
""" return an `ActionSpace` instance"""
""" Returns:
:class:`ActionSpace` """
raise
NotImplementedError
()
def
reset_stat
(
self
):
"""
r
eset all statistics counter"""
"""
R
eset all statistics counter"""
self
.
stats
=
defaultdict
(
list
)
def
play_one_episode
(
self
,
func
,
stat
=
'score'
):
""" play one episode for eval.
:param func: call with the state and return an action
:param stat: a key or list of keys in stats
:returns: the stat(s) after running this episode
""" Play one episode for eval.
Args:
func: the policy function. Takes a state and returns an action.
stat: a key or list of keys in stats to return.
Returns:
the stat(s) after running this episode
"""
if
not
isinstance
(
stat
,
list
):
stat
=
[
stat
]
...
...
@@ -101,7 +109,7 @@ class DiscreteActionSpace(ActionSpace):
class
NaiveRLEnvironment
(
RLEnvironment
):
"""
f
or testing only"""
"""
F
or testing only"""
def
__init__
(
self
):
self
.
k
=
0
...
...
@@ -116,7 +124,7 @@ class NaiveRLEnvironment(RLEnvironment):
class
ProxyPlayer
(
RLEnvironment
):
""" Serve as a proxy another player """
""" Serve as a proxy
to
another player """
def
__init__
(
self
,
player
):
self
.
player
=
player
...
...
tensorpack/RL/expreplay.py
View file @
e0190688
...
...
@@ -23,10 +23,14 @@ Experience = namedtuple('Experience',
class
ExpReplay
(
DataFlow
,
Callback
):
"""
Implement experience replay in the paper
`Human-level control through deep reinforcement learning`.
`Human-level control through deep reinforcement learning
<http://www.nature.com/nature/journal/v518/n7540/full/nature14236.html>`_.
This implementation provides the interface as an DataFlow.
This DataFlow is not fork-safe (doesn't support multiprocess prefetching)
This implementation provides the interface as a :class:`DataFlow`.
This DataFlow is __not__ fork-safe (thus doesn't support multiprocess prefetching).
This implementation only works with Q-learning. It assumes that state is
batch-able, and the network takes batched inputs.
"""
def
__init__
(
self
,
...
...
@@ -43,12 +47,14 @@ class ExpReplay(DataFlow, Callback):
history_len
=
1
):
"""
:param predictor: a callabale running the up-to-date network.
called with a state, return a distribution.
:param player: an `RLEnvironment`
:param history_len: length of history frames to concat. zero-filled initial frames
:param update_frequency: number of new transitions to add to memory
after sampling a batch of transitions for training
Args:
predictor_io_names (tuple of list of str): input/output names to
predict Q value from state.
player (RLEnvironment): the player.
history_len (int): length of history frames to concat. Zero-filled
initial frames.
update_frequency (int): number of new transitions to add to memory
after sampling a batch of transitions for training.
"""
init_memory_size
=
int
(
init_memory_size
)
...
...
tensorpack/RL/gymenv.py
View file @
e0190688
...
...
@@ -30,10 +30,17 @@ _ENV_LOCK = threading.Lock()
class
GymEnv
(
RLEnvironment
):
"""
An OpenAI/gym wrapper. Can optionally auto restart.
Only support discrete action space
now
Only support discrete action space
for now.
"""
def
__init__
(
self
,
name
,
dumpdir
=
None
,
viz
=
False
,
auto_restart
=
True
):
"""
Args:
name (str): the gym environment name.
dumpdir (str): the directory to dump recordings to.
viz (bool): whether to start visualization.
auto_restart (bool): whether to restart after episode ends.
"""
with
_ENV_LOCK
:
self
.
gymenv
=
gym
.
make
(
name
)
if
dumpdir
:
...
...
tensorpack/RL/history.py
View file @
e0190688
...
...
@@ -11,14 +11,15 @@ __all__ = ['HistoryFramePlayer']
class
HistoryFramePlayer
(
ProxyPlayer
):
""" Include history frames in state, or use black images
Assume
player will do auto-restart.
""" Include history frames in state, or use black images
.
It assumes the underlying
player will do auto-restart.
"""
def
__init__
(
self
,
player
,
hist_len
):
"""
:param hist_len: total length of the state, including the current
and `hist_len-1` history
Args:
hist_len (int): total length of the state, including the current
and `hist_len-1` history.
"""
super
(
HistoryFramePlayer
,
self
)
.
__init__
(
player
)
self
.
history
=
deque
(
maxlen
=
hist_len
)
...
...
tensorpack/models/__init__.py
View file @
e0190688
...
...
@@ -92,6 +92,10 @@ class LinearWrap(object):
return
LinearWrap
(
ret
)
def
__call__
(
self
):
"""
Returns:
tf.Tensor: the underlying wrapped tensor.
"""
return
self
.
_t
def
tensor
(
self
):
...
...
tensorpack/predict/base.py
View file @
e0190688
...
...
@@ -11,8 +11,8 @@ from ..utils.naming import PREDICT_TOWER
from
..utils
import
logger
from
..tfutils
import
get_tensors_by_names
,
TowerContext
__all__
=
[
'
OnlinePredictor'
,
'OfflinePredictor
'
,
'
AsyncPredictorBase
'
,
__all__
=
[
'
PredictorBase'
,
'AsyncPredictorBase
'
,
'
OnlinePredictor'
,
'OfflinePredictor
'
,
'MultiTowerOfflinePredictor'
,
'build_multi_tower_prediction_graph'
,
'DataParallelOfflinePredictor'
]
...
...
@@ -20,15 +20,29 @@ __all__ = ['OnlinePredictor', 'OfflinePredictor',
@
six
.
add_metaclass
(
ABCMeta
)
class
PredictorBase
(
object
):
"""
Available attributes:
session
return_input
Base class for all predictors.
Attributes:
session (tf.Session):
return_input (bool): whether the call will also return (inputs, outputs)
or just outpus
"""
def
__call__
(
self
,
*
args
):
"""
if len(args) == 1, assume args[0] is a datapoint (a list)
else, assume args is a datapoinnt
Call the predictor on some inputs.
If ``len(args) == 1``, assume ``args[0]`` is a datapoint (a list).
otherwise, assume ``args`` is a datapoinnt
Examples:
When you have a predictor which takes a datapoint [e1, e2], you
can call it in two ways:
.. code-block:: python
predictor(e1, e2)
predictor([e1, e2])
"""
if
len
(
args
)
!=
1
:
dp
=
args
...
...
@@ -49,15 +63,18 @@ class PredictorBase(object):
class
AsyncPredictorBase
(
PredictorBase
):
""" Base class for all async predictors. """
@
abstractmethod
def
put_task
(
self
,
dp
,
callback
=
None
):
"""
:param dp: A data point (list of component) as inputs.
(It should be either batched or not batched depending on the predictor implementation)
:param callback: a thread-safe callback to get called with
either outputs or (inputs, outputs)
:return: a Future of results
Args:
dp (list): A datapoint as inputs. It could be either batched or not
batched depending on the predictor implementation).
callback: a thread-safe callback to get called with
either outputs or (inputs, outputs).
Returns:
concurrent.futures.Future: a Future of results
"""
@
abstractmethod
...
...
@@ -72,8 +89,16 @@ class AsyncPredictorBase(PredictorBase):
class
OnlinePredictor
(
PredictorBase
):
""" A predictor which directly use an existing session. """
def
__init__
(
self
,
sess
,
input_tensors
,
output_tensors
,
return_input
=
False
):
"""
Args:
sess (tf.Session): an existing session.
input_tensors (list): list of names.
output_tensors (list): list of names.
return_input (bool): same as :attr:`PredictorBase.return_input`.
"""
self
.
session
=
sess
self
.
return_input
=
return_input
...
...
@@ -89,9 +114,13 @@ class OnlinePredictor(PredictorBase):
class
OfflinePredictor
(
OnlinePredictor
):
"""
Build a predictor from a given config, in an independent graph
"""
"""
A predictor built from a given config, in a new graph.
"""
def
__init__
(
self
,
config
):
"""
Args:
config (PredictConfig): the config to use.
"""
self
.
graph
=
tf
.
Graph
()
with
self
.
graph
.
as_default
():
input_placehdrs
=
config
.
model
.
get_input_vars
()
...
...
@@ -109,8 +138,10 @@ class OfflinePredictor(OnlinePredictor):
def
build_multi_tower_prediction_graph
(
build_tower_fn
,
towers
):
"""
:param build_tower_fn: the function to be called inside each tower, taking tower as the argument
:param towers: a list of gpu relative id.
Args:
build_tower_fn: a function that will be called inside each tower,
taking tower id as the argument.
towers: a list of relative GPU id.
"""
for
k
in
towers
:
logger
.
info
(
...
...
@@ -122,8 +153,14 @@ def build_multi_tower_prediction_graph(build_tower_fn, towers):
class
MultiTowerOfflinePredictor
(
OnlinePredictor
):
""" A multi-tower multi-GPU predictor. """
def
__init__
(
self
,
config
,
towers
):
"""
Args:
config (PredictConfig): the config to use.
towers: a list of relative GPU id.
"""
self
.
graph
=
tf
.
Graph
()
self
.
predictors
=
[]
with
self
.
graph
.
as_default
():
...
...
@@ -149,12 +186,24 @@ class MultiTowerOfflinePredictor(OnlinePredictor):
return
self
.
predictors
[
0
]
.
_do_call
(
dp
)
def
get_predictors
(
self
,
n
):
"""
Returns:
PredictorBase: the nth predictor on the nth GPU.
"""
return
[
self
.
predictors
[
k
%
len
(
self
.
predictors
)]
for
k
in
range
(
n
)]
class
DataParallelOfflinePredictor
(
OnlinePredictor
):
""" A data-parallel predictor.
It runs different towers in parallel.
"""
def
__init__
(
self
,
config
,
towers
):
"""
Args:
config (PredictConfig): the config to use.
towers: a list of relative GPU id.
"""
self
.
graph
=
tf
.
Graph
()
with
self
.
graph
.
as_default
():
sess
=
tf
.
Session
(
config
=
config
.
session_config
)
...
...
tensorpack/predict/common.py
View file @
e0190688
...
...
@@ -2,7 +2,6 @@
# File: common.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
from
collections
import
namedtuple
import
six
from
tensorpack.models
import
ModelDesc
...
...
@@ -10,25 +9,20 @@ from ..tfutils import get_default_sess_config
from
..tfutils.sessinit
import
SessionInit
,
JustCurrentSession
from
.base
import
OfflinePredictor
__all__
=
[
'PredictConfig'
,
'get_predict_func'
,
'PredictResult'
]
PredictResult
=
namedtuple
(
'PredictResult'
,
[
'input'
,
'output'
])
__all__
=
[
'PredictConfig'
,
'get_predict_func'
]
class
PredictConfig
(
object
):
def
__init__
(
self
,
**
kwargs
):
"""
The config used by `get_predict_func`.
:param session_init: a `utils.sessinit.SessionInit` instance to
initialize variables of a session.
:param model: a `ModelDesc` instance
:param input_names: a list of input variable names.
:param output_names: a list of names of the output tensors to predict, the
variables can be any computable tensor in the graph.
Predict specific output might not require all input variables.
:param return_input: whether to return (input, output) pair or just output. default to False.
Args:
session_init (SessionInit): how to initialize variables of the session.
model (ModelDesc): the model to use.
input_names (list): a list of input tensor names.
output_names (list): a list of names of the output tensors to predict, the
tensors can be any computable tensor in the graph.
return_input: same as in :attr:`PredictorBase.return_input`.
"""
# TODO use the name "tensor" instead of "variable"
def
assert_type
(
v
,
tp
):
...
...
@@ -68,10 +62,6 @@ class PredictConfig(object):
def
get_predict_func
(
config
):
"""
Produce a offline predictor run inside a new session.
:param config: a `PredictConfig` instance.
:returns: A callable predictor that takes a list of input values, and return
a list of output values defined in ``config.output_var_names``.
Equivalent to ``OfflinePredictor(config)``.
"""
return
OfflinePredictor
(
config
)
tensorpack/predict/concurrency.py
View file @
e0190688
...
...
@@ -32,8 +32,9 @@ class MultiProcessPredictWorker(multiprocessing.Process):
def
__init__
(
self
,
idx
,
config
):
"""
:param idx: index of the worker. the 0th worker will print log.
:param config: a `PredictConfig`
Args:
idx (int): index of the worker. the 0th worker will print log.
config (PredictConfig): the config to use.
"""
super
(
MultiProcessPredictWorker
,
self
)
.
__init__
()
self
.
idx
=
idx
...
...
@@ -53,12 +54,17 @@ class MultiProcessPredictWorker(multiprocessing.Process):
class
MultiProcessQueuePredictWorker
(
MultiProcessPredictWorker
):
""" An offline predictor worker that takes input and produces output by queue"""
"""
An offline predictor worker that takes input and produces output by queue.
Each process will exit when they see :class:`DIE`.
"""
def
__init__
(
self
,
idx
,
inqueue
,
outqueue
,
config
):
"""
:param inqueue: input queue to get data point. elements are (task_id, dp)
:param outqueue: output queue put result. elements are (task_id, output)
Args:
idx, config: same as in :class:`MultiProcessPredictWorker`.
inqueue (multiprocessing.Queue): input queue to get data point. elements are (task_id, dp)
outqueue (multiprocessing.Queue): output queue to put result. elements are (task_id, output)
"""
super
(
MultiProcessQueuePredictWorker
,
self
)
.
__init__
(
idx
,
config
)
self
.
inqueue
=
inqueue
...
...
@@ -125,12 +131,16 @@ class PredictorWorkerThread(threading.Thread):
class
MultiThreadAsyncPredictor
(
AsyncPredictorBase
):
"""
An multithread online async predictor which run a list of PredictorBase.
An multithread online async predictor which run
s
a list of PredictorBase.
It would do an extra batching internally.
"""
def
__init__
(
self
,
predictors
,
batch_size
=
5
):
""" :param predictors: a list of OnlinePredictor"""
"""
Args:
predictors (list): a list of OnlinePredictor avaiable to use.
batch_size (int): the maximum of an internal batch.
"""
assert
len
(
predictors
)
for
k
in
predictors
:
# assert isinstance(k, OnlinePredictor), type(k)
...
...
@@ -156,7 +166,7 @@ class MultiThreadAsyncPredictor(AsyncPredictorBase):
def
put_task
(
self
,
dp
,
callback
=
None
):
"""
dp must be non-batched, i.e. single instance
Same as in :meth:`AsyncPredictorBase.put_task`.
"""
f
=
Future
()
if
callback
is
not
None
:
...
...
tensorpack/predict/dataset.py
View file @
e0190688
...
...
@@ -25,11 +25,15 @@ __all__ = ['DatasetPredictorBase', 'SimpleDatasetPredictor',
@
six
.
add_metaclass
(
ABCMeta
)
class
DatasetPredictorBase
(
object
):
""" Base class for dataset predictors.
These are predictors which run over a :class:`DataFlow`.
"""
def
__init__
(
self
,
config
,
dataset
):
"""
:param config: a `PredictConfig` instance.
:param dataset: a `DataFlow` instance.
Args:
config (PredictConfig): the config of predictor.
dataset (DataFlow): the DataFlow to run on.
"""
assert
isinstance
(
dataset
,
DataFlow
)
assert
isinstance
(
config
,
PredictConfig
)
...
...
@@ -38,27 +42,29 @@ class DatasetPredictorBase(object):
@
abstractmethod
def
get_result
(
self
):
""" A generator function, produce output for each input in dataset"""
"""
Yields:
output for each datapoint in the DataFlow.
"""
pass
def
get_all_result
(
self
):
"""
Run over the dataset and return a list of all predictions.
Returns:
list: all outputs for all datapoints in the DataFlow.
"""
return
list
(
self
.
get_result
())
class
SimpleDatasetPredictor
(
DatasetPredictorBase
):
"""
Run the predict_config on a given `DataFlow`
.
Simply create one predictor and run it on the DataFlow
.
"""
def
__init__
(
self
,
config
,
dataset
):
super
(
SimpleDatasetPredictor
,
self
)
.
__init__
(
config
,
dataset
)
self
.
predictor
=
OfflinePredictor
(
config
)
def
get_result
(
self
):
""" A generator to produce prediction for each data"""
self
.
dataset
.
reset_state
()
try
:
sz
=
self
.
dataset
.
size
()
...
...
@@ -70,20 +76,26 @@ class SimpleDatasetPredictor(DatasetPredictorBase):
yield
res
pbar
.
update
()
# TODO allow unordered
class
MultiProcessDatasetPredictor
(
DatasetPredictorBase
):
"""
Run prediction in multiprocesses, on either CPU or GPU.
Each process fetch datapoints as tasks and run predictions independently.
"""
# TODO allow unordered
def
__init__
(
self
,
config
,
dataset
,
nr_proc
,
use_gpu
=
True
,
ordered
=
True
):
"""
Run prediction in multiprocesses, on either CPU or GPU. Mix mode not supported.
:param nr_proc: number of processes to use
:param use_gpu: use GPU or CPU.
If GPU, then nr_proc cannot be more than what's in CUDA_VISIBLE_DEVICES
:param ordered: produce results with the original order of the
dataflow. a bit slower.
Args:
config: same as in :class:`DatasetPredictorBase`.
dataset: same as in :class:`DatasetPredictorBase`.
nr_proc (int): number of processes to use
use_gpu (bool): use GPU or CPU.
If GPU, then ``nr_proc`` cannot be more than what's in
CUDA_VISIBLE_DEVICES.
ordered (bool): produce outputs in the original order of the
datapoints. This will be a bit slower. Otherwise, :meth:`get_result` will produce
outputs in any order.
"""
if
config
.
return_input
:
logger
.
warn
(
"Using the option `return_input` in MultiProcessDatasetPredictor might be slow"
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment