Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
4fc21080
Commit
4fc21080
authored
Jul 17, 2016
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
async predictor base
parent
e04d846a
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
65 additions
and
22 deletions
+65
-22
examples/Atari2600/README.md
examples/Atari2600/README.md
+1
-1
tensorpack/predict/base.py
tensorpack/predict/base.py
+39
-2
tensorpack/predict/concurrency.py
tensorpack/predict/concurrency.py
+22
-17
tensorpack/train/base.py
tensorpack/train/base.py
+1
-1
tensorpack/train/trainer.py
tensorpack/train/trainer.py
+2
-1
No files found.
examples/Atari2600/README.md
View file @
4fc21080
Reproduce the following methods:
Reproduce the following
reinforcement learning
methods:
+
Nature-DQN in:
+
Nature-DQN in:
[
Human-level Control Through Deep Reinforcement Learning
](
http://www.nature.com/nature/journal/v518/n7540/full/nature14236.html
)
[
Human-level Control Through Deep Reinforcement Learning
](
http://www.nature.com/nature/journal/v518/n7540/full/nature14236.html
)
...
...
tensorpack/predict/base.py
View file @
4fc21080
...
@@ -5,9 +5,10 @@
...
@@ -5,9 +5,10 @@
from
abc
import
abstractmethod
,
ABCMeta
,
abstractproperty
from
abc
import
abstractmethod
,
ABCMeta
,
abstractproperty
import
tensorflow
as
tf
import
tensorflow
as
tf
import
six
from
..tfutils
import
get_vars_by_names
from
..tfutils
import
get_vars_by_names
__all__
=
[
'OnlinePredictor'
,
'OfflinePredictor'
]
__all__
=
[
'OnlinePredictor'
,
'OfflinePredictor'
,
'AsyncPredictorBase'
]
class
PredictorBase
(
object
):
class
PredictorBase
(
object
):
...
@@ -31,7 +32,27 @@ class PredictorBase(object):
...
@@ -31,7 +32,27 @@ class PredictorBase(object):
:param dp: input datapoint. must have the same length as input_var_names
:param dp: input datapoint. must have the same length as input_var_names
:return: output as defined by the config
:return: output as defined by the config
"""
"""
pass
class
AsyncPredictorBase
(
PredictorBase
):
@
abstractmethod
def
put_task
(
self
,
dp
,
callback
=
None
):
"""
:param dp: A data point (list of component) as inputs.
(It should be either batched or not batched depending on the predictor implementation)
:param callback: a thread-safe callback to get called with the list of
outputs of (inputs, outputs) pair
:return: a Future of outputs
"""
@
abstractmethod
def
start
(
self
):
""" Start workers """
def
_do_call
(
self
,
dp
):
assert
six
.
PY3
,
"With Python2, sync methods not available for async predictor"
fut
=
self
.
put_task
(
dp
)
# in Tornado, Future.result() doesn't wait
return
fut
.
result
()
class
OnlinePredictor
(
PredictorBase
):
class
OnlinePredictor
(
PredictorBase
):
def
__init__
(
self
,
sess
,
input_vars
,
output_vars
,
return_input
=
False
):
def
__init__
(
self
,
sess
,
input_vars
,
output_vars
,
return_input
=
False
):
...
@@ -64,3 +85,19 @@ class OfflinePredictor(OnlinePredictor):
...
@@ -64,3 +85,19 @@ class OfflinePredictor(OnlinePredictor):
config
.
session_init
.
init
(
sess
)
config
.
session_init
.
init
(
sess
)
super
(
OfflinePredictor
,
self
)
.
__init__
(
super
(
OfflinePredictor
,
self
)
.
__init__
(
sess
,
input_vars
,
output_vars
,
config
.
return_input
)
sess
,
input_vars
,
output_vars
,
config
.
return_input
)
class
AsyncOnlinePredictor
(
PredictorBase
):
def
__init__
(
self
,
sess
,
enqueue_op
,
output_vars
,
return_input
=
False
):
"""
:param enqueue_op: an op to feed inputs with.
:param output_vars: a list of directly-runnable (no extra feeding requirements)
vars producing the outputs.
"""
self
.
session
=
sess
self
.
enqop
=
enqueue_op
self
.
output_vars
=
output_vars
self
.
return_input
=
return_input
def
put_task
(
self
,
dp
,
callback
):
pass
tensorpack/predict/concurrency.py
View file @
4fc21080
...
@@ -16,7 +16,7 @@ from ..utils import logger
...
@@ -16,7 +16,7 @@ from ..utils import logger
from
..utils.timer
import
*
from
..utils.timer
import
*
from
..tfutils
import
*
from
..tfutils
import
*
from
.base
import
OfflinePredictor
from
.base
import
*
try
:
try
:
if
six
.
PY2
:
if
six
.
PY2
:
...
@@ -116,34 +116,39 @@ class PredictorWorkerThread(threading.Thread):
...
@@ -116,34 +116,39 @@ class PredictorWorkerThread(threading.Thread):
cnt
+=
1
cnt
+=
1
return
batched
,
futures
return
batched
,
futures
class
MultiThreadAsyncPredictor
(
object
):
class
MultiThreadAsyncPredictor
(
AsyncPredictorBase
):
"""
"""
An multithread
predictor which run a list of predict func
.
An multithread
online async predictor which run a list of OnlinePredictor
.
Use async interface, support multi-thread and multi-GPU
.
It would do an extra batching internally
.
"""
"""
def
__init__
(
self
,
funcs
,
batch_size
=
5
):
def
__init__
(
self
,
predictors
,
batch_size
=
5
):
""" :param funcs: a list of predict func"""
""" :param predictors: a list of OnlinePredictor"""
self
.
input_queue
=
queue
.
Queue
(
maxsize
=
len
(
funcs
)
*
10
)
for
k
in
predictors
:
assert
isinstance
(
k
,
OnlinePredictor
),
type
(
k
)
self
.
input_queue
=
queue
.
Queue
(
maxsize
=
len
(
predictors
)
*
10
)
self
.
threads
=
[
self
.
threads
=
[
PredictorWorkerThread
(
PredictorWorkerThread
(
self
.
input_queue
,
f
,
id
,
batch_size
=
batch_size
)
self
.
input_queue
,
f
,
id
,
batch_size
=
batch_size
)
for
id
,
f
in
enumerate
(
func
s
)]
for
id
,
f
in
enumerate
(
predictor
s
)]
# TODO XXX set logging here to avoid affecting TF logging
if
six
.
PY2
:
import
tornado.options
as
options
# TODO XXX set logging here to avoid affecting TF logging
options
.
parse_command_line
([
'--logging=debug'
])
import
tornado.options
as
options
options
.
parse_command_line
([
'--logging=debug'
])
def
run
(
self
):
def
start
(
self
):
for
t
in
self
.
threads
:
for
t
in
self
.
threads
:
t
.
start
()
t
.
start
()
def
put_task
(
self
,
inputs
,
callback
=
None
):
def
run
(
self
):
# temporarily for back-compatibility
self
.
start
()
def
put_task
(
self
,
dp
,
callback
=
None
):
"""
dp must be non-batched, i.e. single instance
"""
"""
:param inputs: a data point (list of component) matching input_names (not batched)
:param callback: a thread-safe callback to get called with the list of outputs
:returns: a Future of output."""
f
=
Future
()
f
=
Future
()
if
callback
is
not
None
:
if
callback
is
not
None
:
f
.
add_done_callback
(
callback
)
f
.
add_done_callback
(
callback
)
self
.
input_queue
.
put
((
inputs
,
f
))
self
.
input_queue
.
put
((
dp
,
f
))
return
f
return
f
tensorpack/train/base.py
View file @
4fc21080
...
@@ -62,7 +62,7 @@ class Trainer(object):
...
@@ -62,7 +62,7 @@ class Trainer(object):
Can be overwritten by subclasses to exploit more
Can be overwritten by subclasses to exploit more
parallelism among funcs.
parallelism among funcs.
"""
"""
return
[
self
.
get_predict_func
(
input_name
,
output_names
)
for
k
in
range
(
n
)]
return
[
self
.
get_predict_func
(
input_name
s
,
output_names
)
for
k
in
range
(
n
)]
def
trigger_epoch
(
self
):
def
trigger_epoch
(
self
):
self
.
_trigger_epoch
()
self
.
_trigger_epoch
()
...
...
tensorpack/train/trainer.py
View file @
4fc21080
...
@@ -30,6 +30,7 @@ class PredictorFactory(object):
...
@@ -30,6 +30,7 @@ class PredictorFactory(object):
self
.
tower_built
=
False
self
.
tower_built
=
False
def
get_predictor
(
self
,
input_names
,
output_names
,
tower
):
def
get_predictor
(
self
,
input_names
,
output_names
,
tower
):
""" Return an online predictor"""
if
not
self
.
tower_built
:
if
not
self
.
tower_built
:
self
.
_build_predict_tower
()
self
.
_build_predict_tower
()
tower
=
self
.
towers
[
tower
%
len
(
self
.
towers
)]
tower
=
self
.
towers
[
tower
%
len
(
self
.
towers
)]
...
@@ -204,7 +205,7 @@ class QueueInputTrainer(Trainer):
...
@@ -204,7 +205,7 @@ class QueueInputTrainer(Trainer):
self
.
main_loop
()
self
.
main_loop
()
def
run_step
(
self
):
def
run_step
(
self
):
"""
just
run self.train_op"""
"""
Simply
run self.train_op"""
self
.
sess
.
run
(
self
.
train_op
)
self
.
sess
.
run
(
self
.
train_op
)
#run_metadata = tf.RunMetadata()
#run_metadata = tf.RunMetadata()
#self.sess.run([self.train_op],
#self.sess.run([self.train_op],
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment