Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
aaf4cc78
Commit
aaf4cc78
authored
Jun 11, 2016
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
update atari/common
parent
0485c1de
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
14 additions
and
14 deletions
+14
-14
examples/Atari2600/common.py
examples/Atari2600/common.py
+4
-3
tensorpack/dataflow/dataset/bsds500.py
tensorpack/dataflow/dataset/bsds500.py
+1
-1
tensorpack/dataflow/dataset/svhn.py
tensorpack/dataflow/dataset/svhn.py
+1
-1
tensorpack/train/base.py
tensorpack/train/base.py
+3
-1
tensorpack/utils/logger.py
tensorpack/utils/logger.py
+5
-8
No files found.
examples/Atari2600/common.py
View file @
aaf4cc78
...
@@ -75,14 +75,15 @@ def eval_model_multithread(cfg, nr_eval):
...
@@ -75,14 +75,15 @@ def eval_model_multithread(cfg, nr_eval):
logger
.
info
(
"Average Score: {}; Max Score: {}"
.
format
(
mean
,
max
))
logger
.
info
(
"Average Score: {}; Max Score: {}"
.
format
(
mean
,
max
))
class
Evaluator
(
Callback
):
class
Evaluator
(
Callback
):
def
__init__
(
self
,
nr_eval
,
output_name
):
def
__init__
(
self
,
nr_eval
,
input_names
,
output_names
):
self
.
eval_episode
=
nr_eval
self
.
eval_episode
=
nr_eval
self
.
output_name
=
output_name
self
.
input_names
=
input_names
self
.
output_names
=
output_names
def
_before_train
(
self
):
def
_before_train
(
self
):
NR_PROC
=
min
(
multiprocessing
.
cpu_count
()
//
2
,
8
)
NR_PROC
=
min
(
multiprocessing
.
cpu_count
()
//
2
,
8
)
self
.
pred_funcs
=
[
self
.
trainer
.
get_predict_func
(
self
.
pred_funcs
=
[
self
.
trainer
.
get_predict_func
(
[
'state'
],
[
self
.
output_name
]
)]
*
NR_PROC
self
.
input_names
,
self
.
output_names
)]
*
NR_PROC
def
_trigger_epoch
(
self
):
def
_trigger_epoch
(
self
):
t
=
time
.
time
()
t
=
time
.
time
()
...
...
tensorpack/dataflow/dataset/bsds500.py
View file @
aaf4cc78
...
@@ -15,7 +15,7 @@ try:
...
@@ -15,7 +15,7 @@ try:
from
scipy.io
import
loadmat
from
scipy.io
import
loadmat
__all__
=
[
'BSDS500'
]
__all__
=
[
'BSDS500'
]
except
ImportError
:
except
ImportError
:
logger
.
error
(
"Cannot import scipy. BSDS500 dataset won't be available!"
)
logger
.
warn
(
"Cannot import scipy. BSDS500 dataset won't be available!"
)
__all__
=
[]
__all__
=
[]
DATA_URL
=
"http://www.eecs.berkeley.edu/Research/Projects/CS/vision/grouping/BSR/BSR_bsds500.tgz"
DATA_URL
=
"http://www.eecs.berkeley.edu/Research/Projects/CS/vision/grouping/BSR/BSR_bsds500.tgz"
...
...
tensorpack/dataflow/dataset/svhn.py
View file @
aaf4cc78
...
@@ -15,7 +15,7 @@ try:
...
@@ -15,7 +15,7 @@ try:
import
scipy.io
import
scipy.io
__all__
=
[
'SVHNDigit'
]
__all__
=
[
'SVHNDigit'
]
except
ImportError
:
except
ImportError
:
logger
.
error
(
"Cannot import scipy. SVHNDigit dataset won't be available!"
)
logger
.
warn
(
"Cannot import scipy. SVHNDigit dataset won't be available!"
)
__all__
=
[]
__all__
=
[]
SVHN_URL
=
"http://ufldl.stanford.edu/housenumbers/"
SVHN_URL
=
"http://ufldl.stanford.edu/housenumbers/"
...
...
tensorpack/train/base.py
View file @
aaf4cc78
...
@@ -122,7 +122,9 @@ class Trainer(object):
...
@@ -122,7 +122,9 @@ class Trainer(object):
for
step
in
tqdm
.
trange
(
for
step
in
tqdm
.
trange
(
self
.
config
.
step_per_epoch
,
self
.
config
.
step_per_epoch
,
leave
=
True
,
mininterval
=
0.5
,
leave
=
True
,
mininterval
=
0.5
,
dynamic_ncols
=
True
,
ascii
=
True
):
smoothing
=
0.5
,
dynamic_ncols
=
True
,
ascii
=
True
):
#bar_format='{l_bar}{bar}|{n_fmt}/{total_fmt} [{elapsed}<{remaining},{rate_noinv_fmt}]'):
#bar_format='{l_bar}{bar}|{n_fmt}/{total_fmt} [{elapsed}<{remaining},{rate_noinv_fmt}]'):
if
self
.
coord
.
should_stop
():
if
self
.
coord
.
should_stop
():
return
return
...
...
tensorpack/utils/logger.py
View file @
aaf4cc78
...
@@ -16,7 +16,7 @@ __all__ = []
...
@@ -16,7 +16,7 @@ __all__ = []
class
MyFormatter
(
logging
.
Formatter
):
class
MyFormatter
(
logging
.
Formatter
):
def
format
(
self
,
record
):
def
format
(
self
,
record
):
date
=
colored
(
'[
%(asctime)
s
%(lineno)
d@
%(filename)
s:
%(name)
s
]'
,
'green'
)
date
=
colored
(
'[
%(asctime)
s
@
%(filename)
s:
%(lineno)
d
]'
,
'green'
)
msg
=
'
%(message)
s'
msg
=
'
%(message)
s'
if
record
.
levelno
==
logging
.
WARNING
:
if
record
.
levelno
==
logging
.
WARNING
:
fmt
=
date
+
' '
+
colored
(
'WRN'
,
'red'
,
attrs
=
[
'blink'
])
+
' '
+
msg
fmt
=
date
+
' '
+
colored
(
'WRN'
,
'red'
,
attrs
=
[
'blink'
])
+
' '
+
msg
...
@@ -27,25 +27,22 @@ class MyFormatter(logging.Formatter):
...
@@ -27,25 +27,22 @@ class MyFormatter(logging.Formatter):
if
hasattr
(
self
,
'_style'
):
if
hasattr
(
self
,
'_style'
):
# Python3 compatibilty
# Python3 compatibilty
self
.
_style
.
_fmt
=
fmt
self
.
_style
.
_fmt
=
fmt
self
.
_fmt
=
fmt
self
.
_fmt
=
fmt
else
:
self
.
_fmt
=
fmt
return
super
(
MyFormatter
,
self
)
.
format
(
record
)
return
super
(
MyFormatter
,
self
)
.
format
(
record
)
def
getlogger
():
def
getlogger
():
logger
=
logging
.
getLogger
(
'tensorpack'
)
logger
=
logging
.
getLogger
(
'tensorpack'
)
logger
.
propagate
=
False
logger
.
propagate
=
False
logger
.
setLevel
(
logging
.
INFO
)
logger
.
setLevel
(
logging
.
INFO
)
handler
=
logging
.
StreamHandler
()
handler
=
logging
.
StreamHandler
(
sys
.
stdout
)
handler
.
setFormatter
(
MyFormatter
(
datefmt
=
'
%
d
%
H:
%
M:
%
S'
))
handler
.
setFormatter
(
MyFormatter
(
datefmt
=
'
%
d
%
H:
%
M:
%
S'
))
logger
.
addHandler
(
handler
)
logger
.
addHandler
(
handler
)
return
logger
return
logger
logger
=
getlogger
()
def
get_time_str
():
def
get_time_str
():
return
datetime
.
now
()
.
strftime
(
'
%
m
%
d-
%
H
%
M
%
S'
)
return
datetime
.
now
()
.
strftime
(
'
%
m
%
d-
%
H
%
M
%
S'
)
logger
=
getlogger
()
# logger file and directory:
# logger file and directory:
global
LOG_FILE
,
LOG_DIR
global
LOG_FILE
,
LOG_DIR
def
_set_file
(
path
):
def
_set_file
(
path
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment