Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
c5da59af
Commit
c5da59af
authored
Jun 02, 2016
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
minor fix for async
parent
cc844ed4
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
32 additions
and
9 deletions
+32
-9
tensorpack/RL/atari.py
tensorpack/RL/atari.py
+8
-2
tensorpack/RL/simulator.py
tensorpack/RL/simulator.py
+12
-2
tensorpack/tfutils/gradproc.py
tensorpack/tfutils/gradproc.py
+9
-3
tensorpack/train/trainer.py
tensorpack/train/trainer.py
+3
-2
No files found.
tensorpack/RL/atari.py
View file @
c5da59af
...
...
@@ -9,7 +9,7 @@ import os
import
cv2
from
collections
import
deque
from
six.moves
import
range
from
..utils
import
get_rng
,
logger
from
..utils
import
get_rng
,
logger
,
memoized
from
..utils.stat
import
StatCounter
from
.envbase
import
RLEnvironment
...
...
@@ -21,6 +21,10 @@ except ImportError:
__all__
=
[
'AtariPlayer'
]
@
memoized
def
log_once
():
logger
.
warn
(
"https://github.com/mgbellemare/Arcade-Learning-Environment/pull/171 is not merged!"
)
class
AtariPlayer
(
RLEnvironment
):
"""
A wrapper for atari emulator.
...
...
@@ -43,10 +47,12 @@ class AtariPlayer(RLEnvironment):
self
.
ale
.
setInt
(
"random_seed"
,
self
.
rng
.
randint
(
0
,
10000
))
self
.
ale
.
setBool
(
"showinfo"
,
False
)
try
:
ALEInterface
.
setLoggerMode
(
ALEInterface
.
Logger
.
Warning
)
except
AttributeError
:
logger
.
warn
(
"https://github.com/mgbellemare/Arcade-Learning-Environment/pull/171 is not merged!"
)
log_once
()
self
.
ale
.
setInt
(
"frame_skip"
,
1
)
self
.
ale
.
setBool
(
'color_averaging'
,
False
)
# manual.pdf suggests otherwise. may need to check
...
...
tensorpack/RL/simulator.py
View file @
c5da59af
...
...
@@ -4,6 +4,7 @@
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
import
multiprocessing
import
time
import
threading
import
weakref
from
abc
import
abstractmethod
,
ABCMeta
...
...
@@ -68,10 +69,12 @@ class SimulatorMaster(threading.Thread):
class
Experience
(
object
):
""" A transition of state, or experience"""
def
__init__
(
self
,
state
,
action
,
reward
):
def
__init__
(
self
,
state
,
action
,
reward
,
misc
=
None
):
""" misc: whatever other attribute you want to save"""
self
.
state
=
state
self
.
action
=
action
self
.
reward
=
reward
self
.
misc
=
misc
def
__init__
(
self
,
server_name
):
super
(
SimulatorMaster
,
self
)
.
__init__
()
...
...
@@ -91,7 +94,14 @@ class SimulatorMaster(threading.Thread):
def
run
(
self
):
self
.
clients
=
defaultdict
(
SimulatorMaster
.
ClientState
)
while
True
:
ident
,
_
,
msg
=
self
.
socket
.
recv_multipart
()
while
True
:
# avoid the lock being acquired here forever
try
:
with
self
.
socket_lock
:
ident
,
_
,
msg
=
self
.
socket
.
recv_multipart
(
zmq
.
NOBLOCK
)
break
except
zmq
.
ZMQError
:
time
.
sleep
(
0.01
)
#assert _ == ""
client
=
self
.
clients
[
ident
]
client
.
protocol_state
=
1
-
client
.
protocol_state
# first flip the state
...
...
tensorpack/tfutils/gradproc.py
View file @
c5da59af
...
...
@@ -27,16 +27,23 @@ class GradientProcessor(object):
def
_process
(
self
,
grads
):
pass
_summaried_gradient
=
set
()
class
SummaryGradient
(
GradientProcessor
):
"""
Summary history and RMS for each graident variable
"""
def
_process
(
self
,
grads
):
for
grad
,
var
in
grads
:
tf
.
histogram_summary
(
var
.
op
.
name
+
'/grad'
,
grad
)
name
=
var
.
op
.
name
if
name
in
_summaried_gradient
:
continue
_summaried_gradient
.
add
(
name
)
tf
.
histogram_summary
(
name
+
'/grad'
,
grad
)
tf
.
add_to_collection
(
MOVING_SUMMARY_VARS_KEY
,
tf
.
sqrt
(
tf
.
reduce_mean
(
tf
.
square
(
grad
)),
name
=
var
.
op
.
name
+
'/gradRMS'
))
name
=
name
+
'/gradRMS'
))
return
grads
...
...
@@ -46,7 +53,6 @@ class CheckGradient(GradientProcessor):
"""
def
_process
(
self
,
grads
):
for
grad
,
var
in
grads
:
assert
grad
is
not
None
,
"Grad is None for variable {}"
.
format
(
var
.
name
)
# TODO make assert work
tf
.
Assert
(
tf
.
reduce_all
(
tf
.
is_finite
(
var
)),
[
var
])
return
grads
...
...
tensorpack/train/trainer.py
View file @
c5da59af
...
...
@@ -191,13 +191,13 @@ class QueueInputTrainer(Trainer):
grads
=
QueueInputTrainer
.
_average_grads
(
grad_list
)
grads
=
self
.
process_grads
(
grads
)
else
:
grad_list
=
[
self
.
process_grads
(
g
)
for
g
in
grad_list
]
# pretend to average the grads, in order to make async and
# sync have consistent effective learning rate
def
scale
(
grads
):
return
[(
grad
/
self
.
config
.
nr_tower
,
var
)
for
grad
,
var
in
grads
]
grad_list
=
map
(
scale
,
grad_list
)
grads
=
grad_list
[
0
]
# use grad from the first tower for routinely stuff
grad_list
=
[
self
.
process_grads
(
g
)
for
g
in
grad_list
]
grads
=
grad_list
[
0
]
# use grad from the first tower for the main iteration
else
:
grads
=
self
.
_single_tower_grad
()
grads
=
self
.
process_grads
(
grads
)
...
...
@@ -207,6 +207,7 @@ class QueueInputTrainer(Trainer):
summary_moving_average
())
if
self
.
async
:
# prepare train_op for the rest of the towers
self
.
threads
=
[]
for
k
in
range
(
1
,
self
.
config
.
nr_tower
):
train_op
=
self
.
config
.
optimizer
.
apply_gradients
(
grad_list
[
k
])
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment