Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
8d3c709c
Commit
8d3c709c
authored
Jul 18, 2016
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
py3 compat & simulator speedup
parent
4fc21080
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
83 additions
and
54 deletions
+83
-54
tensorpack/RL/atari.py
tensorpack/RL/atari.py
+8
-8
tensorpack/RL/simulator.py
tensorpack/RL/simulator.py
+35
-31
tensorpack/predict/base.py
tensorpack/predict/base.py
+3
-3
tensorpack/predict/concurrency.py
tensorpack/predict/concurrency.py
+8
-6
tensorpack/train/multigpu.py
tensorpack/train/multigpu.py
+2
-2
tensorpack/train/trainer.py
tensorpack/train/trainer.py
+5
-2
tensorpack/utils/timer.py
tensorpack/utils/timer.py
+22
-2
No files found.
tensorpack/RL/atari.py
View file @
8d3c709c
...
...
@@ -65,18 +65,18 @@ class AtariPlayer(RLEnvironment):
self
.
ale
=
ALEInterface
()
self
.
rng
=
get_rng
(
self
)
self
.
ale
.
setInt
(
"random_seed"
,
self
.
rng
.
randint
(
0
,
10000
))
self
.
ale
.
setBool
(
"showinfo"
,
False
)
self
.
ale
.
setInt
(
b
"random_seed"
,
self
.
rng
.
randint
(
0
,
10000
))
self
.
ale
.
setBool
(
b
"showinfo"
,
False
)
self
.
ale
.
setInt
(
"frame_skip"
,
1
)
self
.
ale
.
setBool
(
'color_averaging'
,
False
)
self
.
ale
.
setInt
(
b
"frame_skip"
,
1
)
self
.
ale
.
setBool
(
b
'color_averaging'
,
False
)
# manual.pdf suggests otherwise.
self
.
ale
.
setFloat
(
'repeat_action_probability'
,
0.0
)
self
.
ale
.
setFloat
(
b
'repeat_action_probability'
,
0.0
)
# viz setup
if
isinstance
(
viz
,
six
.
string_types
):
assert
os
.
path
.
isdir
(
viz
),
viz
self
.
ale
.
setString
(
'record_screen_dir'
,
viz
)
self
.
ale
.
setString
(
b
'record_screen_dir'
,
viz
)
viz
=
0
if
isinstance
(
viz
,
int
):
viz
=
float
(
viz
)
...
...
@@ -86,7 +86,7 @@ class AtariPlayer(RLEnvironment):
cv2
.
startWindowThread
()
cv2
.
namedWindow
(
self
.
windowname
)
self
.
ale
.
loadROM
(
rom_file
)
self
.
ale
.
loadROM
(
rom_file
.
encode
(
'utf-8'
)
)
self
.
width
,
self
.
height
=
self
.
ale
.
getScreenDims
()
self
.
actions
=
self
.
ale
.
getMinimalActionSet
()
...
...
@@ -184,7 +184,7 @@ if __name__ == '__main__':
cnt
+=
1
if
cnt
==
5000
:
break
print
time
.
time
()
-
start
print
(
time
.
time
()
-
start
)
if
len
(
sys
.
argv
)
==
3
and
sys
.
argv
[
2
]
==
'benchmark'
:
import
threading
,
multiprocessing
...
...
tensorpack/RL/simulator.py
View file @
8d3c709c
...
...
@@ -39,33 +39,30 @@ class SimulatorProcess(multiprocessing.Process):
self
.
c2s
=
pipe_c2s
self
.
s2c
=
pipe_s2c
self
.
identity
=
u'simulator-{}'
.
format
(
self
.
idx
)
.
encode
(
'utf-8'
)
def
run
(
self
):
player
=
self
.
_build_player
()
context
=
zmq
.
Context
()
c2s_socket
=
context
.
socket
(
zmq
.
DEALER
)
c2s_socket
.
identity
=
'simulator-{}'
.
format
(
self
.
idx
)
c2s_socket
=
context
.
socket
(
zmq
.
PUSH
)
c2s_socket
.
setsockopt
(
zmq
.
IDENTITY
,
self
.
identity
)
c2s_socket
.
set_hwm
(
2
)
c2s_socket
.
connect
(
self
.
c2s
)
s2c_socket
=
context
.
socket
(
zmq
.
DEALER
)
s2c_socket
.
identity
=
'simulator-{}'
.
format
(
self
.
idx
)
s2c_socket
.
setsockopt
(
zmq
.
IDENTITY
,
self
.
identity
)
#s2c_socket.set_hwm(5)
s2c_socket
.
connect
(
self
.
s2c
)
#cnt = 0
state
=
player
.
current_state
()
reward
,
isOver
=
0
,
False
while
True
:
state
=
player
.
current_state
()
c2s_socket
.
send
(
dumps
(
state
),
copy
=
False
)
#with total_timer('client recv_action'):
data
=
s2c_socket
.
recv
(
copy
=
False
)
action
=
loads
(
data
)
c2s_socket
.
send
(
dumps
(
(
self
.
identity
,
state
,
reward
,
isOver
)),
copy
=
False
)
action
=
loads
(
s2c_socket
.
recv
(
copy
=
False
))
reward
,
isOver
=
player
.
action
(
action
)
c2s_socket
.
send
(
dumps
((
reward
,
isOver
)),
copy
=
False
)
#with total_timer('client recv_ack'):
ACK
=
s2c_socket
.
recv
(
copy
=
False
)
#cnt += 1
#if cnt % 100 == 0:
#print_total_timer()
state
=
player
.
current_state
()
@
abstractmethod
def
_build_player
(
self
):
...
...
@@ -80,7 +77,6 @@ class SimulatorMaster(threading.Thread):
class
ClientState
(
object
):
def
__init__
(
self
):
self
.
protocol_state
=
0
# state in communication
self
.
memory
=
[]
# list of Experience
class
Experience
(
object
):
...
...
@@ -95,21 +91,25 @@ class SimulatorMaster(threading.Thread):
def
__init__
(
self
,
pipe_c2s
,
pipe_s2c
):
super
(
SimulatorMaster
,
self
)
.
__init__
()
self
.
daemon
=
True
self
.
context
=
zmq
.
Context
()
self
.
c2s_socket
=
self
.
context
.
socket
(
zmq
.
ROUTER
)
self
.
c2s_socket
=
self
.
context
.
socket
(
zmq
.
PULL
)
self
.
c2s_socket
.
bind
(
pipe_c2s
)
self
.
c2s_socket
.
set_hwm
(
10
)
self
.
s2c_socket
=
self
.
context
.
socket
(
zmq
.
ROUTER
)
self
.
s2c_socket
.
bind
(
pipe_s2c
)
self
.
socket_lock
=
threading
.
Lock
()
self
.
daemon
=
True
self
.
s2c_socket
.
set_hwm
(
10
)
# queueing messages to client
self
.
send_queue
=
queue
.
Queue
(
maxsize
=
100
)
self
.
send_thread
=
LoopThread
(
lambda
:
self
.
s2c_socket
.
send_multipart
(
self
.
send_queue
.
get
()))
def
f
():
msg
=
self
.
send_queue
.
get
()
# slow
self
.
s2c_socket
.
send_multipart
(
msg
,
copy
=
False
)
self
.
send_thread
=
LoopThread
(
f
)
self
.
send_thread
.
daemon
=
True
self
.
send_thread
.
start
()
...
...
@@ -123,21 +123,25 @@ class SimulatorMaster(threading.Thread):
def
run
(
self
):
self
.
clients
=
defaultdict
(
self
.
ClientState
)
#cnt = 0
while
True
:
ident
,
msg
=
self
.
c2s_socket
.
recv_multipart
()
#cnt += 1
#if cnt % 3000 == 0:
#print_total_timer()
msg
=
loads
(
self
.
c2s_socket
.
recv
(
copy
=
False
)
.
bytes
)
ident
,
state
,
reward
,
isOver
=
msg
client
=
self
.
clients
[
ident
]
client
.
protocol_state
=
1
-
client
.
protocol_state
# first flip the state
if
not
client
.
protocol_state
==
0
:
# state-action
state
=
loads
(
msg
)
self
.
_on_state
(
state
,
ident
)
else
:
# reward-response
reward
,
isOver
=
loads
(
msg
)
# check if reward&isOver is valid
# in the first message, only state is valid
if
len
(
client
.
memory
)
>
0
:
client
.
memory
[
-
1
]
.
reward
=
reward
if
isOver
:
self
.
_on_episode_over
(
ident
)
else
:
self
.
_on_datapoint
(
ident
)
self
.
send_queue
.
put
([
ident
,
'Thanks'
])
# just an ACK
# feed state and return action
self
.
_on_state
(
state
,
ident
)
@
abstractmethod
def
_on_state
(
self
,
state
,
ident
):
...
...
tensorpack/predict/base.py
View file @
8d3c709c
...
...
@@ -39,9 +39,9 @@ class AsyncPredictorBase(PredictorBase):
"""
:param dp: A data point (list of component) as inputs.
(It should be either batched or not batched depending on the predictor implementation)
:param callback: a thread-safe callback to get called with
the list of
outputs of (inputs, outputs) pair
:return: a Future of
outpu
ts
:param callback: a thread-safe callback to get called with
either outputs or (inputs, outputs)
:return: a Future of
resul
ts
"""
@
abstractmethod
...
...
tensorpack/predict/concurrency.py
View file @
8d3c709c
...
...
@@ -82,16 +82,16 @@ class PredictorWorkerThread(threading.Thread):
self
.
id
=
id
def
run
(
self
):
#self.xxx = None
while
True
:
batched
,
futures
=
self
.
fetch_batch
()
outputs
=
self
.
func
(
batched
)
#print "batched size: ", len(batched[0]), "queuesize: ", self.queue.qsize()
#print "Worker {} batched {} Queue {}".format(
#self.id, len(futures), self.queue.qsize())
# debug, for speed testing
#if
self.xxx is None
:
#self.xxx = outputs = self.func(
[batched]
)
#if
not hasattr(self, 'xxx')
:
#self.xxx = outputs = self.func(
batched
)
#else:
#outputs = [[self.xxx[0][0]] * len(batched
), [self.xxx[1][0]] * len(batched
)]
#outputs = [[self.xxx[0][0]] * len(batched
[0]), [self.xxx[1][0]] * len(batched[0]
)]
for
idx
,
f
in
enumerate
(
futures
):
f
.
set_result
([
k
[
idx
]
for
k
in
outputs
])
...
...
@@ -125,7 +125,9 @@ class MultiThreadAsyncPredictor(AsyncPredictorBase):
""" :param predictors: a list of OnlinePredictor"""
for
k
in
predictors
:
assert
isinstance
(
k
,
OnlinePredictor
),
type
(
k
)
self
.
input_queue
=
queue
.
Queue
(
maxsize
=
len
(
predictors
)
*
10
)
# TODO use predictors.return_input here
assert
k
.
return_input
==
False
self
.
input_queue
=
queue
.
Queue
(
maxsize
=
len
(
predictors
)
*
100
)
self
.
threads
=
[
PredictorWorkerThread
(
self
.
input_queue
,
f
,
id
,
batch_size
=
batch_size
)
...
...
tensorpack/train/multigpu.py
View file @
8d3c709c
...
...
@@ -115,7 +115,7 @@ class AsyncMultiGPUTrainer(MultiGPUTrainer):
train_op
=
self
.
config
.
optimizer
.
apply_gradients
(
grad_list
[
k
])
def
f
(
op
=
train_op
):
# avoid late-binding
self
.
sess
.
run
([
op
])
self
.
async_step_counter
.
next
(
)
next
(
self
.
async_step_counter
)
th
=
LoopThread
(
f
)
th
.
pause
()
th
.
start
()
...
...
@@ -127,7 +127,7 @@ class AsyncMultiGPUTrainer(MultiGPUTrainer):
self
.
async_running
=
True
for
th
in
self
.
training_threads
:
# resume all threads
th
.
resume
()
self
.
async_step_counter
.
next
(
)
next
(
self
.
async_step_counter
)
super
(
AsyncMultiGPUTrainer
,
self
)
.
run_step
()
def
_trigger_epoch
(
self
):
...
...
tensorpack/train/trainer.py
View file @
8d3c709c
...
...
@@ -202,6 +202,9 @@ class QueueInputTrainer(Trainer):
self
.
config
.
optimizer
.
apply_gradients
(
grads
,
get_global_step_var
()),
summary_moving_average
(),
name
=
'train_op'
)
# skip training
#self.train_op = tf.group(*self.dequed_inputs)
self
.
main_loop
()
def
run_step
(
self
):
...
...
@@ -218,8 +221,6 @@ class QueueInputTrainer(Trainer):
#trace_file.write(trace.generate_chrome_trace_format())
#import sys; sys.exit()
#self.sess.run([self.dequed_inputs[1]])
def
_trigger_epoch
(
self
):
# need to run summary_op every epoch
# note that summary_op will take a data from the queue
...
...
@@ -234,3 +235,5 @@ class QueueInputTrainer(Trainer):
"""
return
self
.
predictor_factory
.
get_predictor
(
input_names
,
output_names
,
tower
)
def
get_predict_funcs
(
self
,
input_names
,
output_names
,
n
):
return
[
self
.
get_predict_func
(
input_names
,
output_names
,
k
)
for
k
in
range
(
n
)]
tensorpack/utils/timer.py
View file @
8d3c709c
...
...
@@ -13,7 +13,27 @@ import atexit
from
.stat
import
StatCounter
from
.
import
logger
__all__
=
[
'total_timer'
,
'timed_operation'
,
'print_total_timer'
]
__all__
=
[
'total_timer'
,
'timed_operation'
,
'print_total_timer'
,
'IterSpeedCounter'
]
class
IterSpeedCounter
(
object
):
def
__init__
(
self
,
print_every
,
name
=
None
):
self
.
cnt
=
0
self
.
print_every
=
int
(
print_every
)
self
.
name
=
name
if
name
else
'IterSpeed'
def
reset
(
self
):
self
.
start
=
time
.
time
()
def
__call__
(
self
):
if
self
.
cnt
==
0
:
self
.
reset
()
self
.
cnt
+=
1
if
self
.
cnt
%
self
.
print_every
!=
0
:
return
t
=
time
.
time
()
-
self
.
start
logger
.
info
(
"{}: {:.2f} sec, {} times, {:.3g} sec/time"
.
format
(
self
.
name
,
t
,
self
.
cnt
,
t
/
self
.
cnt
))
@
contextmanager
def
timed_operation
(
msg
,
log_start
=
False
):
...
...
@@ -37,7 +57,7 @@ def print_total_timer():
if
len
(
_TOTAL_TIMER_DATA
)
==
0
:
return
for
k
,
v
in
six
.
iteritems
(
_TOTAL_TIMER_DATA
):
logger
.
info
(
"Total Time: {} -> {
} sec, {} times, {
} sec/time"
.
format
(
logger
.
info
(
"Total Time: {} -> {
:.2f} sec, {} times, {:.3g
} sec/time"
.
format
(
k
,
v
.
sum
,
v
.
count
,
v
.
average
))
atexit
.
register
(
print_total_timer
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment