Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
194cda0b
Commit
194cda0b
authored
Jun 10, 2016
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
remove get_stat; more general evaluator
parent
b2ec42a8
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
24 additions
and
34 deletions
+24
-34
examples/Atari2600/DQN.py
examples/Atari2600/DQN.py
+1
-1
tensorpack/RL/atari.py
tensorpack/RL/atari.py
+0
-7
tensorpack/RL/envbase.py
tensorpack/RL/envbase.py
+2
-11
tensorpack/RL/expreplay.py
tensorpack/RL/expreplay.py
+7
-3
tensorpack/RL/history.py
tensorpack/RL/history.py
+4
-2
tensorpack/callbacks/summary.py
tensorpack/callbacks/summary.py
+2
-2
tensorpack/predict/concurrency.py
tensorpack/predict/concurrency.py
+2
-2
tensorpack/tfutils/symbolic_functions.py
tensorpack/tfutils/symbolic_functions.py
+6
-6
No files found.
examples/Atari2600/DQN.py
View file @
194cda0b
...
@@ -168,7 +168,7 @@ def get_config():
...
@@ -168,7 +168,7 @@ def get_config():
HumanHyperParamSetter
(
ObjAttrParam
(
dataset_train
,
'exploration'
),
'hyper.txt'
),
HumanHyperParamSetter
(
ObjAttrParam
(
dataset_train
,
'exploration'
),
'hyper.txt'
),
RunOp
(
lambda
:
M
.
update_target_param
()),
RunOp
(
lambda
:
M
.
update_target_param
()),
dataset_train
,
dataset_train
,
PeriodicCallback
(
Evaluator
(
EVAL_EPISODE
,
'fct/output:0'
),
2
),
PeriodicCallback
(
Evaluator
(
EVAL_EPISODE
,
[
'state'
],
[
'fct/output'
]
),
2
),
]),
]),
# save memory for multiprocess evaluator
# save memory for multiprocess evaluator
session_config
=
get_default_sess_config
(
0.6
),
session_config
=
get_default_sess_config
(
0.6
),
...
...
tensorpack/RL/atari.py
View file @
194cda0b
...
@@ -156,13 +156,6 @@ class AtariPlayer(RLEnvironment):
...
@@ -156,13 +156,6 @@ class AtariPlayer(RLEnvironment):
isOver
=
isOver
or
newlives
<
oldlives
isOver
=
isOver
or
newlives
<
oldlives
return
(
r
,
isOver
)
return
(
r
,
isOver
)
def
get_stat
(
self
):
try
:
return
{
'avg_score'
:
np
.
mean
(
self
.
stats
[
'score'
]),
'max_score'
:
float
(
np
.
max
(
self
.
stats
[
'score'
]))
}
except
ValueError
:
return
{}
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
import
sys
import
sys
import
time
import
time
...
...
tensorpack/RL/envbase.py
View file @
194cda0b
...
@@ -28,7 +28,7 @@ class RLEnvironment(object):
...
@@ -28,7 +28,7 @@ class RLEnvironment(object):
def
action
(
self
,
act
):
def
action
(
self
,
act
):
"""
"""
Perform an action. Will automatically start a new episode if isOver==True
Perform an action. Will automatically start a new episode if isOver==True
:param
s
act: the action
:param act: the action
:returns: (reward, isOver)
:returns: (reward, isOver)
"""
"""
...
@@ -40,19 +40,13 @@ class RLEnvironment(object):
...
@@ -40,19 +40,13 @@ class RLEnvironment(object):
""" return an `ActionSpace` instance"""
""" return an `ActionSpace` instance"""
raise
NotImplementedError
()
raise
NotImplementedError
()
def
get_stat
(
self
):
"""
return a dict of statistics (e.g., score) for all the episodes since last call to reset_stat
"""
return
{}
def
reset_stat
(
self
):
def
reset_stat
(
self
):
""" reset all statistics counter"""
""" reset all statistics counter"""
self
.
stats
=
defaultdict
(
list
)
self
.
stats
=
defaultdict
(
list
)
def
play_one_episode
(
self
,
func
,
stat
=
'score'
):
def
play_one_episode
(
self
,
func
,
stat
=
'score'
):
""" play one episode for eval.
""" play one episode for eval.
:param
s
func: call with the state and return an action
:param func: call with the state and return an action
:returns: the score of this episode
:returns: the score of this episode
"""
"""
while
True
:
while
True
:
...
@@ -102,9 +96,6 @@ class ProxyPlayer(RLEnvironment):
...
@@ -102,9 +96,6 @@ class ProxyPlayer(RLEnvironment):
def
__init__
(
self
,
player
):
def
__init__
(
self
,
player
):
self
.
player
=
player
self
.
player
=
player
def
get_stat
(
self
):
return
self
.
player
.
get_stat
()
def
reset_stat
(
self
):
def
reset_stat
(
self
):
self
.
player
.
reset_stat
()
self
.
player
.
reset_stat
()
...
...
tensorpack/RL/expreplay.py
View file @
194cda0b
...
@@ -170,10 +170,14 @@ class ExpReplay(DataFlow, Callback):
...
@@ -170,10 +170,14 @@ class ExpReplay(DataFlow, Callback):
self
.
exploration
-=
self
.
exploration_epoch_anneal
self
.
exploration
-=
self
.
exploration_epoch_anneal
logger
.
info
(
"Exploration changed to {}"
.
format
(
self
.
exploration
))
logger
.
info
(
"Exploration changed to {}"
.
format
(
self
.
exploration
))
# log player statistics
# log player statistics
stats
=
self
.
player
.
get_stat
()
stats
=
self
.
player
.
stats
for
k
,
v
in
six
.
iteritems
(
stats
):
for
k
,
v
in
six
.
iteritems
(
stats
):
if
isinstance
(
v
,
float
):
try
:
self
.
trainer
.
write_scalar_summary
(
'expreplay/'
+
k
,
v
)
mean
,
max
=
np
.
mean
(
v
),
np
.
max
(
v
)
self
.
trainer
.
write_scalar_summary
(
'expreplay/mean_'
+
k
,
mean
)
self
.
trainer
.
write_scalar_summary
(
'expreplay/max_'
+
k
,
max
)
except
:
pass
self
.
player
.
reset_stat
()
self
.
player
.
reset_stat
()
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
...
...
tensorpack/RL/history.py
View file @
194cda0b
...
@@ -10,9 +10,11 @@ from .envbase import ProxyPlayer
...
@@ -10,9 +10,11 @@ from .envbase import ProxyPlayer
__all__
=
[
'HistoryFramePlayer'
]
__all__
=
[
'HistoryFramePlayer'
]
class
HistoryFramePlayer
(
ProxyPlayer
):
class
HistoryFramePlayer
(
ProxyPlayer
):
""" Include history frames in state, or use black images"""
""" Include history frames in state, or use black images
Assume player will do auto-restart.
"""
def
__init__
(
self
,
player
,
hist_len
):
def
__init__
(
self
,
player
,
hist_len
):
""" :param
s
hist_len: total length of the state, including the current
""" :param hist_len: total length of the state, including the current
and `hist_len-1` history"""
and `hist_len-1` history"""
super
(
HistoryFramePlayer
,
self
)
.
__init__
(
player
)
super
(
HistoryFramePlayer
,
self
)
.
__init__
(
player
)
self
.
history
=
deque
(
maxlen
=
hist_len
)
self
.
history
=
deque
(
maxlen
=
hist_len
)
...
...
tensorpack/callbacks/summary.py
View file @
194cda0b
...
@@ -47,11 +47,11 @@ class StatHolder(object):
...
@@ -47,11 +47,11 @@ class StatHolder(object):
"""
"""
self
.
print_tag
=
None
if
print_tag
is
None
else
set
(
print_tag
)
self
.
print_tag
=
None
if
print_tag
is
None
else
set
(
print_tag
)
def
get_stat_now
(
self
,
k
):
def
get_stat_now
(
self
,
k
ey
):
"""
"""
Return the value of a stat in the current epoch.
Return the value of a stat in the current epoch.
"""
"""
return
self
.
stat_now
[
k
]
return
self
.
stat_now
[
k
ey
]
def
finalize
(
self
):
def
finalize
(
self
):
"""
"""
...
...
tensorpack/predict/concurrency.py
View file @
194cda0b
...
@@ -153,8 +153,8 @@ class MultiThreadAsyncPredictor(object):
...
@@ -153,8 +153,8 @@ class MultiThreadAsyncPredictor(object):
def
put_task
(
self
,
inputs
,
callback
=
None
):
def
put_task
(
self
,
inputs
,
callback
=
None
):
"""
"""
:param
s
inputs: a data point (list of component) matching input_names (not batched)
:param inputs: a data point (list of component) matching input_names (not batched)
:param
s
callback: a callback to get called with the list of outputs
:param callback: a callback to get called with the list of outputs
:returns: a Future of output."""
:returns: a Future of output."""
f
=
Future
()
f
=
Future
()
if
callback
is
not
None
:
if
callback
is
not
None
:
...
...
tensorpack/tfutils/symbolic_functions.py
View file @
194cda0b
...
@@ -44,10 +44,8 @@ def logSoftmax(x):
...
@@ -44,10 +44,8 @@ def logSoftmax(x):
:param x: NxC tensor.
:param x: NxC tensor.
:returns: NxC tensor.
:returns: NxC tensor.
"""
"""
with
tf
.
op_scope
([
x
],
'logSoftmax'
):
logger
.
warn
(
"symbf.logSoftmax is deprecated in favor of tf.nn.log_softmax"
)
z
=
x
-
tf
.
reduce_max
(
x
,
1
,
keep_dims
=
True
)
return
tf
.
nn
.
log_softmax
(
x
)
logprob
=
z
-
tf
.
log
(
tf
.
reduce_sum
(
tf
.
exp
(
z
),
1
,
keep_dims
=
True
))
return
logprob
def
class_balanced_binary_class_cross_entropy
(
pred
,
label
,
name
=
'cross_entropy_loss'
):
def
class_balanced_binary_class_cross_entropy
(
pred
,
label
,
name
=
'cross_entropy_loss'
):
"""
"""
...
@@ -73,11 +71,13 @@ def class_balanced_binary_class_cross_entropy(pred, label, name='cross_entropy_l
...
@@ -73,11 +71,13 @@ def class_balanced_binary_class_cross_entropy(pred, label, name='cross_entropy_l
cost
=
tf
.
reduce_mean
(
cost
,
name
=
name
)
cost
=
tf
.
reduce_mean
(
cost
,
name
=
name
)
return
cost
return
cost
def
print_stat
(
x
):
def
print_stat
(
x
,
message
=
None
):
""" a simple print op.
""" a simple print op.
Use it like: x = print_stat(x)
Use it like: x = print_stat(x)
"""
"""
return
tf
.
Print
(
x
,
[
tf
.
reduce_mean
(
x
),
x
],
summarize
=
20
)
if
message
is
None
:
message
=
x
.
op
.
name
return
tf
.
Print
(
x
,
[
tf
.
reduce_mean
(
x
),
x
],
summarize
=
20
,
message
=
message
)
def
rms
(
x
,
name
=
None
):
def
rms
(
x
,
name
=
None
):
if
name
is
None
:
if
name
is
None
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment