Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
64a63c5e
Commit
64a63c5e
authored
Jul 29, 2016
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
multi tower prediction graph
parent
af2c0e9c
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
86 additions
and
42 deletions
+86
-42
examples/Atari2600/README.md
examples/Atari2600/README.md
+1
-1
tensorpack/RL/common.py
tensorpack/RL/common.py
+4
-1
tensorpack/models/model_desc.py
tensorpack/models/model_desc.py
+2
-1
tensorpack/predict/base.py
tensorpack/predict/base.py
+45
-15
tensorpack/predict/concurrency.py
tensorpack/predict/concurrency.py
+6
-2
tensorpack/tfutils/gradproc.py
tensorpack/tfutils/gradproc.py
+7
-4
tensorpack/train/trainer.py
tensorpack/train/trainer.py
+21
-18
No files found.
examples/Atari2600/README.md
View file @
64a63c5e
...
...
@@ -21,7 +21,7 @@ Both were trained on one GPU with an extra GPU for simulation.
This is probably the fastest RL trainer you'd find.
The x-axis is the number of iterations, not wall time.
Iteration speed on Tesla M40 is about
10
.7it/s for B-A3C.
Iteration speed on Tesla M40 is about
9
.7it/s for B-A3C.
D-DQN is faster at the beginning but will converge to 12it/s due of exploration annealing.
A demo trained with Double-DQN on breakout is available at
[
youtube
](
https://youtu.be/o21mddZtE5Y
)
.
...
...
tensorpack/RL/common.py
View file @
64a63c5e
...
...
@@ -21,6 +21,7 @@ class PreventStuckPlayer(ProxyPlayer):
"""
:param nr_repeat: trigger the 'action' after this many of repeated action
:param action: the action to be triggered to get out of stuck
Does auto-reset, but doesn't auto-restart the underlying player.
"""
super
(
PreventStuckPlayer
,
self
)
.
__init__
(
player
)
self
.
act_que
=
deque
(
maxlen
=
nr_repeat
)
...
...
@@ -41,7 +42,7 @@ class PreventStuckPlayer(ProxyPlayer):
class
LimitLengthPlayer
(
ProxyPlayer
):
""" Limit the total number of actions in an episode.
Does
not auto restart
.
Does
auto-reset, but doesn't auto-restart the underlying player
.
"""
def
__init__
(
self
,
player
,
limit
):
super
(
LimitLengthPlayer
,
self
)
.
__init__
(
player
)
...
...
@@ -53,6 +54,8 @@ class LimitLengthPlayer(ProxyPlayer):
self
.
cnt
+=
1
if
self
.
cnt
>=
self
.
limit
:
isOver
=
True
if
isOver
:
self
.
cnt
=
0
return
(
r
,
isOver
)
def
restart_episode
(
self
):
...
...
tensorpack/models/model_desc.py
View file @
64a63c5e
...
...
@@ -71,7 +71,8 @@ class ModelDesc(object):
def
get_gradient_processor
(
self
):
""" Return a list of GradientProcessor. They will be executed in order"""
return
[
CheckGradient
()]
#, SummaryGradient()]
return
[
#SummaryGradient(),
CheckGradient
()]
class
ModelFromMetaGraph
(
ModelDesc
):
...
...
tensorpack/predict/base.py
View file @
64a63c5e
...
...
@@ -6,9 +6,13 @@
from
abc
import
abstractmethod
,
ABCMeta
,
abstractproperty
import
tensorflow
as
tf
import
six
from
..utils
import
logger
from
..tfutils
import
get_vars_by_names
__all__
=
[
'OnlinePredictor'
,
'OfflinePredictor'
,
'AsyncPredictorBase'
]
__all__
=
[
'OnlinePredictor'
,
'OfflinePredictor'
,
'AsyncPredictorBase'
,
'MultiTowerOfflinePredictor'
,
'build_multi_tower_prediction_graph'
]
class
PredictorBase
(
object
):
...
...
@@ -87,17 +91,43 @@ class OfflinePredictor(OnlinePredictor):
sess
,
input_vars
,
output_vars
,
config
.
return_input
)
#class AsyncOnlinePredictor(PredictorBase):
#def __init__(self, sess, enqueue_op, output_vars, return_input=False):
#"""
#:param enqueue_op: an op to feed inputs with.
#:param output_vars: a list of directly-runnable (no extra feeding requirements)
#vars producing the outputs.
#"""
#self.session = sess
#self.enqop = enqueue_op
#self.output_vars = output_vars
#self.return_input = return_input
#def put_task(self, dp, callback):
#pass
def
build_multi_tower_prediction_graph
(
model
,
towers
,
prefix
=
'towerp'
):
"""
:param towers: a list of gpu relative id.
"""
input_vars
=
model
.
get_input_vars
()
for
k
in
towers
:
logger
.
info
(
"Building graph for predictor tower {}..."
.
format
(
k
))
with
tf
.
device
(
'/gpu:{}'
.
format
(
k
)
if
k
>=
0
else
'/cpu:0'
),
\
tf
.
name_scope
(
'{}{}'
.
format
(
prefix
,
k
)):
model
.
_build_graph
(
input_vars
,
False
)
tf
.
get_variable_scope
()
.
reuse_variables
()
def
MultiTowerOfflinePredictor
(
OnlinePredictor
):
PREFIX
=
'towerp'
def
__init__
(
self
,
config
,
towers
):
self
.
graph
=
tf
.
Graph
()
self
.
predictors
=
[]
with
self
.
graph
.
as_default
():
# TODO backup summary keys?
build_multi_tower_prediction_graph
(
config
.
model
,
towers
,
self
.
PREFIX
)
self
.
sess
=
tf
.
Session
(
config
=
config
.
session_config
)
config
.
session_init
.
init
(
self
.
sess
)
input_vars
=
get_vars_by_names
(
config
.
input_var_names
)
# use the first tower for compatible PredictorBase interface
for
k
in
towers
:
output_vars
=
get_vars_by_names
(
[
'{}{}/'
.
format
(
self
.
PREFIX
,
k
)
+
n
\
for
n
in
config
.
output_var_names
])
self
.
predictors
.
append
(
OnlinePredictor
(
self
.
sess
,
input_vars
,
output_vars
,
config
.
return_input
))
def
_do_call
(
self
,
dp
):
return
self
.
predictors
[
0
]
.
_do_call
(
dp
)
def
get_predictors
(
self
,
n
):
return
[
self
.
predictors
[
k
%
len
(
self
.
predictors
)]
for
k
in
range
(
n
)]
tensorpack/predict/concurrency.py
View file @
64a63c5e
...
...
@@ -42,6 +42,9 @@ class MultiProcessPredictWorker(multiprocessing.Process):
self
.
config
=
config
def
_init_runtime
(
self
):
""" Call _init_runtime under different CUDA_VISIBLE_DEVICES, you'll
have workers that run on multiGPUs
"""
if
self
.
idx
!=
0
:
from
tensorpack.models._common
import
disable_layer_logging
disable_layer_logging
()
...
...
@@ -72,6 +75,7 @@ class MultiProcessQueuePredictWorker(MultiProcessPredictWorker):
else
:
self
.
outqueue
.
put
((
tid
,
self
.
func
(
dp
)))
class
PredictorWorkerThread
(
threading
.
Thread
):
def
__init__
(
self
,
queue
,
pred_func
,
id
,
batch_size
=
5
):
super
(
PredictorWorkerThread
,
self
)
.
__init__
()
...
...
@@ -118,13 +122,13 @@ class PredictorWorkerThread(threading.Thread):
class
MultiThreadAsyncPredictor
(
AsyncPredictorBase
):
"""
An multithread online async predictor which run a list of
OnlinePredictor
.
An multithread online async predictor which run a list of
PredictorBase
.
It would do an extra batching internally.
"""
def
__init__
(
self
,
predictors
,
batch_size
=
5
):
""" :param predictors: a list of OnlinePredictor"""
for
k
in
predictors
:
assert
isinstance
(
k
,
OnlinePredictor
),
type
(
k
)
#
assert isinstance(k, OnlinePredictor), type(k)
# TODO use predictors.return_input here
assert
k
.
return_input
==
False
self
.
input_queue
=
queue
.
Queue
(
maxsize
=
len
(
predictors
)
*
100
)
...
...
tensorpack/tfutils/gradproc.py
View file @
64a63c5e
...
...
@@ -48,16 +48,19 @@ class SummaryGradient(GradientProcessor):
name
=
name
+
'/RMS'
))
return
grads
class
CheckGradient
(
GradientProcessor
):
"""
Check for numeric issue
"""
def
_process
(
self
,
grads
):
ret
=
[]
for
grad
,
var
in
grads
:
# TODO make assert work
tf
.
Assert
(
tf
.
reduce_all
(
tf
.
is_finite
(
var
)),
[
var
])
return
grads
op
=
tf
.
Assert
(
tf
.
reduce_all
(
tf
.
is_finite
(
var
)),
[
var
],
summarize
=
100
)
with
tf
.
control_dependencies
([
op
]):
grad
=
tf
.
identity
(
grad
)
ret
.
append
((
grad
,
var
))
return
ret
class
ScaleGradient
(
GradientProcessor
):
"""
...
...
tensorpack/train/trainer.py
View file @
64a63c5e
...
...
@@ -16,7 +16,7 @@ from ..tfutils.modelutils import describe_model
from
..utils
import
*
from
..tfutils
import
*
from
..tfutils.summary
import
add_moving_summary
from
..predict
import
OnlinePredictor
from
..predict
import
OnlinePredictor
,
build_multi_tower_prediction_graph
__all__
=
[
'SimpleTrainer'
,
'QueueInputTrainer'
]
...
...
@@ -24,32 +24,35 @@ class PredictorFactory(object):
""" Make predictors for a trainer"""
PREFIX
=
'towerp'
def
__init__
(
self
,
trainer
,
towers
):
self
.
trainer
=
trainer
def
__init__
(
self
,
sess
,
model
,
towers
):
"""
:param towers: list of gpu relative id
"""
self
.
sess
=
sess
self
.
model
=
model
self
.
towers
=
towers
self
.
tower_built
=
False
def
get_predictor
(
self
,
input_names
,
output_names
,
tower
):
""" Return an online predictor"""
"""
:param tower: need the kth tower (not the gpu id)
:returns: an online predictor
"""
if
not
self
.
tower_built
:
self
.
_build_predict_tower
()
tower
=
self
.
towers
[
tower
%
len
(
self
.
towers
)]
raw_input_vars
=
get_vars_by_names
(
input_names
)
output_names
=
[
'{}{}/'
.
format
(
self
.
PREFIX
,
tower
)
+
n
for
n
in
output_names
]
output_vars
=
get_vars_by_names
(
output_names
)
return
OnlinePredictor
(
self
.
trainer
.
sess
,
raw_input_vars
,
output_vars
)
return
OnlinePredictor
(
self
.
sess
,
raw_input_vars
,
output_vars
)
def
_build_predict_tower
(
self
):
# build_predict_tower might get called anywhere, but 'towerp' should be the outermost name scope
tf
.
get_variable_scope
()
.
reuse_variables
()
with
tf
.
name_scope
(
None
),
\
freeze_collection
(
SUMMARY_BACKUP_KEYS
):
inputs
=
self
.
trainer
.
model
.
get_input_vars
()
tf
.
get_variable_scope
()
.
reuse_variables
()
for
k
in
self
.
towers
:
logger
.
info
(
"Building graph for predictor tower {}..."
.
format
(
k
))
with
tf
.
device
(
'/gpu:{}'
.
format
(
k
)
if
k
>=
0
else
'/cpu:0'
),
\
tf
.
name_scope
(
'{}{}'
.
format
(
self
.
PREFIX
,
k
)):
self
.
trainer
.
model
.
build_graph
(
inputs
,
False
)
build_multi_tower_prediction_graph
(
self
.
model
,
self
.
towers
,
prefix
=
self
.
PREFIX
)
self
.
tower_built
=
True
class
SimpleTrainer
(
Trainer
):
...
...
@@ -89,7 +92,7 @@ class SimpleTrainer(Trainer):
def
get_predict_func
(
self
,
input_names
,
output_names
):
if
not
hasattr
(
self
,
'predictor_factory'
):
self
.
predictor_factory
=
PredictorFactory
(
self
,
[
0
])
self
.
predictor_factory
=
PredictorFactory
(
self
.
sess
,
self
.
model
,
[
0
])
return
self
.
predictor_factory
.
get_predictor
(
input_names
,
output_names
,
0
)
class
EnqueueThread
(
threading
.
Thread
):
...
...
@@ -150,11 +153,8 @@ class QueueInputTrainer(Trainer):
else
:
self
.
input_queue
=
input_queue
if
predict_tower
is
None
:
# by default, use the first training gpu for prediction
predict_tower
=
[
0
]
self
.
predictor_factory
=
PredictorFactory
(
self
,
predict_tower
)
# by default, use the first training gpu for prediction
self
.
predict_tower
=
predict_tower
or
[
0
]
self
.
dequed_inputs
=
None
def
_get_model_inputs
(
self
):
...
...
@@ -233,6 +233,9 @@ class QueueInputTrainer(Trainer):
:param tower: return the kth predict_func
:returns: an `OnlinePredictor`
"""
if
not
hasattr
(
self
,
'predictor_factory'
):
self
.
predictor_factory
=
PredictorFactory
(
self
.
sess
,
self
.
model
,
self
.
predict_tower
)
return
self
.
predictor_factory
.
get_predictor
(
input_names
,
output_names
,
tower
)
def
get_predict_funcs
(
self
,
input_names
,
output_names
,
n
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment