Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
4e472eb5
Commit
4e472eb5
authored
Jun 03, 2016
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
misc funcs
parent
c83f2d9f
Changes
9
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
70 additions
and
21 deletions
+70
-21
examples/Atari2600/DQN.py
examples/Atari2600/DQN.py
+4
-12
tensorpack/RL/common.py
tensorpack/RL/common.py
+17
-1
tensorpack/RL/envbase.py
tensorpack/RL/envbase.py
+17
-0
tensorpack/callbacks/param.py
tensorpack/callbacks/param.py
+1
-1
tensorpack/dataflow/common.py
tensorpack/dataflow/common.py
+17
-2
tensorpack/predict/concurrency.py
tensorpack/predict/concurrency.py
+3
-2
tensorpack/tfutils/symbolic_functions.py
tensorpack/tfutils/symbolic_functions.py
+5
-0
tensorpack/train/base.py
tensorpack/train/base.py
+3
-3
tensorpack/utils/lut.py
tensorpack/utils/lut.py
+3
-0
No files found.
examples/Atari2600/DQN.py
View file @
4e472eb5
...
@@ -148,22 +148,14 @@ class Model(ModelDesc):
...
@@ -148,22 +148,14 @@ class Model(ModelDesc):
return
self
.
predict_value
.
eval
(
feed_dict
=
{
'state:0'
:
[
state
]})[
0
]
return
self
.
predict_value
.
eval
(
feed_dict
=
{
'state:0'
:
[
state
]})[
0
]
def
play_one_episode
(
player
,
func
,
verbose
=
False
):
def
play_one_episode
(
player
,
func
,
verbose
=
False
):
while
True
:
def
f
(
s
):
s
=
player
.
current_state
()
act
=
func
([[
s
]])[
0
][
0
]
.
argmax
()
outputs
=
func
([[
s
]])
action_value
=
outputs
[
0
][
0
]
act
=
action_value
.
argmax
()
if
verbose
:
print
action_value
,
act
if
random
.
random
()
<
0.01
:
if
random
.
random
()
<
0.01
:
act
=
random
.
choice
(
range
(
NUM_ACTIONS
))
act
=
random
.
choice
(
range
(
NUM_ACTIONS
))
if
verbose
:
if
verbose
:
print
(
act
)
print
(
act
)
reward
,
isOver
=
player
.
action
(
act
)
return
act
if
isOver
:
return
player
.
play_one_episode
(
f
)
sc
=
player
.
stats
[
'score'
][
0
]
player
.
reset_stat
()
return
sc
def
play_model
(
model_path
):
def
play_model
(
model_path
):
player
=
get_player
(
0.013
)
player
=
get_player
(
0.013
)
...
...
tensorpack/RL/common.py
View file @
4e472eb5
...
@@ -8,7 +8,7 @@ import numpy as np
...
@@ -8,7 +8,7 @@ import numpy as np
from
collections
import
deque
from
collections
import
deque
from
.envbase
import
ProxyPlayer
from
.envbase
import
ProxyPlayer
__all__
=
[
'HistoryFramePlayer'
,
'PreventStuckPlayer'
]
__all__
=
[
'HistoryFramePlayer'
,
'PreventStuckPlayer'
,
'LimitLengthPlayer'
]
class
HistoryFramePlayer
(
ProxyPlayer
):
class
HistoryFramePlayer
(
ProxyPlayer
):
""" Include history frames in state, or use black images"""
""" Include history frames in state, or use black images"""
...
@@ -62,3 +62,19 @@ class PreventStuckPlayer(ProxyPlayer):
...
@@ -62,3 +62,19 @@ class PreventStuckPlayer(ProxyPlayer):
if
isOver
:
if
isOver
:
self
.
act_que
.
clear
()
self
.
act_que
.
clear
()
return
(
r
,
isOver
)
return
(
r
,
isOver
)
class
LimitLengthPlayer
(
ProxyPlayer
):
""" Limit the total number of actions in an episode"""
def
__init__
(
self
,
player
,
limit
):
super
(
LimitLengthPlayer
,
self
)
.
__init__
(
player
)
self
.
limit
=
limit
self
.
cnt
=
0
def
action
(
self
,
act
):
r
,
isOver
=
self
.
player
.
action
(
act
)
self
.
cnt
+=
1
if
self
.
cnt
==
self
.
limit
:
isOver
=
True
if
isOver
:
self
.
cnt
==
0
return
(
r
,
isOver
)
tensorpack/RL/envbase.py
View file @
4e472eb5
...
@@ -39,6 +39,20 @@ class RLEnvironment(object):
...
@@ -39,6 +39,20 @@ class RLEnvironment(object):
""" reset the statistics counter"""
""" reset the statistics counter"""
self
.
stats
=
defaultdict
(
list
)
self
.
stats
=
defaultdict
(
list
)
def
play_one_episode
(
self
,
func
,
stat
=
'score'
):
""" play one episode for eval.
:params func: call with the state and return an action
:returns: the score of this episode
"""
while
True
:
s
=
self
.
current_state
()
act
=
func
(
s
)
r
,
isOver
=
self
.
action
(
act
)
if
isOver
:
s
=
self
.
stats
[
stat
]
self
.
reset_stat
()
return
s
class
NaiveRLEnvironment
(
RLEnvironment
):
class
NaiveRLEnvironment
(
RLEnvironment
):
""" for testing only"""
""" for testing only"""
def
__init__
(
self
):
def
__init__
(
self
):
...
@@ -71,3 +85,6 @@ class ProxyPlayer(RLEnvironment):
...
@@ -71,3 +85,6 @@ class ProxyPlayer(RLEnvironment):
def
stats
(
self
):
def
stats
(
self
):
return
self
.
player
.
stats
return
self
.
player
.
stats
def
play_one_episode
(
self
,
func
,
stat
=
'score'
):
return
self
.
player
.
play_one_episode
(
self
,
func
,
stat
)
tensorpack/callbacks/param.py
View file @
4e472eb5
...
@@ -140,7 +140,7 @@ class HumanHyperParamSetter(HyperParamSetter):
...
@@ -140,7 +140,7 @@ class HumanHyperParamSetter(HyperParamSetter):
return
ret
return
ret
except
:
except
:
logger
.
warn
(
logger
.
warn
(
"Failed to
parse
{} in {}"
.
format
(
"Failed to
find
{} in {}"
.
format
(
self
.
param
.
readable_name
,
self
.
file_name
))
self
.
param
.
readable_name
,
self
.
file_name
))
return
None
return
None
...
...
tensorpack/dataflow/common.py
View file @
4e472eb5
...
@@ -11,7 +11,8 @@ from ..utils import *
...
@@ -11,7 +11,8 @@ from ..utils import *
__all__
=
[
'BatchData'
,
'FixedSizeData'
,
'FakeData'
,
'MapData'
,
__all__
=
[
'BatchData'
,
'FixedSizeData'
,
'FakeData'
,
'MapData'
,
'RepeatedData'
,
'MapDataComponent'
,
'RandomChooseData'
,
'RepeatedData'
,
'MapDataComponent'
,
'RandomChooseData'
,
'RandomMixData'
,
'JoinData'
,
'ConcatData'
,
'SelectComponent'
]
'RandomMixData'
,
'JoinData'
,
'ConcatData'
,
'SelectComponent'
,
'DataFromQueue'
]
class
BatchData
(
ProxyDataFlow
):
class
BatchData
(
ProxyDataFlow
):
def
__init__
(
self
,
ds
,
batch_size
,
remainder
=
False
):
def
__init__
(
self
,
ds
,
batch_size
,
remainder
=
False
):
...
@@ -25,7 +26,11 @@ class BatchData(ProxyDataFlow):
...
@@ -25,7 +26,11 @@ class BatchData(ProxyDataFlow):
"""
"""
super
(
BatchData
,
self
)
.
__init__
(
ds
)
super
(
BatchData
,
self
)
.
__init__
(
ds
)
if
not
remainder
:
if
not
remainder
:
assert
batch_size
<=
ds
.
size
()
try
:
s
=
ds
.
size
()
assert
batch_size
<=
ds
.
size
()
except
NotImplementedError
:
pass
self
.
batch_size
=
batch_size
self
.
batch_size
=
batch_size
self
.
remainder
=
remainder
self
.
remainder
=
remainder
...
@@ -313,6 +318,16 @@ class JoinData(DataFlow):
...
@@ -313,6 +318,16 @@ class JoinData(DataFlow):
for
itr
in
itrs
:
for
itr
in
itrs
:
del
itr
del
itr
class
DataFromQueue
(
DataFlow
):
""" provide data from a queue
"""
def
__init__
(
self
,
queue
):
self
.
queue
=
queue
def
get_data
(
self
):
while
True
:
yield
self
.
queue
.
get
()
def
SelectComponent
(
ds
,
idxs
):
def
SelectComponent
(
ds
,
idxs
):
"""
"""
:param ds: a :mod:`DataFlow` instance
:param ds: a :mod:`DataFlow` instance
...
...
tensorpack/predict/concurrency.py
View file @
4e472eb5
...
@@ -21,8 +21,6 @@ from .common import *
...
@@ -21,8 +21,6 @@ from .common import *
try
:
try
:
if
six
.
PY2
:
if
six
.
PY2
:
from
tornado.concurrent
import
Future
from
tornado.concurrent
import
Future
import
tornado.options
as
options
options
.
parse_command_line
([
'--logging=debug'
])
else
:
else
:
from
concurrent.futures
import
Future
from
concurrent.futures
import
Future
except
ImportError
:
except
ImportError
:
...
@@ -146,6 +144,9 @@ class MultiThreadAsyncPredictor(object):
...
@@ -146,6 +144,9 @@ class MultiThreadAsyncPredictor(object):
for
id
,
f
in
enumerate
(
for
id
,
f
in
enumerate
(
trainer
.
get_predict_funcs
(
trainer
.
get_predict_funcs
(
input_names
,
output_names
,
nr_thread
))]
input_names
,
output_names
,
nr_thread
))]
# TODO XXX set logging here to avoid affecting TF logging
import
tornado.options
as
options
options
.
parse_command_line
([
'--logging=debug'
])
def
run
(
self
):
def
run
(
self
):
for
t
in
self
.
threads
:
for
t
in
self
.
threads
:
...
...
tensorpack/tfutils/symbolic_functions.py
View file @
4e472eb5
...
@@ -78,3 +78,8 @@ def print_stat(x):
...
@@ -78,3 +78,8 @@ def print_stat(x):
Use it like: x = print_stat(x)
Use it like: x = print_stat(x)
"""
"""
return
tf
.
Print
(
x
,
[
tf
.
reduce_mean
(
x
),
x
],
summarize
=
20
)
return
tf
.
Print
(
x
,
[
tf
.
reduce_mean
(
x
),
x
],
summarize
=
20
)
def
rms
(
x
,
name
=
None
):
if
name
is
None
:
name
=
x
.
op
.
name
+
'/rms'
return
tf
.
sqrt
(
tf
.
reduce_mean
(
tf
.
square
(
x
)),
name
=
name
)
tensorpack/train/base.py
View file @
4e472eb5
...
@@ -152,9 +152,9 @@ class Trainer(object):
...
@@ -152,9 +152,9 @@ class Trainer(object):
tf
.
train
.
start_queue_runners
(
tf
.
train
.
start_queue_runners
(
sess
=
self
.
sess
,
coord
=
self
.
coord
,
daemon
=
True
,
start
=
True
)
sess
=
self
.
sess
,
coord
=
self
.
coord
,
daemon
=
True
,
start
=
True
)
# avoid sigint get handled by other processes
with
self
.
sess
.
as_default
():
start_proc_mask_signal
(
self
.
extra_threads_procs
)
# avoid sigint get handled by other processes
start_proc_mask_signal
(
self
.
extra_threads_procs
)
def
process_grads
(
self
,
grads
):
def
process_grads
(
self
,
grads
):
g
=
[]
g
=
[]
...
...
tensorpack/utils/lut.py
View file @
4e472eb5
...
@@ -20,3 +20,6 @@ class LookUpTable(object):
...
@@ -20,3 +20,6 @@ class LookUpTable(object):
def
get_idx
(
self
,
obj
):
def
get_idx
(
self
,
obj
):
return
self
.
obj2idx
[
obj
]
return
self
.
obj2idx
[
obj
]
def
__str__
(
self
):
return
self
.
idx2obj
.
__str__
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment