Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
58a41fca
Commit
58a41fca
authored
Aug 09, 2016
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
simulator improved
parent
b2fd9b0d
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
93 additions
and
47 deletions
+93
-47
examples/Atari2600/common.py
examples/Atari2600/common.py
+1
-1
tensorpack/RL/simulator.py
tensorpack/RL/simulator.py
+68
-36
tensorpack/dataflow/raw.py
tensorpack/dataflow/raw.py
+24
-1
tensorpack/tfutils/varmanip.py
tensorpack/tfutils/varmanip.py
+0
-9
No files found.
examples/Atari2600/common.py
View file @
58a41fca
...
...
@@ -88,7 +88,7 @@ class Evaluator(Callback):
self
.
input_names
=
input_names
self
.
output_names
=
output_names
def
_
before_train
(
self
):
def
_
setup_graph
(
self
):
NR_PROC
=
min
(
multiprocessing
.
cpu_count
()
//
2
,
8
)
self
.
pred_funcs
=
[
self
.
trainer
.
get_predict_func
(
self
.
input_names
,
self
.
output_names
)]
*
NR_PROC
...
...
tensorpack/RL/simulator.py
View file @
58a41fca
...
...
@@ -3,6 +3,7 @@
# File: simulator.py
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
import
tensorflow
as
tf
import
multiprocessing
as
mp
import
time
import
threading
...
...
@@ -13,12 +14,17 @@ import numpy as np
import
six
from
six.moves
import
queue
from
..callbacks
import
Callback
from
..tfutils.varmanip
import
SessionUpdate
from
..predict
import
OfflinePredictor
from
..utils
import
logger
from
..utils.timer
import
*
from
..utils.serialize
import
*
from
..utils.concurrency
import
*
__all__
=
[
'SimulatorProcess'
,
'SimulatorMaster'
,
'StateExchangeSimulatorProcess'
,
'SimulatorProcessSharedWeight'
]
'SimulatorProcessStateExchange'
,
'SimulatorProcessSharedWeight'
,
'TransitionExperience'
,
'WeightSync'
]
try
:
import
zmq
...
...
@@ -26,6 +32,16 @@ except ImportError:
logger
.
warn
(
"Error in 'import zmq'. RL simulator won't be available."
)
__all__
=
[]
class
TransitionExperience
(
object
):
""" A transition of state, or experience"""
def
__init__
(
self
,
state
,
action
,
reward
,
**
kwargs
):
""" kwargs: whatever other attribute you want to save"""
self
.
state
=
state
self
.
action
=
action
self
.
reward
=
reward
for
k
,
v
in
six
.
iteritems
(
kwargs
):
setattr
(
self
,
k
,
v
)
class
SimulatorProcessBase
(
mp
.
Process
):
__metaclass__
=
ABCMeta
...
...
@@ -39,7 +55,7 @@ class SimulatorProcessBase(mp.Process):
pass
class
S
tateExchangeSimulatorProcess
(
SimulatorProcessBase
):
class
S
imulatorProcessStateExchange
(
SimulatorProcessBase
):
"""
A process that simulates a player and communicates to master to
send states and receive the next action
...
...
@@ -50,7 +66,7 @@ class StateExchangeSimulatorProcess(SimulatorProcessBase):
"""
:param idx: idx of this process
"""
super
(
S
tateExchangeSimulatorProcess
,
self
)
.
__init__
(
idx
)
super
(
S
imulatorProcessStateExchange
,
self
)
.
__init__
(
idx
)
self
.
c2s
=
pipe_c2s
self
.
s2c
=
pipe_s2c
...
...
@@ -78,7 +94,7 @@ class StateExchangeSimulatorProcess(SimulatorProcessBase):
state
=
player
.
current_state
()
# compatibility
SimulatorProcess
=
S
tateExchangeSimulatorProcess
SimulatorProcess
=
S
imulatorProcessStateExchange
class
SimulatorMaster
(
threading
.
Thread
):
""" A base thread to communicate with all StateExchangeSimulatorProcess.
...
...
@@ -91,16 +107,6 @@ class SimulatorMaster(threading.Thread):
def
__init__
(
self
):
self
.
memory
=
[]
# list of Experience
class
Experience
(
object
):
""" A transition of state, or experience"""
def
__init__
(
self
,
state
,
action
,
reward
,
**
kwargs
):
""" kwargs: whatever other attribute you want to save"""
self
.
state
=
state
self
.
action
=
action
self
.
reward
=
reward
for
k
,
v
in
six
.
iteritems
(
kwargs
):
setattr
(
self
,
k
,
v
)
def
__init__
(
self
,
pipe_c2s
,
pipe_s2c
):
super
(
SimulatorMaster
,
self
)
.
__init__
()
self
.
daemon
=
True
...
...
@@ -170,8 +176,7 @@ class SimulatorMaster(threading.Thread):
"""
def
__del__
(
self
):
self
.
socket
.
close
()
self
.
context
.
term
()
self
.
context
.
destroy
(
linger
=
0
)
class
SimulatorProcessDF
(
SimulatorProcessBase
):
...
...
@@ -191,18 +196,15 @@ class SimulatorProcessDF(SimulatorProcessBase):
self
.
c2s_socket
.
connect
(
self
.
pipe_c2s
)
self
.
_prepare
()
while
True
:
dp
=
self
.
_produce_datapoint
()
self
.
c2s_socket
.
send
(
dumps
(
(
self
.
identity
,
dp
)
),
copy
=
False
)
for
dp
in
self
.
get_data
():
self
.
c2s_socket
.
send
(
dumps
(
dp
),
copy
=
False
)
@
abstractmethod
def
_prepare
(
self
):
pass
@
abstractmethod
def
_produce_datapoint
(
self
):
def
get_data
(
self
):
pass
...
...
@@ -212,31 +214,61 @@ class SimulatorProcessSharedWeight(SimulatorProcessDF):
Start me under some CUDA_VISIBLE_DEVICES set!
"""
def
__init__
(
self
,
idx
,
pipe_c2s
,
evt
,
shared_dic
):
def
__init__
(
self
,
idx
,
pipe_c2s
,
condvar
,
shared_dic
,
pred_config
):
super
(
SimulatorProcessSharedWeight
,
self
)
.
__init__
(
idx
,
pipe_c2s
)
self
.
evt
=
evt
self
.
condvar
=
condvar
self
.
shared_dic
=
shared_dic
self
.
pred_config
=
pred_config
def
_prepare
(
self
):
self
.
_build_session
()
self
.
predictor
=
OfflinePredictor
(
self
.
pred_config
)
with
self
.
predictor
.
graph
.
as_default
():
vars_to_update
=
self
.
_params_to_update
()
self
.
sess_updater
=
SessionUpdate
(
self
.
predictor
.
session
,
vars_to_update
)
# TODO setup callback for explore?
self
.
predictor
.
graph
.
finalize
()
# start a thread to wait for evt
self
.
weight_lock
=
threading
.
Lock
()
# start a thread to wait for notification
def
func
():
self
.
evt
.
wait
()
self
.
condvar
.
acquire
()
while
True
:
self
.
condvar
.
wait
()
self
.
_trigger_evt
()
self
.
evt_th
=
LoopThread
(
func
,
pausable
=
False
)
self
.
evt_th
=
threading
.
Thread
(
target
=
func
)
self
.
evt_th
.
daemon
=
True
self
.
evt_th
.
start
()
@
abstractmethod
def
_trigger_evt
(
self
):
pass
#
self.sess_updater.update(self.shared_dic['params'])
with
self
.
weight_lock
:
self
.
sess_updater
.
update
(
self
.
shared_dic
[
'params'
])
@
abstractmethod
def
_build_session
(
self
):
# build session and self.sess_updaer
pass
def
_params_to_update
(
self
):
# can be overwritten to update more params
return
tf
.
trainable_variables
()
class
WeightSync
(
Callback
):
""" Sync weight from main process to shared_dic and notify"""
def
__init__
(
self
,
condvar
,
shared_dic
):
self
.
condvar
=
condvar
self
.
shared_dic
=
shared_dic
def
_setup_graph
(
self
):
self
.
vars
=
self
.
_params_to_update
()
def
_params_to_update
(
self
):
# can be overwritten to update more params
return
tf
.
trainable_variables
()
def
_trigger_epoch
(
self
):
logger
.
info
(
"Updating weights ..."
)
dic
=
{
v
.
name
:
v
.
eval
()
for
v
in
self
.
vars
}
self
.
shared_dic
[
'params'
]
=
dic
self
.
condvar
.
acquire
()
self
.
condvar
.
notify_all
()
self
.
condvar
.
release
()
if
__name__
==
'__main__'
:
import
random
...
...
tensorpack/dataflow/raw.py
View file @
58a41fca
...
...
@@ -6,8 +6,15 @@
import
numpy
as
np
from
six.moves
import
range
from
.base
import
DataFlow
,
RNGDataFlow
from
..utils.serialize
import
loads
__all__
=
[
'FakeData'
,
'DataFromQueue'
,
'DataFromList'
]
try
:
import
zmq
except
:
pass
else
:
__all__
.
append
(
'DataFromSocket'
)
class
FakeData
(
RNGDataFlow
):
""" Generate fake fixed data of given shapes"""
...
...
@@ -43,7 +50,6 @@ class DataFromQueue(DataFlow):
while
True
:
yield
self
.
queue
.
get
()
class
DataFromList
(
RNGDataFlow
):
""" Produce data from a list"""
def
__init__
(
self
,
lst
,
shuffle
=
True
):
...
...
@@ -63,3 +69,20 @@ class DataFromList(RNGDataFlow):
for
k
in
idxs
:
yield
self
.
lst
[
k
]
class
DataFromSocket
(
DataFlow
):
""" Produce data from a zmq socket"""
def
__init__
(
self
,
socket_name
):
self
.
_name
=
socket_name
def
get_data
(
self
):
try
:
ctx
=
zmq
.
Context
()
socket
=
ctx
.
socket
(
zmq
.
PULL
)
socket
.
bind
(
self
.
_name
)
while
True
:
dp
=
loads
(
socket
.
recv
(
copy
=
False
))
yield
dp
finally
:
ctx
.
destroy
(
linger
=
0
)
tensorpack/tfutils/varmanip.py
View file @
58a41fca
...
...
@@ -61,9 +61,6 @@ class SessionUpdate(object):
for
name
,
value
in
six
.
iteritems
(
prms
):
assert
name
in
self
.
assign_ops
for
p
,
v
,
op
in
self
.
assign_ops
[
name
]:
if
'fc0/W'
in
name
:
import
IPython
as
IP
;
IP
.
embed
(
config
=
IP
.
terminal
.
ipapp
.
load_default_config
())
varshape
=
tuple
(
v
.
get_shape
()
.
as_list
())
if
varshape
!=
value
.
shape
:
# TODO only allow reshape when shape different by empty axis
...
...
@@ -71,13 +68,7 @@ class SessionUpdate(object):
"{}: {}!={}"
.
format
(
name
,
varshape
,
value
.
shape
)
logger
.
warn
(
"Param {} is reshaped during assigning"
.
format
(
name
))
value
=
value
.
reshape
(
varshape
)
if
'fc0/W'
in
name
:
import
IPython
as
IP
;
IP
.
embed
(
config
=
IP
.
terminal
.
ipapp
.
load_default_config
())
self
.
sess
.
run
(
op
,
feed_dict
=
{
p
:
value
})
if
'fc0/W'
in
name
:
import
IPython
as
IP
;
IP
.
embed
(
config
=
IP
.
terminal
.
ipapp
.
load_default_config
())
def
dump_session_params
(
path
):
""" Dump value of all trainable + to_save variables to a dict and save to `path` as
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment