Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
e0190688
Commit
e0190688
authored
Jan 07, 2017
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
apidoc for predict/ and RL/
parent
06ea1c0a
Changes
11
Show whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
188 additions
and
88 deletions
+188
-88
docs/conf.py
docs/conf.py
+1
-0
tensorpack/RL/common.py
tensorpack/RL/common.py
+17
-5
tensorpack/RL/envbase.py
tensorpack/RL/envbase.py
+19
-11
tensorpack/RL/expreplay.py
tensorpack/RL/expreplay.py
+15
-9
tensorpack/RL/gymenv.py
tensorpack/RL/gymenv.py
+8
-1
tensorpack/RL/history.py
tensorpack/RL/history.py
+5
-4
tensorpack/models/__init__.py
tensorpack/models/__init__.py
+4
-0
tensorpack/predict/base.py
tensorpack/predict/base.py
+64
-15
tensorpack/predict/common.py
tensorpack/predict/common.py
+9
-19
tensorpack/predict/concurrency.py
tensorpack/predict/concurrency.py
+18
-8
tensorpack/predict/dataset.py
tensorpack/predict/dataset.py
+28
-16
No files found.
docs/conf.py
View file @
e0190688
...
@@ -73,6 +73,7 @@ extensions = [
...
@@ -73,6 +73,7 @@ extensions = [
]
]
napoleon_google_docstring
=
True
napoleon_google_docstring
=
True
napoleon_include_init_with_doc
=
True
napoleon_include_init_with_doc
=
True
napoleon_include_special_with_doc
=
True
napoleon_numpy_docstring
=
False
napoleon_numpy_docstring
=
False
napoleon_use_rtype
=
False
napoleon_use_rtype
=
False
...
...
tensorpack/RL/common.py
View file @
e0190688
...
@@ -15,14 +15,16 @@ class PreventStuckPlayer(ProxyPlayer):
...
@@ -15,14 +15,16 @@ class PreventStuckPlayer(ProxyPlayer):
""" Prevent the player from getting stuck (repeating a no-op)
""" Prevent the player from getting stuck (repeating a no-op)
by inserting a different action. Useful in games such as Atari Breakout
by inserting a different action. Useful in games such as Atari Breakout
where the agent needs to press the 'start' button to start playing.
where the agent needs to press the 'start' button to start playing.
It does auto-reset, but doesn't auto-restart the underlying player.
"""
"""
# TODO hash the state as well?
# TODO hash the state as well?
def
__init__
(
self
,
player
,
nr_repeat
,
action
):
def
__init__
(
self
,
player
,
nr_repeat
,
action
):
"""
"""
It does auto-reset, but doesn't auto-restart the underlying player.
Args:
:param nr_repeat: trigger the 'action' after this many of repeated action
nr_repeat (int): trigger the 'action' after this many of repeated action.
:param action: the action to be triggered to get out of stuck
action: the action to be triggered to get out of stuck.
"""
"""
super
(
PreventStuckPlayer
,
self
)
.
__init__
(
player
)
super
(
PreventStuckPlayer
,
self
)
.
__init__
(
player
)
self
.
act_que
=
deque
(
maxlen
=
nr_repeat
)
self
.
act_que
=
deque
(
maxlen
=
nr_repeat
)
...
@@ -44,10 +46,14 @@ class PreventStuckPlayer(ProxyPlayer):
...
@@ -44,10 +46,14 @@ class PreventStuckPlayer(ProxyPlayer):
class
LimitLengthPlayer
(
ProxyPlayer
):
class
LimitLengthPlayer
(
ProxyPlayer
):
""" Limit the total number of actions in an episode.
""" Limit the total number of actions in an episode.
Will
auto restart the underlying player on timeout
Will
restart the underlying player on timeout.
"""
"""
def
__init__
(
self
,
player
,
limit
):
def
__init__
(
self
,
player
,
limit
):
"""
Args:
limit(int): the time limit
"""
super
(
LimitLengthPlayer
,
self
)
.
__init__
(
player
)
super
(
LimitLengthPlayer
,
self
)
.
__init__
(
player
)
self
.
limit
=
limit
self
.
limit
=
limit
self
.
cnt
=
0
self
.
cnt
=
0
...
@@ -70,7 +76,8 @@ class LimitLengthPlayer(ProxyPlayer):
...
@@ -70,7 +76,8 @@ class LimitLengthPlayer(ProxyPlayer):
class
AutoRestartPlayer
(
ProxyPlayer
):
class
AutoRestartPlayer
(
ProxyPlayer
):
""" Auto-restart the player on episode ends,
""" Auto-restart the player on episode ends,
in case some player wasn't designed to do so. """
in case some player wasn't designed to do so.
"""
def
action
(
self
,
act
):
def
action
(
self
,
act
):
r
,
isOver
=
self
.
player
.
action
(
act
)
r
,
isOver
=
self
.
player
.
action
(
act
)
...
@@ -81,8 +88,13 @@ class AutoRestartPlayer(ProxyPlayer):
...
@@ -81,8 +88,13 @@ class AutoRestartPlayer(ProxyPlayer):
class
MapPlayerState
(
ProxyPlayer
):
class
MapPlayerState
(
ProxyPlayer
):
""" Map the state of the underlying player by a function. """
def
__init__
(
self
,
player
,
func
):
def
__init__
(
self
,
player
,
func
):
"""
Args:
func: takes the old state and return a new state.
"""
super
(
MapPlayerState
,
self
)
.
__init__
(
player
)
super
(
MapPlayerState
,
self
)
.
__init__
(
player
)
self
.
func
=
func
self
.
func
=
func
...
...
tensorpack/RL/envbase.py
View file @
e0190688
...
@@ -15,6 +15,7 @@ __all__ = ['RLEnvironment', 'NaiveRLEnvironment', 'ProxyPlayer',
...
@@ -15,6 +15,7 @@ __all__ = ['RLEnvironment', 'NaiveRLEnvironment', 'ProxyPlayer',
@
six
.
add_metaclass
(
ABCMeta
)
@
six
.
add_metaclass
(
ABCMeta
)
class
RLEnvironment
(
object
):
class
RLEnvironment
(
object
):
""" Base class of RL environment. """
def
__init__
(
self
):
def
__init__
(
self
):
self
.
reset_stat
()
self
.
reset_stat
()
...
@@ -29,8 +30,11 @@ class RLEnvironment(object):
...
@@ -29,8 +30,11 @@ class RLEnvironment(object):
def
action
(
self
,
act
):
def
action
(
self
,
act
):
"""
"""
Perform an action. Will automatically start a new episode if isOver==True
Perform an action. Will automatically start a new episode if isOver==True
:param act: the action
:returns: (reward, isOver)
Args:
act: the action
Returns:
tuple: (reward, isOver)
"""
"""
def
restart_episode
(
self
):
def
restart_episode
(
self
):
...
@@ -38,22 +42,26 @@ class RLEnvironment(object):
...
@@ -38,22 +42,26 @@ class RLEnvironment(object):
raise
NotImplementedError
()
raise
NotImplementedError
()
def
finish_episode
(
self
):
def
finish_episode
(
self
):
"""
g
et called when an episode finished"""
"""
G
et called when an episode finished"""
pass
pass
def
get_action_space
(
self
):
def
get_action_space
(
self
):
""" return an `ActionSpace` instance"""
""" Returns:
:class:`ActionSpace` """
raise
NotImplementedError
()
raise
NotImplementedError
()
def
reset_stat
(
self
):
def
reset_stat
(
self
):
"""
r
eset all statistics counter"""
"""
R
eset all statistics counter"""
self
.
stats
=
defaultdict
(
list
)
self
.
stats
=
defaultdict
(
list
)
def
play_one_episode
(
self
,
func
,
stat
=
'score'
):
def
play_one_episode
(
self
,
func
,
stat
=
'score'
):
""" play one episode for eval.
""" Play one episode for eval.
:param func: call with the state and return an action
:param stat: a key or list of keys in stats
Args:
:returns: the stat(s) after running this episode
func: the policy function. Takes a state and returns an action.
stat: a key or list of keys in stats to return.
Returns:
the stat(s) after running this episode
"""
"""
if
not
isinstance
(
stat
,
list
):
if
not
isinstance
(
stat
,
list
):
stat
=
[
stat
]
stat
=
[
stat
]
...
@@ -101,7 +109,7 @@ class DiscreteActionSpace(ActionSpace):
...
@@ -101,7 +109,7 @@ class DiscreteActionSpace(ActionSpace):
class
NaiveRLEnvironment
(
RLEnvironment
):
class
NaiveRLEnvironment
(
RLEnvironment
):
"""
f
or testing only"""
"""
F
or testing only"""
def
__init__
(
self
):
def
__init__
(
self
):
self
.
k
=
0
self
.
k
=
0
...
@@ -116,7 +124,7 @@ class NaiveRLEnvironment(RLEnvironment):
...
@@ -116,7 +124,7 @@ class NaiveRLEnvironment(RLEnvironment):
class
ProxyPlayer
(
RLEnvironment
):
class
ProxyPlayer
(
RLEnvironment
):
""" Serve as a proxy another player """
""" Serve as a proxy
to
another player """
def
__init__
(
self
,
player
):
def
__init__
(
self
,
player
):
self
.
player
=
player
self
.
player
=
player
...
...
tensorpack/RL/expreplay.py
View file @
e0190688
...
@@ -23,10 +23,14 @@ Experience = namedtuple('Experience',
...
@@ -23,10 +23,14 @@ Experience = namedtuple('Experience',
class
ExpReplay
(
DataFlow
,
Callback
):
class
ExpReplay
(
DataFlow
,
Callback
):
"""
"""
Implement experience replay in the paper
Implement experience replay in the paper
`Human-level control through deep reinforcement learning`.
`Human-level control through deep reinforcement learning
<http://www.nature.com/nature/journal/v518/n7540/full/nature14236.html>`_.
This implementation provides the interface as an DataFlow.
This implementation provides the interface as a :class:`DataFlow`.
This DataFlow is not fork-safe (doesn't support multiprocess prefetching)
This DataFlow is __not__ fork-safe (thus doesn't support multiprocess prefetching).
This implementation only works with Q-learning. It assumes that state is
batch-able, and the network takes batched inputs.
"""
"""
def
__init__
(
self
,
def
__init__
(
self
,
...
@@ -43,12 +47,14 @@ class ExpReplay(DataFlow, Callback):
...
@@ -43,12 +47,14 @@ class ExpReplay(DataFlow, Callback):
history_len
=
1
history_len
=
1
):
):
"""
"""
:param predictor: a callabale running the up-to-date network.
Args:
called with a state, return a distribution.
predictor_io_names (tuple of list of str): input/output names to
:param player: an `RLEnvironment`
predict Q value from state.
:param history_len: length of history frames to concat. zero-filled initial frames
player (RLEnvironment): the player.
:param update_frequency: number of new transitions to add to memory
history_len (int): length of history frames to concat. Zero-filled
after sampling a batch of transitions for training
initial frames.
update_frequency (int): number of new transitions to add to memory
after sampling a batch of transitions for training.
"""
"""
init_memory_size
=
int
(
init_memory_size
)
init_memory_size
=
int
(
init_memory_size
)
...
...
tensorpack/RL/gymenv.py
View file @
e0190688
...
@@ -30,10 +30,17 @@ _ENV_LOCK = threading.Lock()
...
@@ -30,10 +30,17 @@ _ENV_LOCK = threading.Lock()
class
GymEnv
(
RLEnvironment
):
class
GymEnv
(
RLEnvironment
):
"""
"""
An OpenAI/gym wrapper. Can optionally auto restart.
An OpenAI/gym wrapper. Can optionally auto restart.
Only support discrete action space
now
Only support discrete action space
for now.
"""
"""
def
__init__
(
self
,
name
,
dumpdir
=
None
,
viz
=
False
,
auto_restart
=
True
):
def
__init__
(
self
,
name
,
dumpdir
=
None
,
viz
=
False
,
auto_restart
=
True
):
"""
Args:
name (str): the gym environment name.
dumpdir (str): the directory to dump recordings to.
viz (bool): whether to start visualization.
auto_restart (bool): whether to restart after episode ends.
"""
with
_ENV_LOCK
:
with
_ENV_LOCK
:
self
.
gymenv
=
gym
.
make
(
name
)
self
.
gymenv
=
gym
.
make
(
name
)
if
dumpdir
:
if
dumpdir
:
...
...
tensorpack/RL/history.py
View file @
e0190688
...
@@ -11,14 +11,15 @@ __all__ = ['HistoryFramePlayer']
...
@@ -11,14 +11,15 @@ __all__ = ['HistoryFramePlayer']
class
HistoryFramePlayer
(
ProxyPlayer
):
class
HistoryFramePlayer
(
ProxyPlayer
):
""" Include history frames in state, or use black images
""" Include history frames in state, or use black images
.
Assume
player will do auto-restart.
It assumes the underlying
player will do auto-restart.
"""
"""
def
__init__
(
self
,
player
,
hist_len
):
def
__init__
(
self
,
player
,
hist_len
):
"""
"""
:param hist_len: total length of the state, including the current
Args:
and `hist_len-1` history
hist_len (int): total length of the state, including the current
and `hist_len-1` history.
"""
"""
super
(
HistoryFramePlayer
,
self
)
.
__init__
(
player
)
super
(
HistoryFramePlayer
,
self
)
.
__init__
(
player
)
self
.
history
=
deque
(
maxlen
=
hist_len
)
self
.
history
=
deque
(
maxlen
=
hist_len
)
...
...
tensorpack/models/__init__.py
View file @
e0190688
...
@@ -92,6 +92,10 @@ class LinearWrap(object):
...
@@ -92,6 +92,10 @@ class LinearWrap(object):
return
LinearWrap
(
ret
)
return
LinearWrap
(
ret
)
def
__call__
(
self
):
def
__call__
(
self
):
"""
Returns:
tf.Tensor: the underlying wrapped tensor.
"""
return
self
.
_t
return
self
.
_t
def
tensor
(
self
):
def
tensor
(
self
):
...
...
tensorpack/predict/base.py
View file @
e0190688
...
@@ -11,8 +11,8 @@ from ..utils.naming import PREDICT_TOWER
...
@@ -11,8 +11,8 @@ from ..utils.naming import PREDICT_TOWER
from
..utils
import
logger
from
..utils
import
logger
from
..tfutils
import
get_tensors_by_names
,
TowerContext
from
..tfutils
import
get_tensors_by_names
,
TowerContext
__all__
=
[
'
OnlinePredictor'
,
'OfflinePredictor
'
,
__all__
=
[
'
PredictorBase'
,
'AsyncPredictorBase
'
,
'
AsyncPredictorBase
'
,
'
OnlinePredictor'
,
'OfflinePredictor
'
,
'MultiTowerOfflinePredictor'
,
'build_multi_tower_prediction_graph'
,
'MultiTowerOfflinePredictor'
,
'build_multi_tower_prediction_graph'
,
'DataParallelOfflinePredictor'
]
'DataParallelOfflinePredictor'
]
...
@@ -20,15 +20,29 @@ __all__ = ['OnlinePredictor', 'OfflinePredictor',
...
@@ -20,15 +20,29 @@ __all__ = ['OnlinePredictor', 'OfflinePredictor',
@
six
.
add_metaclass
(
ABCMeta
)
@
six
.
add_metaclass
(
ABCMeta
)
class
PredictorBase
(
object
):
class
PredictorBase
(
object
):
"""
"""
Available attributes:
Base class for all predictors.
session
return_input
Attributes:
session (tf.Session):
return_input (bool): whether the call will also return (inputs, outputs)
or just outpus
"""
"""
def
__call__
(
self
,
*
args
):
def
__call__
(
self
,
*
args
):
"""
"""
if len(args) == 1, assume args[0] is a datapoint (a list)
Call the predictor on some inputs.
else, assume args is a datapoinnt
If ``len(args) == 1``, assume ``args[0]`` is a datapoint (a list).
otherwise, assume ``args`` is a datapoinnt
Examples:
When you have a predictor which takes a datapoint [e1, e2], you
can call it in two ways:
.. code-block:: python
predictor(e1, e2)
predictor([e1, e2])
"""
"""
if
len
(
args
)
!=
1
:
if
len
(
args
)
!=
1
:
dp
=
args
dp
=
args
...
@@ -49,15 +63,18 @@ class PredictorBase(object):
...
@@ -49,15 +63,18 @@ class PredictorBase(object):
class
AsyncPredictorBase
(
PredictorBase
):
class
AsyncPredictorBase
(
PredictorBase
):
""" Base class for all async predictors. """
@
abstractmethod
@
abstractmethod
def
put_task
(
self
,
dp
,
callback
=
None
):
def
put_task
(
self
,
dp
,
callback
=
None
):
"""
"""
:param dp: A data point (list of component) as inputs.
Args:
(It should be either batched or not batched depending on the predictor implementation)
dp (list): A datapoint as inputs. It could be either batched or not
:param callback: a thread-safe callback to get called with
batched depending on the predictor implementation).
either outputs or (inputs, outputs)
callback: a thread-safe callback to get called with
:return: a Future of results
either outputs or (inputs, outputs).
Returns:
concurrent.futures.Future: a Future of results
"""
"""
@
abstractmethod
@
abstractmethod
...
@@ -72,8 +89,16 @@ class AsyncPredictorBase(PredictorBase):
...
@@ -72,8 +89,16 @@ class AsyncPredictorBase(PredictorBase):
class
OnlinePredictor
(
PredictorBase
):
class
OnlinePredictor
(
PredictorBase
):
""" A predictor which directly use an existing session. """
def
__init__
(
self
,
sess
,
input_tensors
,
output_tensors
,
return_input
=
False
):
def
__init__
(
self
,
sess
,
input_tensors
,
output_tensors
,
return_input
=
False
):
"""
Args:
sess (tf.Session): an existing session.
input_tensors (list): list of names.
output_tensors (list): list of names.
return_input (bool): same as :attr:`PredictorBase.return_input`.
"""
self
.
session
=
sess
self
.
session
=
sess
self
.
return_input
=
return_input
self
.
return_input
=
return_input
...
@@ -89,9 +114,13 @@ class OnlinePredictor(PredictorBase):
...
@@ -89,9 +114,13 @@ class OnlinePredictor(PredictorBase):
class
OfflinePredictor
(
OnlinePredictor
):
class
OfflinePredictor
(
OnlinePredictor
):
"""
Build a predictor from a given config, in an independent graph
"""
"""
A predictor built from a given config, in a new graph.
"""
def
__init__
(
self
,
config
):
def
__init__
(
self
,
config
):
"""
Args:
config (PredictConfig): the config to use.
"""
self
.
graph
=
tf
.
Graph
()
self
.
graph
=
tf
.
Graph
()
with
self
.
graph
.
as_default
():
with
self
.
graph
.
as_default
():
input_placehdrs
=
config
.
model
.
get_input_vars
()
input_placehdrs
=
config
.
model
.
get_input_vars
()
...
@@ -109,8 +138,10 @@ class OfflinePredictor(OnlinePredictor):
...
@@ -109,8 +138,10 @@ class OfflinePredictor(OnlinePredictor):
def
build_multi_tower_prediction_graph
(
build_tower_fn
,
towers
):
def
build_multi_tower_prediction_graph
(
build_tower_fn
,
towers
):
"""
"""
:param build_tower_fn: the function to be called inside each tower, taking tower as the argument
Args:
:param towers: a list of gpu relative id.
build_tower_fn: a function that will be called inside each tower,
taking tower id as the argument.
towers: a list of relative GPU id.
"""
"""
for
k
in
towers
:
for
k
in
towers
:
logger
.
info
(
logger
.
info
(
...
@@ -122,8 +153,14 @@ def build_multi_tower_prediction_graph(build_tower_fn, towers):
...
@@ -122,8 +153,14 @@ def build_multi_tower_prediction_graph(build_tower_fn, towers):
class
MultiTowerOfflinePredictor
(
OnlinePredictor
):
class
MultiTowerOfflinePredictor
(
OnlinePredictor
):
""" A multi-tower multi-GPU predictor. """
def
__init__
(
self
,
config
,
towers
):
def
__init__
(
self
,
config
,
towers
):
"""
Args:
config (PredictConfig): the config to use.
towers: a list of relative GPU id.
"""
self
.
graph
=
tf
.
Graph
()
self
.
graph
=
tf
.
Graph
()
self
.
predictors
=
[]
self
.
predictors
=
[]
with
self
.
graph
.
as_default
():
with
self
.
graph
.
as_default
():
...
@@ -149,12 +186,24 @@ class MultiTowerOfflinePredictor(OnlinePredictor):
...
@@ -149,12 +186,24 @@ class MultiTowerOfflinePredictor(OnlinePredictor):
return
self
.
predictors
[
0
]
.
_do_call
(
dp
)
return
self
.
predictors
[
0
]
.
_do_call
(
dp
)
def
get_predictors
(
self
,
n
):
def
get_predictors
(
self
,
n
):
"""
Returns:
PredictorBase: the nth predictor on the nth GPU.
"""
return
[
self
.
predictors
[
k
%
len
(
self
.
predictors
)]
for
k
in
range
(
n
)]
return
[
self
.
predictors
[
k
%
len
(
self
.
predictors
)]
for
k
in
range
(
n
)]
class
DataParallelOfflinePredictor
(
OnlinePredictor
):
class
DataParallelOfflinePredictor
(
OnlinePredictor
):
""" A data-parallel predictor.
It runs different towers in parallel.
"""
def
__init__
(
self
,
config
,
towers
):
def
__init__
(
self
,
config
,
towers
):
"""
Args:
config (PredictConfig): the config to use.
towers: a list of relative GPU id.
"""
self
.
graph
=
tf
.
Graph
()
self
.
graph
=
tf
.
Graph
()
with
self
.
graph
.
as_default
():
with
self
.
graph
.
as_default
():
sess
=
tf
.
Session
(
config
=
config
.
session_config
)
sess
=
tf
.
Session
(
config
=
config
.
session_config
)
...
...
tensorpack/predict/common.py
View file @
e0190688
...
@@ -2,7 +2,6 @@
...
@@ -2,7 +2,6 @@
# File: common.py
# File: common.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
from
collections
import
namedtuple
import
six
import
six
from
tensorpack.models
import
ModelDesc
from
tensorpack.models
import
ModelDesc
...
@@ -10,25 +9,20 @@ from ..tfutils import get_default_sess_config
...
@@ -10,25 +9,20 @@ from ..tfutils import get_default_sess_config
from
..tfutils.sessinit
import
SessionInit
,
JustCurrentSession
from
..tfutils.sessinit
import
SessionInit
,
JustCurrentSession
from
.base
import
OfflinePredictor
from
.base
import
OfflinePredictor
__all__
=
[
'PredictConfig'
,
'get_predict_func'
,
'PredictResult'
]
__all__
=
[
'PredictConfig'
,
'get_predict_func'
]
PredictResult
=
namedtuple
(
'PredictResult'
,
[
'input'
,
'output'
])
class
PredictConfig
(
object
):
class
PredictConfig
(
object
):
def
__init__
(
self
,
**
kwargs
):
def
__init__
(
self
,
**
kwargs
):
"""
"""
The config used by `get_predict_func`.
Args:
session_init (SessionInit): how to initialize variables of the session.
:param session_init: a `utils.sessinit.SessionInit` instance to
model (ModelDesc): the model to use.
initialize variables of a session.
input_names (list): a list of input tensor names.
:param model: a `ModelDesc` instance
output_names (list): a list of names of the output tensors to predict, the
:param input_names: a list of input variable names.
tensors can be any computable tensor in the graph.
:param output_names: a list of names of the output tensors to predict, the
return_input: same as in :attr:`PredictorBase.return_input`.
variables can be any computable tensor in the graph.
Predict specific output might not require all input variables.
:param return_input: whether to return (input, output) pair or just output. default to False.
"""
"""
# TODO use the name "tensor" instead of "variable"
# TODO use the name "tensor" instead of "variable"
def
assert_type
(
v
,
tp
):
def
assert_type
(
v
,
tp
):
...
@@ -68,10 +62,6 @@ class PredictConfig(object):
...
@@ -68,10 +62,6 @@ class PredictConfig(object):
def
get_predict_func
(
config
):
def
get_predict_func
(
config
):
"""
"""
Produce a offline predictor run inside a new session.
Equivalent to ``OfflinePredictor(config)``.
:param config: a `PredictConfig` instance.
:returns: A callable predictor that takes a list of input values, and return
a list of output values defined in ``config.output_var_names``.
"""
"""
return
OfflinePredictor
(
config
)
return
OfflinePredictor
(
config
)
tensorpack/predict/concurrency.py
View file @
e0190688
...
@@ -32,8 +32,9 @@ class MultiProcessPredictWorker(multiprocessing.Process):
...
@@ -32,8 +32,9 @@ class MultiProcessPredictWorker(multiprocessing.Process):
def
__init__
(
self
,
idx
,
config
):
def
__init__
(
self
,
idx
,
config
):
"""
"""
:param idx: index of the worker. the 0th worker will print log.
Args:
:param config: a `PredictConfig`
idx (int): index of the worker. the 0th worker will print log.
config (PredictConfig): the config to use.
"""
"""
super
(
MultiProcessPredictWorker
,
self
)
.
__init__
()
super
(
MultiProcessPredictWorker
,
self
)
.
__init__
()
self
.
idx
=
idx
self
.
idx
=
idx
...
@@ -53,12 +54,17 @@ class MultiProcessPredictWorker(multiprocessing.Process):
...
@@ -53,12 +54,17 @@ class MultiProcessPredictWorker(multiprocessing.Process):
class
MultiProcessQueuePredictWorker
(
MultiProcessPredictWorker
):
class
MultiProcessQueuePredictWorker
(
MultiProcessPredictWorker
):
""" An offline predictor worker that takes input and produces output by queue"""
"""
An offline predictor worker that takes input and produces output by queue.
Each process will exit when they see :class:`DIE`.
"""
def
__init__
(
self
,
idx
,
inqueue
,
outqueue
,
config
):
def
__init__
(
self
,
idx
,
inqueue
,
outqueue
,
config
):
"""
"""
:param inqueue: input queue to get data point. elements are (task_id, dp)
Args:
:param outqueue: output queue put result. elements are (task_id, output)
idx, config: same as in :class:`MultiProcessPredictWorker`.
inqueue (multiprocessing.Queue): input queue to get data point. elements are (task_id, dp)
outqueue (multiprocessing.Queue): output queue to put result. elements are (task_id, output)
"""
"""
super
(
MultiProcessQueuePredictWorker
,
self
)
.
__init__
(
idx
,
config
)
super
(
MultiProcessQueuePredictWorker
,
self
)
.
__init__
(
idx
,
config
)
self
.
inqueue
=
inqueue
self
.
inqueue
=
inqueue
...
@@ -125,12 +131,16 @@ class PredictorWorkerThread(threading.Thread):
...
@@ -125,12 +131,16 @@ class PredictorWorkerThread(threading.Thread):
class
MultiThreadAsyncPredictor
(
AsyncPredictorBase
):
class
MultiThreadAsyncPredictor
(
AsyncPredictorBase
):
"""
"""
An multithread online async predictor which run a list of PredictorBase.
An multithread online async predictor which run
s
a list of PredictorBase.
It would do an extra batching internally.
It would do an extra batching internally.
"""
"""
def
__init__
(
self
,
predictors
,
batch_size
=
5
):
def
__init__
(
self
,
predictors
,
batch_size
=
5
):
""" :param predictors: a list of OnlinePredictor"""
"""
Args:
predictors (list): a list of OnlinePredictor avaiable to use.
batch_size (int): the maximum of an internal batch.
"""
assert
len
(
predictors
)
assert
len
(
predictors
)
for
k
in
predictors
:
for
k
in
predictors
:
# assert isinstance(k, OnlinePredictor), type(k)
# assert isinstance(k, OnlinePredictor), type(k)
...
@@ -156,7 +166,7 @@ class MultiThreadAsyncPredictor(AsyncPredictorBase):
...
@@ -156,7 +166,7 @@ class MultiThreadAsyncPredictor(AsyncPredictorBase):
def
put_task
(
self
,
dp
,
callback
=
None
):
def
put_task
(
self
,
dp
,
callback
=
None
):
"""
"""
dp must be non-batched, i.e. single instance
Same as in :meth:`AsyncPredictorBase.put_task`.
"""
"""
f
=
Future
()
f
=
Future
()
if
callback
is
not
None
:
if
callback
is
not
None
:
...
...
tensorpack/predict/dataset.py
View file @
e0190688
...
@@ -25,11 +25,15 @@ __all__ = ['DatasetPredictorBase', 'SimpleDatasetPredictor',
...
@@ -25,11 +25,15 @@ __all__ = ['DatasetPredictorBase', 'SimpleDatasetPredictor',
@
six
.
add_metaclass
(
ABCMeta
)
@
six
.
add_metaclass
(
ABCMeta
)
class
DatasetPredictorBase
(
object
):
class
DatasetPredictorBase
(
object
):
""" Base class for dataset predictors.
These are predictors which run over a :class:`DataFlow`.
"""
def
__init__
(
self
,
config
,
dataset
):
def
__init__
(
self
,
config
,
dataset
):
"""
"""
:param config: a `PredictConfig` instance.
Args:
:param dataset: a `DataFlow` instance.
config (PredictConfig): the config of predictor.
dataset (DataFlow): the DataFlow to run on.
"""
"""
assert
isinstance
(
dataset
,
DataFlow
)
assert
isinstance
(
dataset
,
DataFlow
)
assert
isinstance
(
config
,
PredictConfig
)
assert
isinstance
(
config
,
PredictConfig
)
...
@@ -38,27 +42,29 @@ class DatasetPredictorBase(object):
...
@@ -38,27 +42,29 @@ class DatasetPredictorBase(object):
@
abstractmethod
@
abstractmethod
def
get_result
(
self
):
def
get_result
(
self
):
""" A generator function, produce output for each input in dataset"""
"""
Yields:
output for each datapoint in the DataFlow.
"""
pass
pass
def
get_all_result
(
self
):
def
get_all_result
(
self
):
"""
"""
Run over the dataset and return a list of all predictions.
Returns:
list: all outputs for all datapoints in the DataFlow.
"""
"""
return
list
(
self
.
get_result
())
return
list
(
self
.
get_result
())
class
SimpleDatasetPredictor
(
DatasetPredictorBase
):
class
SimpleDatasetPredictor
(
DatasetPredictorBase
):
"""
"""
Run the predict_config on a given `DataFlow`
.
Simply create one predictor and run it on the DataFlow
.
"""
"""
def
__init__
(
self
,
config
,
dataset
):
def
__init__
(
self
,
config
,
dataset
):
super
(
SimpleDatasetPredictor
,
self
)
.
__init__
(
config
,
dataset
)
super
(
SimpleDatasetPredictor
,
self
)
.
__init__
(
config
,
dataset
)
self
.
predictor
=
OfflinePredictor
(
config
)
self
.
predictor
=
OfflinePredictor
(
config
)
def
get_result
(
self
):
def
get_result
(
self
):
""" A generator to produce prediction for each data"""
self
.
dataset
.
reset_state
()
self
.
dataset
.
reset_state
()
try
:
try
:
sz
=
self
.
dataset
.
size
()
sz
=
self
.
dataset
.
size
()
...
@@ -70,20 +76,26 @@ class SimpleDatasetPredictor(DatasetPredictorBase):
...
@@ -70,20 +76,26 @@ class SimpleDatasetPredictor(DatasetPredictorBase):
yield
res
yield
res
pbar
.
update
()
pbar
.
update
()
# TODO allow unordered
class
MultiProcessDatasetPredictor
(
DatasetPredictorBase
):
class
MultiProcessDatasetPredictor
(
DatasetPredictorBase
):
"""
Run prediction in multiprocesses, on either CPU or GPU.
Each process fetch datapoints as tasks and run predictions independently.
"""
# TODO allow unordered
def
__init__
(
self
,
config
,
dataset
,
nr_proc
,
use_gpu
=
True
,
ordered
=
True
):
def
__init__
(
self
,
config
,
dataset
,
nr_proc
,
use_gpu
=
True
,
ordered
=
True
):
"""
"""
Run prediction in multiprocesses, on either CPU or GPU. Mix mode not supported.
Args:
config: same as in :class:`DatasetPredictorBase`.
:param nr_proc: number of processes to use
dataset: same as in :class:`DatasetPredictorBase`.
:param use_gpu: use GPU or CPU.
nr_proc (int): number of processes to use
If GPU, then nr_proc cannot be more than what's in CUDA_VISIBLE_DEVICES
use_gpu (bool): use GPU or CPU.
:param ordered: produce results with the original order of the
If GPU, then ``nr_proc`` cannot be more than what's in
dataflow. a bit slower.
CUDA_VISIBLE_DEVICES.
ordered (bool): produce outputs in the original order of the
datapoints. This will be a bit slower. Otherwise, :meth:`get_result` will produce
outputs in any order.
"""
"""
if
config
.
return_input
:
if
config
.
return_input
:
logger
.
warn
(
"Using the option `return_input` in MultiProcessDatasetPredictor might be slow"
)
logger
.
warn
(
"Using the option `return_input` in MultiProcessDatasetPredictor might be slow"
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment