Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
8d3c709c
Commit
8d3c709c
authored
Jul 18, 2016
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
py3 compat & simulator speedup
parent
4fc21080
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
83 additions
and
54 deletions
+83
-54
tensorpack/RL/atari.py
tensorpack/RL/atari.py
+8
-8
tensorpack/RL/simulator.py
tensorpack/RL/simulator.py
+35
-31
tensorpack/predict/base.py
tensorpack/predict/base.py
+3
-3
tensorpack/predict/concurrency.py
tensorpack/predict/concurrency.py
+8
-6
tensorpack/train/multigpu.py
tensorpack/train/multigpu.py
+2
-2
tensorpack/train/trainer.py
tensorpack/train/trainer.py
+5
-2
tensorpack/utils/timer.py
tensorpack/utils/timer.py
+22
-2
No files found.
tensorpack/RL/atari.py
View file @
8d3c709c
...
@@ -65,18 +65,18 @@ class AtariPlayer(RLEnvironment):
...
@@ -65,18 +65,18 @@ class AtariPlayer(RLEnvironment):
self
.
ale
=
ALEInterface
()
self
.
ale
=
ALEInterface
()
self
.
rng
=
get_rng
(
self
)
self
.
rng
=
get_rng
(
self
)
self
.
ale
.
setInt
(
"random_seed"
,
self
.
rng
.
randint
(
0
,
10000
))
self
.
ale
.
setInt
(
b
"random_seed"
,
self
.
rng
.
randint
(
0
,
10000
))
self
.
ale
.
setBool
(
"showinfo"
,
False
)
self
.
ale
.
setBool
(
b
"showinfo"
,
False
)
self
.
ale
.
setInt
(
"frame_skip"
,
1
)
self
.
ale
.
setInt
(
b
"frame_skip"
,
1
)
self
.
ale
.
setBool
(
'color_averaging'
,
False
)
self
.
ale
.
setBool
(
b
'color_averaging'
,
False
)
# manual.pdf suggests otherwise.
# manual.pdf suggests otherwise.
self
.
ale
.
setFloat
(
'repeat_action_probability'
,
0.0
)
self
.
ale
.
setFloat
(
b
'repeat_action_probability'
,
0.0
)
# viz setup
# viz setup
if
isinstance
(
viz
,
six
.
string_types
):
if
isinstance
(
viz
,
six
.
string_types
):
assert
os
.
path
.
isdir
(
viz
),
viz
assert
os
.
path
.
isdir
(
viz
),
viz
self
.
ale
.
setString
(
'record_screen_dir'
,
viz
)
self
.
ale
.
setString
(
b
'record_screen_dir'
,
viz
)
viz
=
0
viz
=
0
if
isinstance
(
viz
,
int
):
if
isinstance
(
viz
,
int
):
viz
=
float
(
viz
)
viz
=
float
(
viz
)
...
@@ -86,7 +86,7 @@ class AtariPlayer(RLEnvironment):
...
@@ -86,7 +86,7 @@ class AtariPlayer(RLEnvironment):
cv2
.
startWindowThread
()
cv2
.
startWindowThread
()
cv2
.
namedWindow
(
self
.
windowname
)
cv2
.
namedWindow
(
self
.
windowname
)
self
.
ale
.
loadROM
(
rom_file
)
self
.
ale
.
loadROM
(
rom_file
.
encode
(
'utf-8'
)
)
self
.
width
,
self
.
height
=
self
.
ale
.
getScreenDims
()
self
.
width
,
self
.
height
=
self
.
ale
.
getScreenDims
()
self
.
actions
=
self
.
ale
.
getMinimalActionSet
()
self
.
actions
=
self
.
ale
.
getMinimalActionSet
()
...
@@ -184,7 +184,7 @@ if __name__ == '__main__':
...
@@ -184,7 +184,7 @@ if __name__ == '__main__':
cnt
+=
1
cnt
+=
1
if
cnt
==
5000
:
if
cnt
==
5000
:
break
break
print
time
.
time
()
-
start
print
(
time
.
time
()
-
start
)
if
len
(
sys
.
argv
)
==
3
and
sys
.
argv
[
2
]
==
'benchmark'
:
if
len
(
sys
.
argv
)
==
3
and
sys
.
argv
[
2
]
==
'benchmark'
:
import
threading
,
multiprocessing
import
threading
,
multiprocessing
...
...
tensorpack/RL/simulator.py
View file @
8d3c709c
...
@@ -39,33 +39,30 @@ class SimulatorProcess(multiprocessing.Process):
...
@@ -39,33 +39,30 @@ class SimulatorProcess(multiprocessing.Process):
self
.
c2s
=
pipe_c2s
self
.
c2s
=
pipe_c2s
self
.
s2c
=
pipe_s2c
self
.
s2c
=
pipe_s2c
self
.
identity
=
u'simulator-{}'
.
format
(
self
.
idx
)
.
encode
(
'utf-8'
)
def
run
(
self
):
def
run
(
self
):
player
=
self
.
_build_player
()
player
=
self
.
_build_player
()
context
=
zmq
.
Context
()
context
=
zmq
.
Context
()
c2s_socket
=
context
.
socket
(
zmq
.
DEALER
)
c2s_socket
=
context
.
socket
(
zmq
.
PUSH
)
c2s_socket
.
identity
=
'simulator-{}'
.
format
(
self
.
idx
)
c2s_socket
.
setsockopt
(
zmq
.
IDENTITY
,
self
.
identity
)
c2s_socket
.
set_hwm
(
2
)
c2s_socket
.
set_hwm
(
2
)
c2s_socket
.
connect
(
self
.
c2s
)
c2s_socket
.
connect
(
self
.
c2s
)
s2c_socket
=
context
.
socket
(
zmq
.
DEALER
)
s2c_socket
=
context
.
socket
(
zmq
.
DEALER
)
s2c_socket
.
identity
=
'simulator-{}'
.
format
(
self
.
idx
)
s2c_socket
.
setsockopt
(
zmq
.
IDENTITY
,
self
.
identity
)
#s2c_socket.set_hwm(5)
#s2c_socket.set_hwm(5)
s2c_socket
.
connect
(
self
.
s2c
)
s2c_socket
.
connect
(
self
.
s2c
)
#cnt = 0
state
=
player
.
current_state
()
reward
,
isOver
=
0
,
False
while
True
:
while
True
:
state
=
player
.
current_state
()
c2s_socket
.
send
(
dumps
(
c2s_socket
.
send
(
dumps
(
state
),
copy
=
False
)
(
self
.
identity
,
state
,
reward
,
isOver
)),
#with total_timer('client recv_action'):
copy
=
False
)
data
=
s2c_socket
.
recv
(
copy
=
False
)
action
=
loads
(
s2c_socket
.
recv
(
copy
=
False
))
action
=
loads
(
data
)
reward
,
isOver
=
player
.
action
(
action
)
reward
,
isOver
=
player
.
action
(
action
)
c2s_socket
.
send
(
dumps
((
reward
,
isOver
)),
copy
=
False
)
state
=
player
.
current_state
()
#with total_timer('client recv_ack'):
ACK
=
s2c_socket
.
recv
(
copy
=
False
)
#cnt += 1
#if cnt % 100 == 0:
#print_total_timer()
@
abstractmethod
@
abstractmethod
def
_build_player
(
self
):
def
_build_player
(
self
):
...
@@ -80,7 +77,6 @@ class SimulatorMaster(threading.Thread):
...
@@ -80,7 +77,6 @@ class SimulatorMaster(threading.Thread):
class
ClientState
(
object
):
class
ClientState
(
object
):
def
__init__
(
self
):
def
__init__
(
self
):
self
.
protocol_state
=
0
# state in communication
self
.
memory
=
[]
# list of Experience
self
.
memory
=
[]
# list of Experience
class
Experience
(
object
):
class
Experience
(
object
):
...
@@ -95,21 +91,25 @@ class SimulatorMaster(threading.Thread):
...
@@ -95,21 +91,25 @@ class SimulatorMaster(threading.Thread):
def
__init__
(
self
,
pipe_c2s
,
pipe_s2c
):
def
__init__
(
self
,
pipe_c2s
,
pipe_s2c
):
super
(
SimulatorMaster
,
self
)
.
__init__
()
super
(
SimulatorMaster
,
self
)
.
__init__
()
self
.
daemon
=
True
self
.
context
=
zmq
.
Context
()
self
.
context
=
zmq
.
Context
()
self
.
c2s_socket
=
self
.
context
.
socket
(
zmq
.
ROUTER
)
self
.
c2s_socket
=
self
.
context
.
socket
(
zmq
.
PULL
)
self
.
c2s_socket
.
bind
(
pipe_c2s
)
self
.
c2s_socket
.
bind
(
pipe_c2s
)
self
.
c2s_socket
.
set_hwm
(
10
)
self
.
s2c_socket
=
self
.
context
.
socket
(
zmq
.
ROUTER
)
self
.
s2c_socket
=
self
.
context
.
socket
(
zmq
.
ROUTER
)
self
.
s2c_socket
.
bind
(
pipe_s2c
)
self
.
s2c_socket
.
bind
(
pipe_s2c
)
self
.
s2c_socket
.
set_hwm
(
10
)
self
.
socket_lock
=
threading
.
Lock
()
self
.
daemon
=
True
# queueing messages to client
# queueing messages to client
self
.
send_queue
=
queue
.
Queue
(
maxsize
=
100
)
self
.
send_queue
=
queue
.
Queue
(
maxsize
=
100
)
self
.
send_thread
=
LoopThread
(
lambda
:
self
.
s2c_socket
.
send_multipart
(
self
.
send_queue
.
get
()))
def
f
():
msg
=
self
.
send_queue
.
get
()
# slow
self
.
s2c_socket
.
send_multipart
(
msg
,
copy
=
False
)
self
.
send_thread
=
LoopThread
(
f
)
self
.
send_thread
.
daemon
=
True
self
.
send_thread
.
daemon
=
True
self
.
send_thread
.
start
()
self
.
send_thread
.
start
()
...
@@ -123,21 +123,25 @@ class SimulatorMaster(threading.Thread):
...
@@ -123,21 +123,25 @@ class SimulatorMaster(threading.Thread):
def
run
(
self
):
def
run
(
self
):
self
.
clients
=
defaultdict
(
self
.
ClientState
)
self
.
clients
=
defaultdict
(
self
.
ClientState
)
#cnt = 0
while
True
:
while
True
:
ident
,
msg
=
self
.
c2s_socket
.
recv_multipart
()
#cnt += 1
#if cnt % 3000 == 0:
#print_total_timer()
msg
=
loads
(
self
.
c2s_socket
.
recv
(
copy
=
False
)
.
bytes
)
ident
,
state
,
reward
,
isOver
=
msg
client
=
self
.
clients
[
ident
]
client
=
self
.
clients
[
ident
]
client
.
protocol_state
=
1
-
client
.
protocol_state
# first flip the state
if
not
client
.
protocol_state
==
0
:
# state-action
# check if reward&isOver is valid
state
=
loads
(
msg
)
# in the first message, only state is valid
self
.
_on_state
(
state
,
ident
)
if
len
(
client
.
memory
)
>
0
:
else
:
# reward-response
reward
,
isOver
=
loads
(
msg
)
client
.
memory
[
-
1
]
.
reward
=
reward
client
.
memory
[
-
1
]
.
reward
=
reward
if
isOver
:
if
isOver
:
self
.
_on_episode_over
(
ident
)
self
.
_on_episode_over
(
ident
)
else
:
else
:
self
.
_on_datapoint
(
ident
)
self
.
_on_datapoint
(
ident
)
self
.
send_queue
.
put
([
ident
,
'Thanks'
])
# just an ACK
# feed state and return action
self
.
_on_state
(
state
,
ident
)
@
abstractmethod
@
abstractmethod
def
_on_state
(
self
,
state
,
ident
):
def
_on_state
(
self
,
state
,
ident
):
...
...
tensorpack/predict/base.py
View file @
8d3c709c
...
@@ -39,9 +39,9 @@ class AsyncPredictorBase(PredictorBase):
...
@@ -39,9 +39,9 @@ class AsyncPredictorBase(PredictorBase):
"""
"""
:param dp: A data point (list of component) as inputs.
:param dp: A data point (list of component) as inputs.
(It should be either batched or not batched depending on the predictor implementation)
(It should be either batched or not batched depending on the predictor implementation)
:param callback: a thread-safe callback to get called with
the list of
:param callback: a thread-safe callback to get called with
outputs of (inputs, outputs) pair
either outputs or (inputs, outputs)
:return: a Future of
outpu
ts
:return: a Future of
resul
ts
"""
"""
@
abstractmethod
@
abstractmethod
...
...
tensorpack/predict/concurrency.py
View file @
8d3c709c
...
@@ -82,16 +82,16 @@ class PredictorWorkerThread(threading.Thread):
...
@@ -82,16 +82,16 @@ class PredictorWorkerThread(threading.Thread):
self
.
id
=
id
self
.
id
=
id
def
run
(
self
):
def
run
(
self
):
#self.xxx = None
while
True
:
while
True
:
batched
,
futures
=
self
.
fetch_batch
()
batched
,
futures
=
self
.
fetch_batch
()
outputs
=
self
.
func
(
batched
)
outputs
=
self
.
func
(
batched
)
#print "batched size: ", len(batched[0]), "queuesize: ", self.queue.qsize()
#print "Worker {} batched {} Queue {}".format(
#self.id, len(futures), self.queue.qsize())
# debug, for speed testing
# debug, for speed testing
#if
self.xxx is None
:
#if
not hasattr(self, 'xxx')
:
#self.xxx = outputs = self.func(
[batched]
)
#self.xxx = outputs = self.func(
batched
)
#else:
#else:
#outputs = [[self.xxx[0][0]] * len(batched
), [self.xxx[1][0]] * len(batched
)]
#outputs = [[self.xxx[0][0]] * len(batched
[0]), [self.xxx[1][0]] * len(batched[0]
)]
for
idx
,
f
in
enumerate
(
futures
):
for
idx
,
f
in
enumerate
(
futures
):
f
.
set_result
([
k
[
idx
]
for
k
in
outputs
])
f
.
set_result
([
k
[
idx
]
for
k
in
outputs
])
...
@@ -125,7 +125,9 @@ class MultiThreadAsyncPredictor(AsyncPredictorBase):
...
@@ -125,7 +125,9 @@ class MultiThreadAsyncPredictor(AsyncPredictorBase):
""" :param predictors: a list of OnlinePredictor"""
""" :param predictors: a list of OnlinePredictor"""
for
k
in
predictors
:
for
k
in
predictors
:
assert
isinstance
(
k
,
OnlinePredictor
),
type
(
k
)
assert
isinstance
(
k
,
OnlinePredictor
),
type
(
k
)
self
.
input_queue
=
queue
.
Queue
(
maxsize
=
len
(
predictors
)
*
10
)
# TODO use predictors.return_input here
assert
k
.
return_input
==
False
self
.
input_queue
=
queue
.
Queue
(
maxsize
=
len
(
predictors
)
*
100
)
self
.
threads
=
[
self
.
threads
=
[
PredictorWorkerThread
(
PredictorWorkerThread
(
self
.
input_queue
,
f
,
id
,
batch_size
=
batch_size
)
self
.
input_queue
,
f
,
id
,
batch_size
=
batch_size
)
...
...
tensorpack/train/multigpu.py
View file @
8d3c709c
...
@@ -115,7 +115,7 @@ class AsyncMultiGPUTrainer(MultiGPUTrainer):
...
@@ -115,7 +115,7 @@ class AsyncMultiGPUTrainer(MultiGPUTrainer):
train_op
=
self
.
config
.
optimizer
.
apply_gradients
(
grad_list
[
k
])
train_op
=
self
.
config
.
optimizer
.
apply_gradients
(
grad_list
[
k
])
def
f
(
op
=
train_op
):
# avoid late-binding
def
f
(
op
=
train_op
):
# avoid late-binding
self
.
sess
.
run
([
op
])
self
.
sess
.
run
([
op
])
self
.
async_step_counter
.
next
(
)
next
(
self
.
async_step_counter
)
th
=
LoopThread
(
f
)
th
=
LoopThread
(
f
)
th
.
pause
()
th
.
pause
()
th
.
start
()
th
.
start
()
...
@@ -127,7 +127,7 @@ class AsyncMultiGPUTrainer(MultiGPUTrainer):
...
@@ -127,7 +127,7 @@ class AsyncMultiGPUTrainer(MultiGPUTrainer):
self
.
async_running
=
True
self
.
async_running
=
True
for
th
in
self
.
training_threads
:
# resume all threads
for
th
in
self
.
training_threads
:
# resume all threads
th
.
resume
()
th
.
resume
()
self
.
async_step_counter
.
next
(
)
next
(
self
.
async_step_counter
)
super
(
AsyncMultiGPUTrainer
,
self
)
.
run_step
()
super
(
AsyncMultiGPUTrainer
,
self
)
.
run_step
()
def
_trigger_epoch
(
self
):
def
_trigger_epoch
(
self
):
...
...
tensorpack/train/trainer.py
View file @
8d3c709c
...
@@ -202,6 +202,9 @@ class QueueInputTrainer(Trainer):
...
@@ -202,6 +202,9 @@ class QueueInputTrainer(Trainer):
self
.
config
.
optimizer
.
apply_gradients
(
grads
,
get_global_step_var
()),
self
.
config
.
optimizer
.
apply_gradients
(
grads
,
get_global_step_var
()),
summary_moving_average
(),
name
=
'train_op'
)
summary_moving_average
(),
name
=
'train_op'
)
# skip training
#self.train_op = tf.group(*self.dequed_inputs)
self
.
main_loop
()
self
.
main_loop
()
def
run_step
(
self
):
def
run_step
(
self
):
...
@@ -218,8 +221,6 @@ class QueueInputTrainer(Trainer):
...
@@ -218,8 +221,6 @@ class QueueInputTrainer(Trainer):
#trace_file.write(trace.generate_chrome_trace_format())
#trace_file.write(trace.generate_chrome_trace_format())
#import sys; sys.exit()
#import sys; sys.exit()
#self.sess.run([self.dequed_inputs[1]])
def
_trigger_epoch
(
self
):
def
_trigger_epoch
(
self
):
# need to run summary_op every epoch
# need to run summary_op every epoch
# note that summary_op will take a data from the queue
# note that summary_op will take a data from the queue
...
@@ -234,3 +235,5 @@ class QueueInputTrainer(Trainer):
...
@@ -234,3 +235,5 @@ class QueueInputTrainer(Trainer):
"""
"""
return
self
.
predictor_factory
.
get_predictor
(
input_names
,
output_names
,
tower
)
return
self
.
predictor_factory
.
get_predictor
(
input_names
,
output_names
,
tower
)
def
get_predict_funcs
(
self
,
input_names
,
output_names
,
n
):
return
[
self
.
get_predict_func
(
input_names
,
output_names
,
k
)
for
k
in
range
(
n
)]
tensorpack/utils/timer.py
View file @
8d3c709c
...
@@ -13,7 +13,27 @@ import atexit
...
@@ -13,7 +13,27 @@ import atexit
from
.stat
import
StatCounter
from
.stat
import
StatCounter
from
.
import
logger
from
.
import
logger
__all__
=
[
'total_timer'
,
'timed_operation'
,
'print_total_timer'
]
__all__
=
[
'total_timer'
,
'timed_operation'
,
'print_total_timer'
,
'IterSpeedCounter'
]
class
IterSpeedCounter
(
object
):
def
__init__
(
self
,
print_every
,
name
=
None
):
self
.
cnt
=
0
self
.
print_every
=
int
(
print_every
)
self
.
name
=
name
if
name
else
'IterSpeed'
def
reset
(
self
):
self
.
start
=
time
.
time
()
def
__call__
(
self
):
if
self
.
cnt
==
0
:
self
.
reset
()
self
.
cnt
+=
1
if
self
.
cnt
%
self
.
print_every
!=
0
:
return
t
=
time
.
time
()
-
self
.
start
logger
.
info
(
"{}: {:.2f} sec, {} times, {:.3g} sec/time"
.
format
(
self
.
name
,
t
,
self
.
cnt
,
t
/
self
.
cnt
))
@
contextmanager
@
contextmanager
def
timed_operation
(
msg
,
log_start
=
False
):
def
timed_operation
(
msg
,
log_start
=
False
):
...
@@ -37,7 +57,7 @@ def print_total_timer():
...
@@ -37,7 +57,7 @@ def print_total_timer():
if
len
(
_TOTAL_TIMER_DATA
)
==
0
:
if
len
(
_TOTAL_TIMER_DATA
)
==
0
:
return
return
for
k
,
v
in
six
.
iteritems
(
_TOTAL_TIMER_DATA
):
for
k
,
v
in
six
.
iteritems
(
_TOTAL_TIMER_DATA
):
logger
.
info
(
"Total Time: {} -> {
} sec, {} times, {
} sec/time"
.
format
(
logger
.
info
(
"Total Time: {} -> {
:.2f} sec, {} times, {:.3g
} sec/time"
.
format
(
k
,
v
.
sum
,
v
.
count
,
v
.
average
))
k
,
v
.
sum
,
v
.
count
,
v
.
average
))
atexit
.
register
(
print_total_timer
)
atexit
.
register
(
print_total_timer
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment