Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
9d3cf419
Commit
9d3cf419
authored
Jun 02, 2016
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
total timer
parent
bc1ba816
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
80 additions
and
24 deletions
+80
-24
examples/Atari2600/DQN.py
examples/Atari2600/DQN.py
+1
-1
tensorpack/RL/atari.py
tensorpack/RL/atari.py
+39
-13
tensorpack/dataflow/dataset/visualqa.py
tensorpack/dataflow/dataset/visualqa.py
+1
-0
tensorpack/train/base.py
tensorpack/train/base.py
+1
-0
tensorpack/utils/timer.py
tensorpack/utils/timer.py
+37
-0
tensorpack/utils/utils.py
tensorpack/utils/utils.py
+1
-10
No files found.
examples/Atari2600/DQN.py
View file @
9d3cf419
...
@@ -189,7 +189,7 @@ def eval_with_funcs(predict_funcs):
...
@@ -189,7 +189,7 @@ def eval_with_funcs(predict_funcs):
score
=
play_one_episode
(
player
,
self
.
func
)
score
=
play_one_episode
(
player
,
self
.
func
)
self
.
queue_put_stoppable
(
self
.
q
,
score
)
self
.
queue_put_stoppable
(
self
.
q
,
score
)
q
=
queue
.
Queue
()
q
=
queue
.
Queue
(
maxsize
=
3
)
threads
=
[
Worker
(
f
,
q
)
for
f
in
predict_funcs
]
threads
=
[
Worker
(
f
,
q
)
for
f
in
predict_funcs
]
for
k
in
threads
:
for
k
in
threads
:
...
...
tensorpack/RL/atari.py
View file @
9d3cf419
...
@@ -158,18 +158,44 @@ class AtariPlayer(RLEnvironment):
...
@@ -158,18 +158,44 @@ class AtariPlayer(RLEnvironment):
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
import
sys
import
sys
a
=
AtariPlayer
(
sys
.
argv
[
1
],
viz
=
0.03
,
height_range
=
(
28
,
-
8
))
num
=
a
.
get_num_actions
()
rng
=
get_rng
(
num
)
import
time
import
time
while
True
:
#im = a.grab_image()
def
benchmark
():
#cv2.imshow(a.romname, im)
a
=
AtariPlayer
(
sys
.
argv
[
1
],
viz
=
False
,
height_range
=
(
28
,
-
8
))
act
=
rng
.
choice
(
range
(
num
))
num
=
a
.
get_num_actions
()
print
(
act
)
rng
=
get_rng
(
num
)
r
,
o
=
a
.
action
(
act
)
start
=
time
.
time
()
a
.
current_state
()
cnt
=
0
#time.sleep(0.1)
while
True
:
print
(
r
,
o
)
act
=
rng
.
choice
(
range
(
num
))
r
,
o
=
a
.
action
(
act
)
a
.
current_state
()
cnt
+=
1
if
cnt
==
5000
:
break
print
time
.
time
()
-
start
if
len
(
sys
.
argv
)
==
3
and
sys
.
argv
[
2
]
==
'benchmark'
:
import
threading
,
multiprocessing
for
k
in
range
(
3
):
#th = multiprocessing.Process(target=benchmark)
th
=
threading
.
Thread
(
target
=
benchmark
)
th
.
start
()
time
.
sleep
(
0.02
)
benchmark
()
else
:
a
=
AtariPlayer
(
sys
.
argv
[
1
],
viz
=
0.03
,
height_range
=
(
28
,
-
8
))
num
=
a
.
get_num_actions
()
rng
=
get_rng
(
num
)
import
time
while
True
:
#im = a.grab_image()
#cv2.imshow(a.romname, im)
act
=
rng
.
choice
(
range
(
num
))
print
(
act
)
r
,
o
=
a
.
action
(
act
)
a
.
current_state
()
#time.sleep(0.1)
print
(
r
,
o
)
tensorpack/dataflow/dataset/visualqa.py
View file @
9d3cf419
...
@@ -5,6 +5,7 @@
...
@@ -5,6 +5,7 @@
from
..base
import
DataFlow
from
..base
import
DataFlow
from
...utils
import
*
from
...utils
import
*
from
...utils.timer
import
*
from
six.moves
import
zip
,
map
from
six.moves
import
zip
,
map
from
collections
import
Counter
from
collections
import
Counter
import
json
import
json
...
...
tensorpack/train/base.py
View file @
9d3cf419
...
@@ -11,6 +11,7 @@ import tqdm
...
@@ -11,6 +11,7 @@ import tqdm
import
tensorflow
as
tf
import
tensorflow
as
tf
from
.config
import
TrainConfig
from
.config
import
TrainConfig
from
..utils
import
*
from
..utils
import
*
from
..utils.timer
import
*
from
..utils.concurrency
import
start_proc_mask_signal
from
..utils.concurrency
import
start_proc_mask_signal
from
..callbacks
import
StatHolder
from
..callbacks
import
StatHolder
from
..tfutils
import
*
from
..tfutils
import
*
...
...
tensorpack/utils/timer.py
0 → 100644
View file @
9d3cf419
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# File: timer.py
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
from
contextlib
import
contextmanager
import
time
from
collections
import
defaultdict
import
six
from
.stat
import
StatCounter
from
.
import
logger
__all__
=
[
'total_timer'
,
'timed_operation'
,
'print_total_timer'
]
@
contextmanager
def
timed_operation
(
msg
,
log_start
=
False
):
if
log_start
:
logger
.
info
(
'start {} ...'
.
format
(
msg
))
start
=
time
.
time
()
yield
logger
.
info
(
'{} finished, time={:.2f}sec.'
.
format
(
msg
,
time
.
time
()
-
start
))
_TOTAL_TIMER_DATA
=
defaultdict
(
StatCounter
)
@
contextmanager
def
total_timer
(
msg
):
start
=
time
.
time
()
yield
t
=
time
.
time
()
-
start
_TOTAL_TIMER_DATA
[
msg
]
.
feed
(
t
)
def
print_total_timer
():
for
k
,
v
in
six
.
iteritems
(
_TOTAL_TIMER_DATA
):
logger
.
info
(
"Total Time: {} -> {} sec"
.
format
(
k
,
v
.
sum
))
tensorpack/utils/utils.py
View file @
9d3cf419
...
@@ -11,7 +11,7 @@ import numpy as np
...
@@ -11,7 +11,7 @@ import numpy as np
from
.
import
logger
from
.
import
logger
__all__
=
[
'
timed_operation'
,
'
change_env'
,
__all__
=
[
'change_env'
,
'get_rng'
,
'memoized'
,
'get_nr_gpu'
,
'get_gpus'
]
'get_rng'
,
'memoized'
,
'get_nr_gpu'
,
'get_gpus'
]
#def expand_dim_if_necessary(var, dp):
#def expand_dim_if_necessary(var, dp):
...
@@ -28,15 +28,6 @@ __all__ = ['timed_operation', 'change_env',
...
@@ -28,15 +28,6 @@ __all__ = ['timed_operation', 'change_env',
# dp = dp.reshape(new_shape)
# dp = dp.reshape(new_shape)
# return dp
# return dp
@
contextmanager
def
timed_operation
(
msg
,
log_start
=
False
):
if
log_start
:
logger
.
info
(
'start {} ...'
.
format
(
msg
))
start
=
time
.
time
()
yield
logger
.
info
(
'{} finished, time={:.2f}sec.'
.
format
(
msg
,
time
.
time
()
-
start
))
@
contextmanager
@
contextmanager
def
change_env
(
name
,
val
):
def
change_env
(
name
,
val
):
oldval
=
os
.
environ
.
get
(
name
,
None
)
oldval
=
os
.
environ
.
get
(
name
,
None
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment