Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
55098813
Commit
55098813
authored
Mar 15, 2018
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Upgrade gym
parent
4bc0c748
Changes
8
Expand all
Show whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
79 additions
and
59 deletions
+79
-59
docs/conf.py
docs/conf.py
+1
-1
examples/A3C-Gym/README.md
examples/A3C-Gym/README.md
+28
-25
examples/A3C-Gym/train-atari.py
examples/A3C-Gym/train-atari.py
+3
-4
examples/DeepQNetwork/DQNModel.py
examples/DeepQNetwork/DQNModel.py
+3
-3
examples/DeepQNetwork/atari.py
examples/DeepQNetwork/atari.py
+2
-2
examples/DeepQNetwork/atari_wrapper.py
examples/DeepQNetwork/atari_wrapper.py
+17
-10
examples/DeepQNetwork/common.py
examples/DeepQNetwork/common.py
+13
-8
tensorpack/graph_builder/model_desc.py
tensorpack/graph_builder/model_desc.py
+12
-6
No files found.
docs/conf.py
View file @
55098813
...
...
@@ -362,7 +362,7 @@ def autodoc_skip_member(app, what, name, obj, skip, options):
# include_init_with_doc doesn't work well for decorated init
# https://github.com/sphinx-doc/sphinx/issues/4258
return
False
#
hide deprecated stuff
#
Hide some names that are deprecated or not intended to be used
if
name
in
[
# deprecated stuff:
'GaussianDeform'
,
...
...
examples/A3C-Gym/README.md
View file @
55098813
This diff is collapsed.
Click to expand it.
examples/A3C-Gym/train-atari.py
View file @
55098813
...
...
@@ -54,7 +54,7 @@ ENV_NAME = None
def
get_player
(
train
=
False
,
dumpdir
=
None
):
env
=
gym
.
make
(
ENV_NAME
)
if
dumpdir
:
env
=
gym
.
wrappers
.
Monitor
(
env
,
dumpdir
)
env
=
gym
.
wrappers
.
Monitor
(
env
,
dumpdir
,
video_callable
=
lambda
_
:
True
)
env
=
FireResetEnv
(
env
)
env
=
MapState
(
env
,
lambda
im
:
cv2
.
resize
(
im
,
IMAGE_SIZE
))
env
=
FrameStack
(
env
,
4
)
...
...
@@ -272,7 +272,7 @@ if __name__ == '__main__':
parser
.
add_argument
(
'--load'
,
help
=
'load model'
)
parser
.
add_argument
(
'--env'
,
help
=
'env'
,
required
=
True
)
parser
.
add_argument
(
'--task'
,
help
=
'task to perform'
,
choices
=
[
'play'
,
'eval'
,
'train'
,
'
gen_submit
'
],
default
=
'train'
)
choices
=
[
'play'
,
'eval'
,
'train'
,
'
dump_video
'
],
default
=
'train'
)
parser
.
add_argument
(
'--output'
,
help
=
'output directory for submission'
,
default
=
'output_dir'
)
parser
.
add_argument
(
'--episode'
,
help
=
'number of episode to eval'
,
default
=
100
,
type
=
int
)
args
=
parser
.
parse_args
()
...
...
@@ -297,10 +297,9 @@ if __name__ == '__main__':
args
.
episode
,
render
=
True
)
elif
args
.
task
==
'eval'
:
eval_model_multithread
(
pred
,
args
.
episode
,
get_player
)
elif
args
.
task
==
'
gen_submit
'
:
elif
args
.
task
==
'
dump_video
'
:
play_n_episodes
(
get_player
(
train
=
False
,
dumpdir
=
args
.
output
),
pred
,
args
.
episode
)
# gym.upload(args.output, api_key='xxx')
else
:
train
()
examples/DeepQNetwork/DQNModel.py
View file @
55098813
...
...
@@ -9,7 +9,7 @@ import tensorpack
from
tensorpack
import
ModelDesc
,
InputDesc
from
tensorpack.utils
import
logger
from
tensorpack.tfutils
import
(
summary
,
get_current_tower_context
,
optimizer
,
gradproc
)
varreplace
,
summary
,
get_current_tower_context
,
optimizer
,
gradproc
)
from
tensorpack.tfutils.scope_utils
import
auto_reuse_variable_scope
assert
tensorpack
.
tfutils
.
common
.
get_tf_version_number
()
>=
1.2
...
...
@@ -60,7 +60,7 @@ class Model(ModelDesc):
self
.
predict_value
,
1
),
name
=
'predict_reward'
)
summary
.
add_moving_summary
(
max_pred_reward
)
with
tf
.
variable_scope
(
'target'
):
with
tf
.
variable_scope
(
'target'
)
,
varreplace
.
freeze_variables
(
skip_collection
=
True
)
:
targetQ_predict_value
=
self
.
get_DQN_prediction
(
next_state
)
# NxA
if
self
.
method
!=
'Double'
:
...
...
@@ -96,6 +96,6 @@ class Model(ModelDesc):
target_name
=
v
.
op
.
name
if
target_name
.
startswith
(
'target'
):
new_name
=
target_name
.
replace
(
'target/'
,
''
)
logger
.
info
(
"{} <- {}"
.
format
(
target_name
,
new_name
))
logger
.
info
(
"
Target Network Update:
{} <- {}"
.
format
(
target_name
,
new_name
))
ops
.
append
(
v
.
assign
(
G
.
get_tensor_by_name
(
new_name
+
':0'
)))
return
tf
.
group
(
*
ops
,
name
=
'update_target_network'
)
examples/DeepQNetwork/atari.py
View file @
55098813
...
...
@@ -138,12 +138,12 @@ class AtariPlayer(gym.Env):
self
.
last_raw_screen
=
self
.
_grab_raw_image
()
self
.
ale
.
act
(
0
)
def
_
reset
(
self
):
def
reset
(
self
):
if
self
.
ale
.
game_over
():
self
.
_restart_episode
()
return
self
.
_current_state
()
def
_
step
(
self
,
act
):
def
step
(
self
,
act
):
oldlives
=
self
.
ale
.
lives
()
r
=
0
for
k
in
range
(
self
.
frame_skip
):
...
...
examples/DeepQNetwork/atari_wrapper.py
View file @
55098813
...
...
@@ -8,6 +8,9 @@ from collections import deque
import
gym
from
gym
import
spaces
_v0
,
_v1
=
gym
.
__version__
.
split
(
'.'
)[:
2
]
assert
int
(
_v0
)
>
0
or
int
(
_v1
)
>=
10
,
gym
.
__version__
"""
The following wrappers are copied or modified from openai/baselines:
...
...
@@ -20,7 +23,7 @@ class MapState(gym.ObservationWrapper):
gym
.
ObservationWrapper
.
__init__
(
self
,
env
)
self
.
_func
=
map_func
def
_
observation
(
self
,
obs
):
def
observation
(
self
,
obs
):
return
self
.
_func
(
obs
)
...
...
@@ -32,22 +35,23 @@ class FrameStack(gym.Wrapper):
self
.
frames
=
deque
([],
maxlen
=
k
)
shp
=
env
.
observation_space
.
shape
chan
=
1
if
len
(
shp
)
==
2
else
shp
[
2
]
self
.
observation_space
=
spaces
.
Box
(
low
=
0
,
high
=
255
,
shape
=
(
shp
[
0
],
shp
[
1
],
chan
*
k
))
self
.
observation_space
=
spaces
.
Box
(
low
=
0
,
high
=
255
,
shape
=
(
shp
[
0
],
shp
[
1
],
chan
*
k
),
dtype
=
np
.
uint8
)
def
_
reset
(
self
):
def
reset
(
self
):
"""Clear buffer and re-fill by duplicating the first observation."""
ob
=
self
.
env
.
reset
()
for
_
in
range
(
self
.
k
-
1
):
self
.
frames
.
append
(
np
.
zeros_like
(
ob
))
self
.
frames
.
append
(
ob
)
return
self
.
_
observation
()
return
self
.
observation
()
def
_
step
(
self
,
action
):
def
step
(
self
,
action
):
ob
,
reward
,
done
,
info
=
self
.
env
.
step
(
action
)
self
.
frames
.
append
(
ob
)
return
self
.
_
observation
(),
reward
,
done
,
info
return
self
.
observation
(),
reward
,
done
,
info
def
_
observation
(
self
):
def
observation
(
self
):
assert
len
(
self
.
frames
)
==
self
.
k
if
self
.
frames
[
-
1
]
.
ndim
==
2
:
return
np
.
stack
(
self
.
frames
,
axis
=-
1
)
...
...
@@ -62,7 +66,7 @@ class _FireResetEnv(gym.Wrapper):
assert
env
.
unwrapped
.
get_action_meanings
()[
1
]
==
'FIRE'
assert
len
(
env
.
unwrapped
.
get_action_meanings
())
>=
3
def
_
reset
(
self
):
def
reset
(
self
):
self
.
env
.
reset
()
obs
,
_
,
done
,
_
=
self
.
env
.
step
(
1
)
if
done
:
...
...
@@ -72,6 +76,9 @@ class _FireResetEnv(gym.Wrapper):
self
.
env
.
reset
()
return
obs
def
step
(
self
,
action
):
return
self
.
env
.
step
(
action
)
def
FireResetEnv
(
env
):
if
isinstance
(
env
,
gym
.
Wrapper
):
...
...
@@ -88,7 +95,7 @@ class LimitLength(gym.Wrapper):
gym
.
Wrapper
.
__init__
(
self
,
env
)
self
.
k
=
k
def
_
reset
(
self
):
def
reset
(
self
):
# This assumes that reset() will really reset the env.
# If the underlying env tries to be smart about reset
# (e.g. end-of-life), the assumption doesn't hold.
...
...
@@ -96,7 +103,7 @@ class LimitLength(gym.Wrapper):
self
.
cnt
=
0
return
ob
def
_
step
(
self
,
action
):
def
step
(
self
,
action
):
ob
,
r
,
done
,
info
=
self
.
env
.
step
(
action
)
self
.
cnt
+=
1
if
self
.
cnt
==
self
.
k
:
...
...
examples/DeepQNetwork/common.py
View file @
55098813
...
...
@@ -18,7 +18,7 @@ from tensorpack.utils.utils import get_tqdm_kwargs
def
play_one_episode
(
env
,
func
,
render
=
False
):
def
predict
(
s
):
"""
Map from observation to action, with 0.0
0
1 greedy.
Map from observation to action, with 0.01 greedy.
"""
act
=
func
(
s
[
None
,
:,
:,
:])[
0
][
0
]
.
argmax
()
if
random
.
random
()
<
0.01
:
...
...
@@ -45,7 +45,7 @@ def play_n_episodes(player, predfunc, nr, render=False):
print
(
"{}/{}, score={}"
.
format
(
k
,
nr
,
score
))
def
eval_with_funcs
(
predictors
,
nr_eval
,
get_player_fn
):
def
eval_with_funcs
(
predictors
,
nr_eval
,
get_player_fn
,
verbose
=
False
):
"""
Args:
predictors ([PredictorBase])
...
...
@@ -67,7 +67,6 @@ def eval_with_funcs(predictors, nr_eval, get_player_fn):
while
not
self
.
stopped
():
try
:
score
=
play_one_episode
(
player
,
self
.
func
)
# print("Score, ", score)
except
RuntimeError
:
return
self
.
queue_put_stoppable
(
self
.
q
,
score
)
...
...
@@ -80,17 +79,21 @@ def eval_with_funcs(predictors, nr_eval, get_player_fn):
time
.
sleep
(
0.1
)
# avoid simulator bugs
stat
=
StatCounter
()
for
_
in
tqdm
(
range
(
nr_eval
),
**
get_tqdm_kwargs
()
):
def
fetch
(
):
r
=
q
.
get
()
stat
.
feed
(
r
)
if
verbose
:
logger
.
info
(
"Score: {}"
.
format
(
r
))
for
_
in
tqdm
(
range
(
nr_eval
),
**
get_tqdm_kwargs
()):
fetch
()
logger
.
info
(
"Waiting for all the workers to finish the last run..."
)
for
k
in
threads
:
k
.
stop
()
for
k
in
threads
:
k
.
join
()
while
q
.
qsize
():
r
=
q
.
get
()
stat
.
feed
(
r
)
fetch
()
if
stat
.
count
>
0
:
return
(
stat
.
average
,
stat
.
max
)
...
...
@@ -100,11 +103,13 @@ def eval_with_funcs(predictors, nr_eval, get_player_fn):
def
eval_model_multithread
(
pred
,
nr_eval
,
get_player_fn
):
"""
Args:
pred (OfflinePredictor): state ->
Qvalue
pred (OfflinePredictor): state ->
[#action]
"""
NR_PROC
=
min
(
multiprocessing
.
cpu_count
()
//
2
,
8
)
with
pred
.
sess
.
as_default
():
mean
,
max
=
eval_with_funcs
([
pred
]
*
NR_PROC
,
nr_eval
,
get_player_fn
)
mean
,
max
=
eval_with_funcs
(
[
pred
]
*
NR_PROC
,
nr_eval
,
get_player_fn
,
verbose
=
True
)
logger
.
info
(
"Average Score: {}; Max Score: {}"
.
format
(
mean
,
max
))
...
...
tensorpack/graph_builder/model_desc.py
View file @
55098813
...
...
@@ -95,6 +95,11 @@ class ModelDescBase(object):
Args:
args ([tf.Tensor]): tensors that matches the list of
:class:`InputDesc` defined by ``_get_inputs``.
Returns:
In general it returns nothing, but a subclass (e.g.
:class:`ModelDesc` may require it to return necessary information
to build the trainer.
"""
if
len
(
args
)
==
1
:
arg
=
args
[
0
]
...
...
@@ -124,18 +129,16 @@ class ModelDescBase(object):
class
ModelDesc
(
ModelDescBase
):
"""
A ModelDesc with
single cost and single optimizer
.
A ModelDesc with
**single cost** and **single optimizer**
.
It contains information about InputDesc, how to get cost, and how to get optimizer.
"""
def
get_cost
(
self
):
"""
Return the cost tensor in the graph.
It calls :meth:`ModelDesc._get_cost()` which by default returns
``self.cost``. You can override :meth:`_get_cost()` if needed.
Return the cost tensor to optimize on.
This function also applies the collection
This function takes the cost tensor defined by :meth:`build_graph`,
and applies the collection
``tf.GraphKeys.REGULARIZATION_LOSSES`` to the cost automatically.
"""
cost
=
self
.
_get_cost
()
...
...
@@ -165,6 +168,9 @@ class ModelDesc(ModelDescBase):
raise
NotImplementedError
()
def
_build_graph_get_cost
(
self
,
*
inputs
):
"""
Used by trainers to get the final cost for optimization.
"""
self
.
build_graph
(
*
inputs
)
return
self
.
get_cost
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment