Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
be3409fb
Commit
be3409fb
authored
Sep 11, 2016
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
execute_only_once and maxsaver
parent
a4867550
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
38 additions
and
30 deletions
+38
-30
docs/update.sh
docs/update.sh
+7
-0
examples/Atari2600/atari.py
examples/Atari2600/atari.py
+4
-6
tensorpack/callbacks/common.py
tensorpack/callbacks/common.py
+12
-7
tensorpack/utils/utils.py
tensorpack/utils/utils.py
+15
-17
No files found.
docs/update.sh
0 → 100755
View file @
be3409fb
#!/bin/bash -e
# File: update.sh
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
make clean
sphinx-apidoc
-o
modules ../tensorpack
-f
-d
10
make html
examples/Atari2600/atari.py
View file @
be3409fb
...
...
@@ -10,7 +10,8 @@ from collections import deque
import
threading
import
six
from
six.moves
import
range
from
tensorpack.utils
import
get_rng
,
logger
,
memoized
,
get_dataset_path
from
tensorpack.utils
import
(
get_rng
,
logger
,
memoized
,
get_dataset_path
,
execute_only_once
)
from
tensorpack.utils.stat
import
StatCounter
from
tensorpack.RL.envbase
import
RLEnvironment
,
DiscreteActionSpace
...
...
@@ -19,10 +20,6 @@ from ale_python_interface import ALEInterface
__all__
=
[
'AtariPlayer'
]
@
memoized
def
log_once
():
logger
.
warn
(
"https://github.com/mgbellemare/Arcade-Learning-Environment/pull/171 is not merged!"
)
ROM_URL
=
"https://github.com/openai/atari-py/tree/master/atari_py/atari_roms"
_ALE_LOCK
=
threading
.
Lock
()
...
...
@@ -56,7 +53,8 @@ class AtariPlayer(RLEnvironment):
try
:
ALEInterface
.
setLoggerMode
(
ALEInterface
.
Logger
.
Warning
)
except
AttributeError
:
log_once
()
if
execute_only_once
():
logger
.
warn
(
"https://github.com/mgbellemare/Arcade-Learning-Environment/pull/171 is not merged!"
)
# avoid simulator bugs: https://github.com/mgbellemare/Arcade-Learning-Environment/issues/86
with
_ALE_LOCK
:
...
...
tensorpack/callbacks/common.py
View file @
be3409fb
...
...
@@ -88,25 +88,30 @@ class MinSaver(Callback):
self
.
min
=
None
def
_get_stat
(
self
):
return
self
.
trainer
.
stat_holder
.
get_stat_now
(
self
.
monitor_stat
)
try
:
v
=
self
.
trainer
.
stat_holder
.
get_stat_now
(
self
.
monitor_stat
)
except
KeyError
:
v
=
None
return
v
def
_need_save
(
self
):
if
self
.
reverse
:
return
self
.
_get_stat
()
>
self
.
min
else
:
return
self
.
_get_stat
()
<
self
.
min
v
=
self
.
_get_stat
()
if
not
v
:
return
False
return
v
>
self
.
min
if
self
.
reverse
else
v
<
self
.
min
def
_trigger_epoch
(
self
):
if
self
.
min
is
None
or
self
.
_need_save
():
self
.
min
=
self
.
_get_stat
()
self
.
_save
()
if
self
.
min
:
self
.
_save
()
def
_save
(
self
):
ckpt
=
tf
.
train
.
get_checkpoint_state
(
logger
.
LOG_DIR
)
if
ckpt
is
None
:
raise
RuntimeError
(
"Cannot find a checkpoint state. Do you forget to use ModelSaver?"
)
path
=
c
h
pt
.
model_checkpoint_path
path
=
c
k
pt
.
model_checkpoint_path
newname
=
os
.
path
.
join
(
logger
.
LOG_DIR
,
self
.
filename
or
(
'max-'
if
self
.
reverse
else
'min-'
+
self
.
monitor_stat
+
'.tfmodel'
))
...
...
tensorpack/utils/utils.py
View file @
be3409fb
...
...
@@ -15,23 +15,10 @@ __all__ = ['change_env',
'map_arg'
,
'get_rng'
,
'memoized'
,
'get_dataset_path'
,
'get_tqdm_kwargs'
'get_tqdm_kwargs'
,
'execute_only_once'
]
#def expand_dim_if_necessary(var, dp):
# """
# Args:
# var: a tensor
# dp: a numpy array
# Return a reshaped version of dp, if that makes it match the valid dimension of var
# """
# shape = var.get_shape().as_list()
# valid_shape = [k for k in shape if k]
# if dp.shape == tuple(valid_shape):
# new_shape = [k if k else 1 for k in shape]
# dp = dp.reshape(new_shape)
# return dp
@
contextmanager
def
change_env
(
name
,
val
):
oldval
=
os
.
environ
.
get
(
name
,
None
)
...
...
@@ -104,13 +91,24 @@ def get_rng(obj=None):
int
(
datetime
.
now
()
.
strftime
(
"
%
Y
%
m
%
d
%
H
%
M
%
S
%
f"
)))
%
4294967295
return
np
.
random
.
RandomState
(
seed
)
_EXECUTE_HISTORY
=
set
()
def
execute_only_once
():
f
=
inspect
.
currentframe
()
.
f_back
ident
=
(
f
.
f_code
.
co_filename
,
f
.
f_lineno
)
if
ident
in
_EXECUTE_HISTORY
:
return
False
_EXECUTE_HISTORY
.
add
(
ident
)
return
True
def
get_dataset_path
(
*
args
):
from
.
import
logger
d
=
os
.
environ
.
get
(
'TENSORPACK_DATASET'
,
None
)
if
d
is
None
:
d
=
os
.
path
.
abspath
(
os
.
path
.
join
(
os
.
path
.
dirname
(
__file__
),
'..'
,
'dataflow'
,
'dataset'
))
logger
.
info
(
"TENSORPACK_DATASET not set, using {} for dataset."
.
format
(
d
))
if
execute_only_once
():
from
.
import
logger
logger
.
info
(
"TENSORPACK_DATASET not set, using {} for dataset."
.
format
(
d
))
assert
os
.
path
.
isdir
(
d
),
d
return
os
.
path
.
join
(
d
,
*
args
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment