Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
df82c65a
Commit
df82c65a
authored
Jul 18, 2019
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[A3C] code simplification
parent
d451368a
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
32 additions
and
27 deletions
+32
-27
examples/A3C-Gym/simulator.py
examples/A3C-Gym/simulator.py
+19
-19
examples/A3C-Gym/train-atari.py
examples/A3C-Gym/train-atari.py
+13
-8
No files found.
examples/A3C-Gym/simulator.py
View file @
df82c65a
...
...
@@ -18,7 +18,6 @@ from tensorpack.utils.concurrency import LoopThread, enable_death_signal, ensure
from
tensorpack.utils.serialize
import
dumps
,
loads
__all__
=
[
'SimulatorProcess'
,
'SimulatorMaster'
,
'SimulatorProcessStateExchange'
,
'TransitionExperience'
]
...
...
@@ -35,19 +34,7 @@ class TransitionExperience(object):
@
six
.
add_metaclass
(
ABCMeta
)
class
SimulatorProcessBase
(
mp
.
Process
):
def
__init__
(
self
,
idx
):
super
(
SimulatorProcessBase
,
self
)
.
__init__
()
self
.
idx
=
int
(
idx
)
self
.
name
=
u'simulator-{}'
.
format
(
self
.
idx
)
self
.
identity
=
self
.
name
.
encode
(
'utf-8'
)
@
abstractmethod
def
_build_player
(
self
):
pass
class
SimulatorProcessStateExchange
(
SimulatorProcessBase
):
class
SimulatorProcess
(
mp
.
Process
):
"""
A process that simulates a player and communicates to master to
send states and receive the next action
...
...
@@ -59,7 +46,11 @@ class SimulatorProcessStateExchange(SimulatorProcessBase):
idx: idx of this process
pipe_c2s, pipe_s2c (str): name of the pipe
"""
super
(
SimulatorProcessStateExchange
,
self
)
.
__init__
(
idx
)
super
(
SimulatorProcess
,
self
)
.
__init__
()
self
.
idx
=
int
(
idx
)
self
.
name
=
u'simulator-{}'
.
format
(
self
.
idx
)
self
.
identity
=
self
.
name
.
encode
(
'utf-8'
)
self
.
c2s
=
pipe_c2s
self
.
s2c
=
pipe_s2c
...
...
@@ -90,13 +81,14 @@ class SimulatorProcessStateExchange(SimulatorProcessBase):
if
isOver
:
state
=
player
.
reset
()
# compatibility
SimulatorProcess
=
SimulatorProcessStateExchange
@
abstractmethod
def
_build_player
(
self
):
pass
@
six
.
add_metaclass
(
ABCMeta
)
class
SimulatorMaster
(
threading
.
Thread
):
""" A base thread to communicate with all S
tateExchangeS
imulatorProcess.
""" A base thread to communicate with all SimulatorProcess.
It should produce action for each simulator, as well as
defining callbacks when a transition or an episode is finished.
"""
...
...
@@ -106,6 +98,10 @@ class SimulatorMaster(threading.Thread):
self
.
ident
=
None
def
__init__
(
self
,
pipe_c2s
,
pipe_s2c
):
"""
Args:
pipe_c2s, pipe_s2c (str): names of pipe to be used for communication
"""
super
(
SimulatorMaster
,
self
)
.
__init__
()
assert
os
.
name
!=
'nt'
,
"Doesn't support windows!"
self
.
daemon
=
True
...
...
@@ -152,6 +148,10 @@ class SimulatorMaster(threading.Thread):
except
zmq
.
ContextTerminated
:
logger
.
info
(
"[Simulator] Context was terminated."
)
@
abstractmethod
def
_process_msg
(
self
,
client
,
state
,
reward
,
isOver
):
pass
def
__del__
(
self
):
self
.
context
.
destroy
(
linger
=
0
)
...
...
examples/A3C-Gym/train-atari.py
View file @
df82c65a
...
...
@@ -139,12 +139,16 @@ class Model(ModelDesc):
class
MySimulatorMaster
(
SimulatorMaster
,
Callback
):
def
__init__
(
self
,
pipe_c2s
,
pipe_s2c
,
gpus
):
"""
Args:
gpus (list[int]): the gpus used to run inference
"""
super
(
MySimulatorMaster
,
self
)
.
__init__
(
pipe_c2s
,
pipe_s2c
)
self
.
queue
=
queue
.
Queue
(
maxsize
=
BATCH_SIZE
*
8
*
2
)
self
.
_gpus
=
gpus
def
_setup_graph
(
self
):
#
c
reate predictors on the available predictor GPUs.
#
C
reate predictors on the available predictor GPUs.
num_gpu
=
len
(
self
.
_gpus
)
predictors
=
[
self
.
trainer
.
get_predictor
(
[
'state'
],
[
'policy'
,
'pred_value'
],
...
...
@@ -155,6 +159,8 @@ class MySimulatorMaster(SimulatorMaster, Callback):
def
_before_train
(
self
):
self
.
async_predictor
.
start
()
logger
.
info
(
"Starting MySimulatorMaster ..."
)
start_proc_mask_signal
(
self
)
def
_on_state
(
self
,
state
,
client
):
"""
...
...
@@ -208,6 +214,10 @@ class MySimulatorMaster(SimulatorMaster, Callback):
else
:
client
.
memory
=
[]
def
get_training_dataflow
(
self
):
# the queue contains batched experience
return
BatchData
(
DataFromQueue
(
self
.
queue
),
BATCH_SIZE
)
def
train
():
assert
tf
.
test
.
is_gpu_available
(),
"Training requires GPUs!"
...
...
@@ -242,24 +252,19 @@ def train():
start_proc_mask_signal
(
procs
)
master
=
MySimulatorMaster
(
namec2s
,
names2c
,
predict_tower
)
dataflow
=
BatchData
(
DataFromQueue
(
master
.
queue
),
BATCH_SIZE
)
config
=
TrainConfig
(
model
=
Model
(),
dataflow
=
dataflow
,
dataflow
=
master
.
get_training_dataflow
()
,
callbacks
=
[
ModelSaver
(),
ScheduledHyperParamSetter
(
'learning_rate'
,
[(
20
,
0.0003
),
(
120
,
0.0001
)]),
ScheduledHyperParamSetter
(
'entropy_beta'
,
[(
80
,
0.005
)]),
HumanHyperParamSetter
(
'learning_rate'
),
HumanHyperParamSetter
(
'entropy_beta'
),
master
,
StartProcOrThread
(
master
),
PeriodicTrigger
(
Evaluator
(
EVAL_EPISODE
,
[
'state'
],
[
'policy'
],
get_player
),
every_k_epochs
=
3
),
],
session_creator
=
sesscreate
.
NewSessionCreator
(
config
=
get_default_sess_config
(
0.5
)),
session_creator
=
sesscreate
.
NewSessionCreator
(
config
=
get_default_sess_config
(
0.5
)),
steps_per_epoch
=
STEPS_PER_EPOCH
,
session_init
=
get_model_loader
(
args
.
load
)
if
args
.
load
else
None
,
max_epoch
=
1000
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment