Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
f74ba9a1
Commit
f74ba9a1
authored
Aug 10, 2016
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
fix some atari settings
parent
162f2db0
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
19 additions
and
6 deletions
+19
-6
examples/Atari2600/atari.py
examples/Atari2600/atari.py
+6
-5
examples/Atari2600/common.py
examples/Atari2600/common.py
+2
-1
tensorpack/RL/simulator.py
tensorpack/RL/simulator.py
+8
-0
tensorpack/train/trainer.py
tensorpack/train/trainer.py
+3
-0
No files found.
examples/Atari2600/atari.py
View file @
f74ba9a1
...
...
@@ -62,8 +62,7 @@ class AtariPlayer(RLEnvironment):
with
_ALE_LOCK
:
self
.
ale
=
ALEInterface
()
self
.
rng
=
get_rng
(
self
)
self
.
ale
.
setInt
(
b
"random_seed"
,
self
.
rng
.
randint
(
0
,
10000
))
self
.
ale
.
setInt
(
b
"random_seed"
,
self
.
rng
.
randint
(
0
,
30000
))
self
.
ale
.
setBool
(
b
"showinfo"
,
False
)
self
.
ale
.
setInt
(
b
"frame_skip"
,
1
)
...
...
@@ -132,7 +131,8 @@ class AtariPlayer(RLEnvironment):
def
restart_episode
(
self
):
self
.
current_episode_score
.
reset
()
self
.
ale
.
reset_game
()
with
_ALE_LOCK
:
self
.
ale
.
reset_game
()
# random null-ops start
n
=
self
.
rng
.
randint
(
self
.
nullop_start
)
...
...
@@ -160,11 +160,12 @@ class AtariPlayer(RLEnvironment):
self
.
current_episode_score
.
feed
(
r
)
isOver
=
self
.
ale
.
game_over
()
if
self
.
live_lost_as_eoe
:
isOver
=
isOver
or
newlives
<
oldlives
if
isOver
:
self
.
finish_episode
()
if
self
.
ale
.
game_over
():
self
.
restart_episode
()
if
self
.
live_lost_as_eoe
:
isOver
=
isOver
or
newlives
<
oldlives
return
(
r
,
isOver
)
if
__name__
==
'__main__'
:
...
...
examples/Atari2600/common.py
View file @
f74ba9a1
...
...
@@ -48,10 +48,11 @@ def eval_with_funcs(predict_funcs, nr_eval):
return
self
.
_func
(
*
args
,
**
kwargs
)
def
run
(
self
):
player
=
get_player
()
player
=
get_player
(
train
=
False
)
while
not
self
.
stopped
():
try
:
score
=
play_one_episode
(
player
,
self
.
func
)
#print "Score, ", score
except
RuntimeError
:
return
self
.
queue_put_stoppable
(
self
.
q
,
score
)
...
...
tensorpack/RL/simulator.py
View file @
f74ba9a1
...
...
@@ -14,6 +14,7 @@ import numpy as np
import
six
from
six.moves
import
queue
from
..models._common
import
disable_layer_logging
from
..callbacks
import
Callback
from
..tfutils.varmanip
import
SessionUpdate
from
..predict
import
OfflinePredictor
...
...
@@ -221,6 +222,7 @@ class SimulatorProcessSharedWeight(SimulatorProcessDF):
self
.
pred_config
=
pred_config
def
_prepare
(
self
):
disable_layer_logging
()
self
.
predictor
=
OfflinePredictor
(
self
.
pred_config
)
with
self
.
predictor
.
graph
.
as_default
():
vars_to_update
=
self
.
_params_to_update
()
...
...
@@ -244,6 +246,7 @@ class SimulatorProcessSharedWeight(SimulatorProcessDF):
def
_trigger_evt
(
self
):
with
self
.
weight_lock
:
self
.
sess_updater
.
update
(
self
.
shared_dic
[
'params'
])
logger
.
info
(
"Updated."
)
def
_params_to_update
(
self
):
# can be overwritten to update more params
...
...
@@ -262,7 +265,12 @@ class WeightSync(Callback):
# can be overwritten to update more params
return
tf
.
trainable_variables
()
def
_before_train
(
self
):
self
.
_sync
()
def
_trigger_epoch
(
self
):
self
.
_sync
()
def
_sync
(
self
):
logger
.
info
(
"Updating weights ..."
)
dic
=
{
v
.
name
:
v
.
eval
()
for
v
in
self
.
vars
}
self
.
shared_dic
[
'params'
]
=
dic
...
...
tensorpack/train/trainer.py
View file @
f74ba9a1
...
...
@@ -117,6 +117,9 @@ class EnqueueThread(threading.Thread):
try
:
while
True
:
for
dp
in
self
.
dataflow
.
get_data
():
#import IPython;
#IPython.embed(config=IPython.terminal.ipapp.load_default_config())
if
self
.
coord
.
should_stop
():
return
feed
=
dict
(
zip
(
self
.
input_vars
,
dp
))
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment