Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
b2ec42a8
Commit
b2ec42a8
authored
Jun 09, 2016
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
asyncpredictor accepts multiple input var
parent
b9e2bd1b
Changes
6
Show whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
31 additions
and
21 deletions
+31
-21
tensorpack/RL/simulator.py
tensorpack/RL/simulator.py
+6
-4
tensorpack/callbacks/param.py
tensorpack/callbacks/param.py
+1
-1
tensorpack/predict/concurrency.py
tensorpack/predict/concurrency.py
+19
-12
tensorpack/tfutils/gradproc.py
tensorpack/tfutils/gradproc.py
+1
-1
tensorpack/train/trainer.py
tensorpack/train/trainer.py
+2
-1
tensorpack/utils/concurrency.py
tensorpack/utils/concurrency.py
+2
-2
No files found.
tensorpack/RL/simulator.py
View file @
b2ec42a8
...
@@ -10,6 +10,7 @@ import weakref
...
@@ -10,6 +10,7 @@ import weakref
from
abc
import
abstractmethod
,
ABCMeta
from
abc
import
abstractmethod
,
ABCMeta
from
collections
import
defaultdict
,
namedtuple
from
collections
import
defaultdict
,
namedtuple
import
numpy
as
np
import
numpy
as
np
import
six
from
six.moves
import
queue
from
six.moves
import
queue
from
..utils.timer
import
*
from
..utils.timer
import
*
...
@@ -84,12 +85,13 @@ class SimulatorMaster(threading.Thread):
...
@@ -84,12 +85,13 @@ class SimulatorMaster(threading.Thread):
class
Experience
(
object
):
class
Experience
(
object
):
""" A transition of state, or experience"""
""" A transition of state, or experience"""
def
__init__
(
self
,
state
,
action
,
reward
,
misc
=
None
):
def
__init__
(
self
,
state
,
action
,
reward
,
**
kwargs
):
"""
misc
: whatever other attribute you want to save"""
"""
kwargs
: whatever other attribute you want to save"""
self
.
state
=
state
self
.
state
=
state
self
.
action
=
action
self
.
action
=
action
self
.
reward
=
reward
self
.
reward
=
reward
self
.
misc
=
misc
for
k
,
v
in
six
.
iteritems
(
kwargs
):
setattr
(
self
,
k
,
v
)
def
__init__
(
self
,
pipe_c2s
,
pipe_s2c
):
def
__init__
(
self
,
pipe_c2s
,
pipe_s2c
):
super
(
SimulatorMaster
,
self
)
.
__init__
()
super
(
SimulatorMaster
,
self
)
.
__init__
()
...
@@ -120,7 +122,7 @@ class SimulatorMaster(threading.Thread):
...
@@ -120,7 +122,7 @@ class SimulatorMaster(threading.Thread):
atexit
.
register
(
clean_context
,
[
self
.
c2s_socket
,
self
.
s2c_socket
],
self
.
context
)
atexit
.
register
(
clean_context
,
[
self
.
c2s_socket
,
self
.
s2c_socket
],
self
.
context
)
def
run
(
self
):
def
run
(
self
):
self
.
clients
=
defaultdict
(
SimulatorMaster
.
ClientState
)
self
.
clients
=
defaultdict
(
self
.
ClientState
)
while
True
:
while
True
:
ident
,
msg
=
self
.
c2s_socket
.
recv_multipart
()
ident
,
msg
=
self
.
c2s_socket
.
recv_multipart
()
client
=
self
.
clients
[
ident
]
client
=
self
.
clients
[
ident
]
...
...
tensorpack/callbacks/param.py
View file @
b2ec42a8
...
@@ -130,7 +130,7 @@ class HumanHyperParamSetter(HyperParamSetter):
...
@@ -130,7 +130,7 @@ class HumanHyperParamSetter(HyperParamSetter):
"""
"""
super
(
HumanHyperParamSetter
,
self
)
.
__init__
(
param
)
super
(
HumanHyperParamSetter
,
self
)
.
__init__
(
param
)
self
.
file_name
=
os
.
path
.
join
(
logger
.
LOG_DIR
,
file_name
)
self
.
file_name
=
os
.
path
.
join
(
logger
.
LOG_DIR
,
file_name
)
logger
.
info
(
"Use {}
for
hyperparam {}."
.
format
(
logger
.
info
(
"Use {}
to control
hyperparam {}."
.
format
(
self
.
file_name
,
self
.
param
.
readable_name
))
self
.
file_name
,
self
.
param
.
readable_name
))
def
_get_value_to_set
(
self
):
def
_get_value_to_set
(
self
):
...
...
tensorpack/predict/concurrency.py
View file @
b2ec42a8
...
@@ -81,29 +81,31 @@ class MultiProcessQueuePredictWorker(MultiProcessPredictWorker):
...
@@ -81,29 +81,31 @@ class MultiProcessQueuePredictWorker(MultiProcessPredictWorker):
self
.
outqueue
.
put
((
tid
,
self
.
func
(
dp
)))
self
.
outqueue
.
put
((
tid
,
self
.
func
(
dp
)))
class
PredictorWorkerThread
(
threading
.
Thread
):
class
PredictorWorkerThread
(
threading
.
Thread
):
def
__init__
(
self
,
queue
,
pred_func
,
id
,
batch_size
=
5
):
def
__init__
(
self
,
queue
,
pred_func
,
id
,
nr_input_var
,
batch_size
=
5
):
super
(
PredictorWorkerThread
,
self
)
.
__init__
()
super
(
PredictorWorkerThread
,
self
)
.
__init__
()
self
.
queue
=
queue
self
.
queue
=
queue
self
.
func
=
pred_func
self
.
func
=
pred_func
self
.
daemon
=
True
self
.
daemon
=
True
self
.
batch_size
=
batch_size
self
.
batch_size
=
batch_size
self
.
nr_input_var
=
nr_input_var
self
.
id
=
id
self
.
id
=
id
def
run
(
self
):
def
run
(
self
):
def
fetch
():
def
fetch
():
batched
,
futures
=
[],
[]
batched
,
futures
=
[
[]
for
_
in
range
(
self
.
nr_input_var
)
],
[]
inp
,
f
=
self
.
queue
.
get
()
inp
,
f
=
self
.
queue
.
get
()
batched
.
append
(
inp
)
for
k
in
range
(
self
.
nr_input_var
):
batched
[
k
]
.
append
(
inp
[
k
])
futures
.
append
(
f
)
futures
.
append
(
f
)
if
self
.
batch_size
==
1
:
# fill a batch
return
batched
,
futures
cnt
=
1
while
Tru
e
:
while
cnt
<
self
.
batch_siz
e
:
try
:
try
:
inp
,
f
=
self
.
queue
.
get_nowait
()
inp
,
f
=
self
.
queue
.
get_nowait
()
batched
.
append
(
inp
)
for
k
in
range
(
self
.
nr_input_var
):
batched
[
k
]
.
append
(
inp
[
k
])
futures
.
append
(
f
)
futures
.
append
(
f
)
if
len
(
batched
)
==
self
.
batch_size
:
cnt
+=
1
break
except
queue
.
Empty
:
except
queue
.
Empty
:
break
break
return
batched
,
futures
return
batched
,
futures
...
@@ -111,7 +113,7 @@ class PredictorWorkerThread(threading.Thread):
...
@@ -111,7 +113,7 @@ class PredictorWorkerThread(threading.Thread):
while
True
:
while
True
:
batched
,
futures
=
fetch
()
batched
,
futures
=
fetch
()
#print "batched size: ", len(batched), "queuesize: ", self.queue.qsize()
#print "batched size: ", len(batched), "queuesize: ", self.queue.qsize()
outputs
=
self
.
func
(
[
batched
]
)
outputs
=
self
.
func
(
batched
)
# debug, for speed testing
# debug, for speed testing
#if self.xxx is None:
#if self.xxx is None:
#outputs = self.func([batched])
#outputs = self.func([batched])
...
@@ -135,7 +137,9 @@ class MultiThreadAsyncPredictor(object):
...
@@ -135,7 +137,9 @@ class MultiThreadAsyncPredictor(object):
"""
"""
self
.
input_queue
=
queue
.
Queue
(
maxsize
=
nr_thread
*
10
)
self
.
input_queue
=
queue
.
Queue
(
maxsize
=
nr_thread
*
10
)
self
.
threads
=
[
self
.
threads
=
[
PredictorWorkerThread
(
self
.
input_queue
,
f
,
id
,
batch_size
)
PredictorWorkerThread
(
self
.
input_queue
,
f
,
id
,
len
(
input_names
),
batch_size
=
batch_size
)
for
id
,
f
in
enumerate
(
for
id
,
f
in
enumerate
(
trainer
.
get_predict_funcs
(
trainer
.
get_predict_funcs
(
input_names
,
output_names
,
nr_thread
))]
input_names
,
output_names
,
nr_thread
))]
...
@@ -148,7 +152,10 @@ class MultiThreadAsyncPredictor(object):
...
@@ -148,7 +152,10 @@ class MultiThreadAsyncPredictor(object):
t
.
start
()
t
.
start
()
def
put_task
(
self
,
inputs
,
callback
=
None
):
def
put_task
(
self
,
inputs
,
callback
=
None
):
""" return a Future of output."""
"""
:params inputs: a data point (list of component) matching input_names (not batched)
:params callback: a callback to get called with the list of outputs
:returns: a Future of output."""
f
=
Future
()
f
=
Future
()
if
callback
is
not
None
:
if
callback
is
not
None
:
f
.
add_done_callback
(
callback
)
f
.
add_done_callback
(
callback
)
...
...
tensorpack/tfutils/gradproc.py
View file @
b2ec42a8
...
@@ -92,7 +92,7 @@ class MapGradient(GradientProcessor):
...
@@ -92,7 +92,7 @@ class MapGradient(GradientProcessor):
def
__init__
(
self
,
func
,
regex
=
'.*'
):
def
__init__
(
self
,
func
,
regex
=
'.*'
):
"""
"""
:param func: takes a tensor and returns a tensor
:param func: takes a tensor and returns a tensor
;
param regex: used to match variables. default to match all variables.
:
param regex: used to match variables. default to match all variables.
"""
"""
self
.
func
=
func
self
.
func
=
func
if
not
regex
.
endswith
(
'$'
):
if
not
regex
.
endswith
(
'$'
):
...
...
tensorpack/train/trainer.py
View file @
b2ec42a8
...
@@ -107,6 +107,7 @@ class QueueInputTrainer(Trainer):
...
@@ -107,6 +107,7 @@ class QueueInputTrainer(Trainer):
:param input_queue: a `tf.QueueBase` instance to be used to buffer datapoints.
:param input_queue: a `tf.QueueBase` instance to be used to buffer datapoints.
Defaults to a FIFO queue of size 100.
Defaults to a FIFO queue of size 100.
:param predict_tower: list of gpu idx to run prediction. default to be [0].
:param predict_tower: list of gpu idx to run prediction. default to be [0].
Use -1 for cpu.
"""
"""
super
(
QueueInputTrainer
,
self
)
.
__init__
(
config
)
super
(
QueueInputTrainer
,
self
)
.
__init__
(
config
)
self
.
input_vars
=
self
.
model
.
get_input_vars
()
self
.
input_vars
=
self
.
model
.
get_input_vars
()
...
@@ -136,7 +137,7 @@ class QueueInputTrainer(Trainer):
...
@@ -136,7 +137,7 @@ class QueueInputTrainer(Trainer):
tf
.
get_variable_scope
()
.
reuse_variables
()
tf
.
get_variable_scope
()
.
reuse_variables
()
for
k
in
self
.
predict_tower
:
for
k
in
self
.
predict_tower
:
logger
.
info
(
"Building graph for predict towerp{}..."
.
format
(
k
))
logger
.
info
(
"Building graph for predict towerp{}..."
.
format
(
k
))
with
tf
.
device
(
'/gpu:{}'
.
format
(
k
)),
\
with
tf
.
device
(
'/gpu:{}'
.
format
(
k
)
if
k
>=
0
else
'/cpu:0'
),
\
tf
.
name_scope
(
'towerp{}'
.
format
(
k
)):
tf
.
name_scope
(
'towerp{}'
.
format
(
k
)):
self
.
model
.
build_graph
(
inputs
,
False
)
self
.
model
.
build_graph
(
inputs
,
False
)
...
...
tensorpack/utils/concurrency.py
View file @
b2ec42a8
...
@@ -122,10 +122,10 @@ def subproc_call(cmd, timeout=None):
...
@@ -122,10 +122,10 @@ def subproc_call(cmd, timeout=None):
shell
=
True
,
timeout
=
timeout
)
shell
=
True
,
timeout
=
timeout
)
return
output
return
output
except
subprocess
.
TimeoutExpired
as
e
:
except
subprocess
.
TimeoutExpired
as
e
:
logger
.
warn
(
"
Timeout in evaluation
!"
)
logger
.
warn
(
"
Command timeout
!"
)
logger
.
warn
(
e
.
output
)
logger
.
warn
(
e
.
output
)
except
subprocess
.
CalledProcessError
as
e
:
except
subprocess
.
CalledProcessError
as
e
:
logger
.
warn
(
"
Evaluation script
failed: {}"
.
format
(
e
.
returncode
))
logger
.
warn
(
"
Commnad
failed: {}"
.
format
(
e
.
returncode
))
logger
.
warn
(
e
.
output
)
logger
.
warn
(
e
.
output
)
class
OrderedContainer
(
object
):
class
OrderedContainer
(
object
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment