Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
58a41fca
Commit
58a41fca
authored
Aug 09, 2016
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
simulator improved
parent
b2fd9b0d
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
93 additions
and
47 deletions
+93
-47
examples/Atari2600/common.py
examples/Atari2600/common.py
+1
-1
tensorpack/RL/simulator.py
tensorpack/RL/simulator.py
+68
-36
tensorpack/dataflow/raw.py
tensorpack/dataflow/raw.py
+24
-1
tensorpack/tfutils/varmanip.py
tensorpack/tfutils/varmanip.py
+0
-9
No files found.
examples/Atari2600/common.py
View file @
58a41fca
...
@@ -88,7 +88,7 @@ class Evaluator(Callback):
...
@@ -88,7 +88,7 @@ class Evaluator(Callback):
self
.
input_names
=
input_names
self
.
input_names
=
input_names
self
.
output_names
=
output_names
self
.
output_names
=
output_names
def
_
before_train
(
self
):
def
_
setup_graph
(
self
):
NR_PROC
=
min
(
multiprocessing
.
cpu_count
()
//
2
,
8
)
NR_PROC
=
min
(
multiprocessing
.
cpu_count
()
//
2
,
8
)
self
.
pred_funcs
=
[
self
.
trainer
.
get_predict_func
(
self
.
pred_funcs
=
[
self
.
trainer
.
get_predict_func
(
self
.
input_names
,
self
.
output_names
)]
*
NR_PROC
self
.
input_names
,
self
.
output_names
)]
*
NR_PROC
...
...
tensorpack/RL/simulator.py
View file @
58a41fca
...
@@ -3,6 +3,7 @@
...
@@ -3,6 +3,7 @@
# File: simulator.py
# File: simulator.py
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
import
tensorflow
as
tf
import
multiprocessing
as
mp
import
multiprocessing
as
mp
import
time
import
time
import
threading
import
threading
...
@@ -13,12 +14,17 @@ import numpy as np
...
@@ -13,12 +14,17 @@ import numpy as np
import
six
import
six
from
six.moves
import
queue
from
six.moves
import
queue
from
..callbacks
import
Callback
from
..tfutils.varmanip
import
SessionUpdate
from
..predict
import
OfflinePredictor
from
..utils
import
logger
from
..utils.timer
import
*
from
..utils.timer
import
*
from
..utils.serialize
import
*
from
..utils.serialize
import
*
from
..utils.concurrency
import
*
from
..utils.concurrency
import
*
__all__
=
[
'SimulatorProcess'
,
'SimulatorMaster'
,
__all__
=
[
'SimulatorProcess'
,
'SimulatorMaster'
,
'StateExchangeSimulatorProcess'
,
'SimulatorProcessSharedWeight'
]
'SimulatorProcessStateExchange'
,
'SimulatorProcessSharedWeight'
,
'TransitionExperience'
,
'WeightSync'
]
try
:
try
:
import
zmq
import
zmq
...
@@ -26,6 +32,16 @@ except ImportError:
...
@@ -26,6 +32,16 @@ except ImportError:
logger
.
warn
(
"Error in 'import zmq'. RL simulator won't be available."
)
logger
.
warn
(
"Error in 'import zmq'. RL simulator won't be available."
)
__all__
=
[]
__all__
=
[]
class
TransitionExperience
(
object
):
""" A transition of state, or experience"""
def
__init__
(
self
,
state
,
action
,
reward
,
**
kwargs
):
""" kwargs: whatever other attribute you want to save"""
self
.
state
=
state
self
.
action
=
action
self
.
reward
=
reward
for
k
,
v
in
six
.
iteritems
(
kwargs
):
setattr
(
self
,
k
,
v
)
class
SimulatorProcessBase
(
mp
.
Process
):
class
SimulatorProcessBase
(
mp
.
Process
):
__metaclass__
=
ABCMeta
__metaclass__
=
ABCMeta
...
@@ -39,7 +55,7 @@ class SimulatorProcessBase(mp.Process):
...
@@ -39,7 +55,7 @@ class SimulatorProcessBase(mp.Process):
pass
pass
class
S
tateExchangeSimulatorProcess
(
SimulatorProcessBase
):
class
S
imulatorProcessStateExchange
(
SimulatorProcessBase
):
"""
"""
A process that simulates a player and communicates to master to
A process that simulates a player and communicates to master to
send states and receive the next action
send states and receive the next action
...
@@ -50,7 +66,7 @@ class StateExchangeSimulatorProcess(SimulatorProcessBase):
...
@@ -50,7 +66,7 @@ class StateExchangeSimulatorProcess(SimulatorProcessBase):
"""
"""
:param idx: idx of this process
:param idx: idx of this process
"""
"""
super
(
S
tateExchangeSimulatorProcess
,
self
)
.
__init__
(
idx
)
super
(
S
imulatorProcessStateExchange
,
self
)
.
__init__
(
idx
)
self
.
c2s
=
pipe_c2s
self
.
c2s
=
pipe_c2s
self
.
s2c
=
pipe_s2c
self
.
s2c
=
pipe_s2c
...
@@ -78,7 +94,7 @@ class StateExchangeSimulatorProcess(SimulatorProcessBase):
...
@@ -78,7 +94,7 @@ class StateExchangeSimulatorProcess(SimulatorProcessBase):
state
=
player
.
current_state
()
state
=
player
.
current_state
()
# compatibility
# compatibility
SimulatorProcess
=
S
tateExchangeSimulatorProcess
SimulatorProcess
=
S
imulatorProcessStateExchange
class
SimulatorMaster
(
threading
.
Thread
):
class
SimulatorMaster
(
threading
.
Thread
):
""" A base thread to communicate with all StateExchangeSimulatorProcess.
""" A base thread to communicate with all StateExchangeSimulatorProcess.
...
@@ -91,16 +107,6 @@ class SimulatorMaster(threading.Thread):
...
@@ -91,16 +107,6 @@ class SimulatorMaster(threading.Thread):
def
__init__
(
self
):
def
__init__
(
self
):
self
.
memory
=
[]
# list of Experience
self
.
memory
=
[]
# list of Experience
class
Experience
(
object
):
""" A transition of state, or experience"""
def
__init__
(
self
,
state
,
action
,
reward
,
**
kwargs
):
""" kwargs: whatever other attribute you want to save"""
self
.
state
=
state
self
.
action
=
action
self
.
reward
=
reward
for
k
,
v
in
six
.
iteritems
(
kwargs
):
setattr
(
self
,
k
,
v
)
def
__init__
(
self
,
pipe_c2s
,
pipe_s2c
):
def
__init__
(
self
,
pipe_c2s
,
pipe_s2c
):
super
(
SimulatorMaster
,
self
)
.
__init__
()
super
(
SimulatorMaster
,
self
)
.
__init__
()
self
.
daemon
=
True
self
.
daemon
=
True
...
@@ -170,8 +176,7 @@ class SimulatorMaster(threading.Thread):
...
@@ -170,8 +176,7 @@ class SimulatorMaster(threading.Thread):
"""
"""
def
__del__
(
self
):
def
__del__
(
self
):
self
.
socket
.
close
()
self
.
context
.
destroy
(
linger
=
0
)
self
.
context
.
term
()
class
SimulatorProcessDF
(
SimulatorProcessBase
):
class
SimulatorProcessDF
(
SimulatorProcessBase
):
...
@@ -191,18 +196,15 @@ class SimulatorProcessDF(SimulatorProcessBase):
...
@@ -191,18 +196,15 @@ class SimulatorProcessDF(SimulatorProcessBase):
self
.
c2s_socket
.
connect
(
self
.
pipe_c2s
)
self
.
c2s_socket
.
connect
(
self
.
pipe_c2s
)
self
.
_prepare
()
self
.
_prepare
()
while
True
:
for
dp
in
self
.
get_data
():
dp
=
self
.
_produce_datapoint
()
self
.
c2s_socket
.
send
(
dumps
(
dp
),
copy
=
False
)
self
.
c2s_socket
.
send
(
dumps
(
(
self
.
identity
,
dp
)
),
copy
=
False
)
@
abstractmethod
@
abstractmethod
def
_prepare
(
self
):
def
_prepare
(
self
):
pass
pass
@
abstractmethod
@
abstractmethod
def
_produce_datapoint
(
self
):
def
get_data
(
self
):
pass
pass
...
@@ -212,31 +214,61 @@ class SimulatorProcessSharedWeight(SimulatorProcessDF):
...
@@ -212,31 +214,61 @@ class SimulatorProcessSharedWeight(SimulatorProcessDF):
Start me under some CUDA_VISIBLE_DEVICES set!
Start me under some CUDA_VISIBLE_DEVICES set!
"""
"""
def
__init__
(
self
,
idx
,
pipe_c2s
,
evt
,
shared_dic
):
def
__init__
(
self
,
idx
,
pipe_c2s
,
condvar
,
shared_dic
,
pred_config
):
super
(
SimulatorProcessSharedWeight
,
self
)
.
__init__
(
idx
,
pipe_c2s
)
super
(
SimulatorProcessSharedWeight
,
self
)
.
__init__
(
idx
,
pipe_c2s
)
self
.
evt
=
evt
self
.
condvar
=
condvar
self
.
shared_dic
=
shared_dic
self
.
shared_dic
=
shared_dic
self
.
pred_config
=
pred_config
def
_prepare
(
self
):
def
_prepare
(
self
):
self
.
_build_session
()
self
.
predictor
=
OfflinePredictor
(
self
.
pred_config
)
with
self
.
predictor
.
graph
.
as_default
():
vars_to_update
=
self
.
_params_to_update
()
self
.
sess_updater
=
SessionUpdate
(
self
.
predictor
.
session
,
vars_to_update
)
# TODO setup callback for explore?
self
.
predictor
.
graph
.
finalize
()
# start a thread to wait for evt
self
.
weight_lock
=
threading
.
Lock
()
# start a thread to wait for notification
def
func
():
def
func
():
self
.
evt
.
wait
()
self
.
condvar
.
acquire
()
self
.
_trigger_evt
()
while
True
:
self
.
evt_th
=
LoopThread
(
func
,
pausable
=
False
)
self
.
condvar
.
wait
()
self
.
_trigger_evt
()
self
.
evt_th
=
threading
.
Thread
(
target
=
func
)
self
.
evt_th
.
daemon
=
True
self
.
evt_th
.
start
()
self
.
evt_th
.
start
()
@
abstractmethod
def
_trigger_evt
(
self
):
def
_trigger_evt
(
self
):
pass
with
self
.
weight_lock
:
#
self.sess_updater.update(self.shared_dic['params'])
self
.
sess_updater
.
update
(
self
.
shared_dic
[
'params'
])
@
abstractmethod
def
_params_to_update
(
self
):
def
_build_session
(
self
):
# can be overwritten to update more params
# build session and self.sess_updaer
return
tf
.
trainable_variables
()
pass
class
WeightSync
(
Callback
):
""" Sync weight from main process to shared_dic and notify"""
def
__init__
(
self
,
condvar
,
shared_dic
):
self
.
condvar
=
condvar
self
.
shared_dic
=
shared_dic
def
_setup_graph
(
self
):
self
.
vars
=
self
.
_params_to_update
()
def
_params_to_update
(
self
):
# can be overwritten to update more params
return
tf
.
trainable_variables
()
def
_trigger_epoch
(
self
):
logger
.
info
(
"Updating weights ..."
)
dic
=
{
v
.
name
:
v
.
eval
()
for
v
in
self
.
vars
}
self
.
shared_dic
[
'params'
]
=
dic
self
.
condvar
.
acquire
()
self
.
condvar
.
notify_all
()
self
.
condvar
.
release
()
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
import
random
import
random
...
...
tensorpack/dataflow/raw.py
View file @
58a41fca
...
@@ -6,8 +6,15 @@
...
@@ -6,8 +6,15 @@
import
numpy
as
np
import
numpy
as
np
from
six.moves
import
range
from
six.moves
import
range
from
.base
import
DataFlow
,
RNGDataFlow
from
.base
import
DataFlow
,
RNGDataFlow
from
..utils.serialize
import
loads
__all__
=
[
'FakeData'
,
'DataFromQueue'
,
'DataFromList'
]
__all__
=
[
'FakeData'
,
'DataFromQueue'
,
'DataFromList'
]
try
:
import
zmq
except
:
pass
else
:
__all__
.
append
(
'DataFromSocket'
)
class
FakeData
(
RNGDataFlow
):
class
FakeData
(
RNGDataFlow
):
""" Generate fake fixed data of given shapes"""
""" Generate fake fixed data of given shapes"""
...
@@ -43,7 +50,6 @@ class DataFromQueue(DataFlow):
...
@@ -43,7 +50,6 @@ class DataFromQueue(DataFlow):
while
True
:
while
True
:
yield
self
.
queue
.
get
()
yield
self
.
queue
.
get
()
class
DataFromList
(
RNGDataFlow
):
class
DataFromList
(
RNGDataFlow
):
""" Produce data from a list"""
""" Produce data from a list"""
def
__init__
(
self
,
lst
,
shuffle
=
True
):
def
__init__
(
self
,
lst
,
shuffle
=
True
):
...
@@ -63,3 +69,20 @@ class DataFromList(RNGDataFlow):
...
@@ -63,3 +69,20 @@ class DataFromList(RNGDataFlow):
for
k
in
idxs
:
for
k
in
idxs
:
yield
self
.
lst
[
k
]
yield
self
.
lst
[
k
]
class
DataFromSocket
(
DataFlow
):
""" Produce data from a zmq socket"""
def
__init__
(
self
,
socket_name
):
self
.
_name
=
socket_name
def
get_data
(
self
):
try
:
ctx
=
zmq
.
Context
()
socket
=
ctx
.
socket
(
zmq
.
PULL
)
socket
.
bind
(
self
.
_name
)
while
True
:
dp
=
loads
(
socket
.
recv
(
copy
=
False
))
yield
dp
finally
:
ctx
.
destroy
(
linger
=
0
)
tensorpack/tfutils/varmanip.py
View file @
58a41fca
...
@@ -61,9 +61,6 @@ class SessionUpdate(object):
...
@@ -61,9 +61,6 @@ class SessionUpdate(object):
for
name
,
value
in
six
.
iteritems
(
prms
):
for
name
,
value
in
six
.
iteritems
(
prms
):
assert
name
in
self
.
assign_ops
assert
name
in
self
.
assign_ops
for
p
,
v
,
op
in
self
.
assign_ops
[
name
]:
for
p
,
v
,
op
in
self
.
assign_ops
[
name
]:
if
'fc0/W'
in
name
:
import
IPython
as
IP
;
IP
.
embed
(
config
=
IP
.
terminal
.
ipapp
.
load_default_config
())
varshape
=
tuple
(
v
.
get_shape
()
.
as_list
())
varshape
=
tuple
(
v
.
get_shape
()
.
as_list
())
if
varshape
!=
value
.
shape
:
if
varshape
!=
value
.
shape
:
# TODO only allow reshape when shape different by empty axis
# TODO only allow reshape when shape different by empty axis
...
@@ -71,13 +68,7 @@ class SessionUpdate(object):
...
@@ -71,13 +68,7 @@ class SessionUpdate(object):
"{}: {}!={}"
.
format
(
name
,
varshape
,
value
.
shape
)
"{}: {}!={}"
.
format
(
name
,
varshape
,
value
.
shape
)
logger
.
warn
(
"Param {} is reshaped during assigning"
.
format
(
name
))
logger
.
warn
(
"Param {} is reshaped during assigning"
.
format
(
name
))
value
=
value
.
reshape
(
varshape
)
value
=
value
.
reshape
(
varshape
)
if
'fc0/W'
in
name
:
import
IPython
as
IP
;
IP
.
embed
(
config
=
IP
.
terminal
.
ipapp
.
load_default_config
())
self
.
sess
.
run
(
op
,
feed_dict
=
{
p
:
value
})
self
.
sess
.
run
(
op
,
feed_dict
=
{
p
:
value
})
if
'fc0/W'
in
name
:
import
IPython
as
IP
;
IP
.
embed
(
config
=
IP
.
terminal
.
ipapp
.
load_default_config
())
def
dump_session_params
(
path
):
def
dump_session_params
(
path
):
""" Dump value of all trainable + to_save variables to a dict and save to `path` as
""" Dump value of all trainable + to_save variables to a dict and save to `path` as
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment