Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
ff0a4528
Commit
ff0a4528
authored
May 27, 2016
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
separate predict related code
parent
d7a85f44
Changes
7
Show whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
266 additions
and
5 deletions
+266
-5
examples/Atari2600/DQN.py
examples/Atari2600/DQN.py
+0
-1
tensorpack/dataflow/dataset/atari.py
tensorpack/dataflow/dataset/atari.py
+2
-2
tensorpack/predict/__init__.py
tensorpack/predict/__init__.py
+20
-0
tensorpack/predict/common.py
tensorpack/predict/common.py
+90
-0
tensorpack/predict/concurrency.py
tensorpack/predict/concurrency.py
+67
-0
tensorpack/predict/dataset.py
tensorpack/predict/dataset.py
+76
-0
tensorpack/tfutils/sessinit.py
tensorpack/tfutils/sessinit.py
+11
-2
No files found.
examples/Atari2600/DQN.py
View file @
ff0a4528
...
...
@@ -266,7 +266,6 @@ if __name__ == '__main__':
if
args
.
task
!=
'train'
:
assert
args
.
load
is
not
None
global
ROM_FILE
ROM_FILE
=
args
.
rom
if
args
.
task
==
'play'
:
...
...
tensorpack/dataflow/dataset/atari.py
View file @
ff0a4528
...
...
@@ -36,7 +36,7 @@ class AtariPlayer(RLEnvironment):
self
.
ale
=
ALEInterface
()
self
.
rng
=
get_rng
(
self
)
self
.
ale
.
setInt
(
"random_seed"
,
self
.
rng
.
randint
(
self
.
rng
.
randint
(
0
,
1000
)
))
self
.
ale
.
setInt
(
"random_seed"
,
self
.
rng
.
randint
(
0
,
1000
))
self
.
ale
.
setInt
(
"frame_skip"
,
frame_skip
)
self
.
ale
.
setBool
(
'color_averaging'
,
True
)
self
.
ale
.
loadROM
(
rom_file
)
...
...
@@ -125,7 +125,7 @@ if __name__ == '__main__':
#im = a.grab_image()
#cv2.imshow(a.romname, im)
act
=
rng
.
choice
(
range
(
num
))
print
act
print
(
act
)
r
,
o
=
a
.
action
(
act
)
a
.
current_state
()
#time.sleep(0.1)
...
...
tensorpack/predict/__init__.py
0 → 100644
View file @
ff0a4528
# -*- coding: UTF-8 -*-
# File: __init__.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
from
pkgutil
import
walk_packages
import
os
import
os.path
def
global_import
(
name
):
p
=
__import__
(
name
,
globals
(),
locals
(),
level
=
1
)
lst
=
p
.
__all__
if
'__all__'
in
dir
(
p
)
else
dir
(
p
)
for
k
in
lst
:
globals
()[
k
]
=
p
.
__dict__
[
k
]
del
globals
()[
name
]
for
_
,
module_name
,
_
in
walk_packages
(
[
os
.
path
.
dirname
(
__file__
)]):
if
not
module_name
.
startswith
(
'_'
):
global_import
(
module_name
)
tensorpack/predict.py
→
tensorpack/predict
/common
.py
View file @
ff0a4528
# -*- coding: UTF-8 -*-
# File:
predict
.py
# File:
common
.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
import
tensorflow
as
tf
import
numpy
as
np
from
collections
import
namedtuple
from
tqdm
import
tqdm
from
six.moves
import
zip
,
range
from
six.moves
import
zip
import
multiprocessing
from
.utils.concurrency
import
ensure_proc_terminate
,
OrderedResultGatherProc
,
DIE
from
..tfutils
import
*
from
.tfutils
import
*
from
.utils
import
logger
from
.tfutils.modelutils
import
describe_model
from
.dataflow
import
DataFlow
,
BatchData
from
.dataflow.dftools
import
dataflow_to_process_queue
import
multiprocessing
__all__
=
[
'PredictConfig'
,
'DatasetPredictor'
,
'get_predict_func'
,
'ParallelPredictWorker'
]
__all__
=
[
'PredictConfig'
,
'get_predict_func'
,
'PredictResult'
]
PredictResult
=
namedtuple
(
'PredictResult'
,
[
'input'
,
'output'
])
...
...
@@ -27,7 +19,6 @@ class PredictConfig(object):
"""
The config used by `get_predict_func`.
:param session_config: a `tf.ConfigProto` instance to instantiate the session.
:param session_init: a `utils.sessinit.SessionInit` instance to
initialize variables of a session.
:param input_data_mapping: Decide the mapping from each component in data
...
...
@@ -68,6 +59,7 @@ class PredictConfig(object):
def
get_predict_func
(
config
):
"""
Produce a simple predictor function in a newly-created session without any parallelism.
:param config: a `PredictConfig` instance.
:returns: A prediction function that takes a list of input values, and return
a list of output values defined in ``config.output_var_names``.
...
...
@@ -86,9 +78,6 @@ def get_predict_func(config):
output_vars
=
[
tf
.
get_default_graph
()
.
get_tensor_by_name
(
get_op_var_name
(
n
)[
1
])
for
n
in
output_var_names
]
if
config
.
session_config
:
sess
=
tf
.
Session
(
config
=
config
.
session_config
)
else
:
sess
=
tf
.
Session
()
config
.
session_init
.
init
(
sess
)
...
...
@@ -99,119 +88,3 @@ def get_predict_func(config):
feed
=
dict
(
zip
(
input_map
,
dp
))
return
sess
.
run
(
output_vars
,
feed_dict
=
feed
)
return
run_input
class
ParallelPredictWorker
(
multiprocessing
.
Process
):
def
__init__
(
self
,
idx
,
gpuid
,
config
):
"""
:param idx: index of the worker. the 0th worker will print log.
:param gpuid: id of the GPU to be used. set to -1 to use CPU.
:param config: a `PredictConfig`
"""
super
(
ParallelPredictWorker
,
self
)
.
__init__
()
self
.
idx
=
idx
self
.
gpuid
=
gpuid
self
.
config
=
config
def
_init_runtime
(
self
):
if
self
.
gpuid
>=
0
:
logger
.
info
(
"Worker {} uses GPU {}"
.
format
(
self
.
idx
,
self
.
gpuid
))
os
.
environ
[
'CUDA_VISIBLE_DEVICES'
]
=
self
.
gpuid
else
:
logger
.
info
(
"Worker {} uses CPU"
.
format
(
self
.
idx
))
os
.
environ
[
'CUDA_VISIBLE_DEVICES'
]
=
''
G
=
tf
.
Graph
()
# build a graph for each process, because they don't need to share anything
with
G
.
as_default
(),
tf
.
device
(
'/gpu:0'
if
self
.
gpuid
>=
0
else
'/cpu:0'
):
if
self
.
idx
!=
0
:
from
tensorpack.models._common
import
disable_layer_logging
disable_layer_logging
()
self
.
func
=
get_predict_func
(
self
.
config
)
if
self
.
idx
==
0
:
describe_model
()
class
QueuePredictWorker
(
ParallelPredictWorker
):
""" A worker process to run predictor on one GPU """
def
__init__
(
self
,
idx
,
gpuid
,
inqueue
,
outqueue
,
config
):
"""
:param idx: index of the worker. the 0th worker will print log.
:param gpuid: id of the GPU to be used. set to -1 to use CPU.
:param inqueue: input queue to get data point
:param outqueue: output queue put result
:param config: a `PredictConfig`
"""
super
(
QueuePredictWorker
,
self
)
.
__init__
(
idx
,
gpuid
,
config
)
self
.
inqueue
=
inqueue
self
.
outqueue
=
outqueue
def
run
(
self
):
self
.
_init_runtime
()
while
True
:
tid
,
dp
=
self
.
inqueue
.
get
()
if
tid
==
DIE
:
self
.
outqueue
.
put
((
DIE
,
None
))
return
else
:
res
=
PredictResult
(
dp
,
self
.
func
(
dp
))
self
.
outqueue
.
put
((
tid
,
res
))
class
DatasetPredictor
(
object
):
"""
Run the predict_config on a given `DataFlow`.
"""
def
__init__
(
self
,
config
,
dataset
):
"""
:param config: a `PredictConfig` instance.
:param dataset: a `DataFlow` instance.
"""
assert
isinstance
(
dataset
,
DataFlow
)
self
.
ds
=
dataset
self
.
nr_gpu
=
config
.
nr_gpu
if
self
.
nr_gpu
>
1
:
self
.
inqueue
,
self
.
inqueue_proc
=
dataflow_to_process_queue
(
self
.
ds
,
10
,
self
.
nr_gpu
)
self
.
outqueue
=
multiprocessing
.
Queue
()
try
:
gpus
=
os
.
environ
[
'CUDA_VISIBLE_DEVICES'
]
.
split
(
','
)
except
KeyError
:
gpus
=
list
(
range
(
self
.
nr_gpu
))
self
.
workers
=
[
QueuePredictWorker
(
i
,
gpus
[
i
],
self
.
inqueue
,
self
.
outqueue
,
config
)
for
i
in
range
(
self
.
nr_gpu
)]
self
.
result_queue
=
OrderedResultGatherProc
(
self
.
outqueue
)
# setup all the procs
self
.
inqueue_proc
.
start
()
for
p
in
self
.
workers
:
p
.
start
()
self
.
result_queue
.
start
()
ensure_proc_terminate
(
self
.
workers
)
ensure_proc_terminate
([
self
.
result_queue
,
self
.
inqueue_proc
])
else
:
self
.
func
=
get_predict_func
(
config
)
def
get_result
(
self
):
""" A generator to produce prediction for each data"""
with
tqdm
(
total
=
self
.
ds
.
size
())
as
pbar
:
if
self
.
nr_gpu
==
1
:
for
dp
in
self
.
ds
.
get_data
():
yield
PredictResult
(
dp
,
self
.
func
(
dp
))
pbar
.
update
()
else
:
die_cnt
=
0
while
True
:
res
=
self
.
result_queue
.
get
()
pbar
.
update
()
if
res
[
0
]
!=
DIE
:
yield
res
[
1
]
else
:
die_cnt
+=
1
if
die_cnt
==
self
.
nr_gpu
:
break
self
.
inqueue_proc
.
join
()
self
.
inqueue_proc
.
terminate
()
for
p
in
self
.
workers
:
p
.
join
();
p
.
terminate
()
def
get_all_result
(
self
):
"""
Run over the dataset and return a list of all predictions.
"""
return
list
(
self
.
get_result
())
tensorpack/predict/concurrency.py
0 → 100644
View file @
ff0a4528
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# File: concurrency.py
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
import
multiprocessing
import
tensorflow
as
tf
from
..utils.concurrency
import
DIE
from
..tfutils.modelutils
import
describe_model
from
..utils
import
logger
from
..tfutils
import
*
from
.common
import
*
__all__
=
[
'ParallelPredictWorker'
,
'QueuePredictWorker'
]
class
ParallelPredictWorker
(
multiprocessing
.
Process
):
def
__init__
(
self
,
idx
,
gpuid
,
config
):
"""
:param idx: index of the worker. the 0th worker will print log.
:param gpuid: absolute id of the GPU to be used. set to -1 to use CPU.
:param config: a `PredictConfig`
"""
super
(
ParallelPredictWorker
,
self
)
.
__init__
()
self
.
idx
=
idx
self
.
gpuid
=
gpuid
self
.
config
=
config
def
_init_runtime
(
self
):
if
self
.
gpuid
>=
0
:
logger
.
info
(
"Worker {} uses GPU {}"
.
format
(
self
.
idx
,
self
.
gpuid
))
os
.
environ
[
'CUDA_VISIBLE_DEVICES'
]
=
str
(
self
.
gpuid
)
else
:
logger
.
info
(
"Worker {} uses CPU"
.
format
(
self
.
idx
))
os
.
environ
[
'CUDA_VISIBLE_DEVICES'
]
=
''
G
=
tf
.
Graph
()
# build a graph for each process, because they don't need to share anything
with
G
.
as_default
():
if
self
.
idx
!=
0
:
from
tensorpack.models._common
import
disable_layer_logging
disable_layer_logging
()
self
.
func
=
get_predict_func
(
self
.
config
)
if
self
.
idx
==
0
:
describe_model
()
class
QueuePredictWorker
(
ParallelPredictWorker
):
""" A worker process to run predictor on one GPU """
def
__init__
(
self
,
idx
,
gpuid
,
inqueue
,
outqueue
,
config
):
"""
:param idx: index of the worker. the 0th worker will print log.
:param gpuid: id of the GPU to be used. set to -1 to use CPU.
:param inqueue: input queue to get data point
:param outqueue: output queue put result
:param config: a `PredictConfig`
"""
super
(
QueuePredictWorker
,
self
)
.
__init__
(
idx
,
gpuid
,
config
)
self
.
inqueue
=
inqueue
self
.
outqueue
=
outqueue
def
run
(
self
):
self
.
_init_runtime
()
while
True
:
tid
,
dp
=
self
.
inqueue
.
get
()
if
tid
==
DIE
:
self
.
outqueue
.
put
((
DIE
,
None
))
return
else
:
res
=
PredictResult
(
dp
,
self
.
func
(
dp
))
self
.
outqueue
.
put
((
tid
,
res
))
tensorpack/predict/dataset.py
0 → 100644
View file @
ff0a4528
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# File: dataset.py
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
from
six.moves
import
range
from
tqdm
import
tqdm
from
..dataflow
import
DataFlow
,
BatchData
from
..dataflow.dftools
import
dataflow_to_process_queue
from
..utils.concurrency
import
ensure_proc_terminate
,
OrderedResultGatherProc
,
DIE
from
.concurrency
import
*
__all__
=
[
'DatasetPredictor'
]
class
DatasetPredictor
(
object
):
"""
Run the predict_config on a given `DataFlow`.
"""
def
__init__
(
self
,
config
,
dataset
):
"""
:param config: a `PredictConfig` instance.
:param dataset: a `DataFlow` instance.
"""
assert
isinstance
(
dataset
,
DataFlow
)
self
.
ds
=
dataset
self
.
nr_gpu
=
config
.
nr_gpu
if
self
.
nr_gpu
>
1
:
self
.
inqueue
,
self
.
inqueue_proc
=
dataflow_to_process_queue
(
self
.
ds
,
10
,
self
.
nr_gpu
)
self
.
outqueue
=
multiprocessing
.
Queue
()
try
:
gpus
=
os
.
environ
[
'CUDA_VISIBLE_DEVICES'
]
.
split
(
','
)
except
KeyError
:
gpus
=
list
(
range
(
self
.
nr_gpu
))
self
.
workers
=
[
QueuePredictWorker
(
i
,
gpus
[
i
],
self
.
inqueue
,
self
.
outqueue
,
config
)
for
i
in
range
(
self
.
nr_gpu
)]
self
.
result_queue
=
OrderedResultGatherProc
(
self
.
outqueue
)
# setup all the procs
self
.
inqueue_proc
.
start
()
for
p
in
self
.
workers
:
p
.
start
()
self
.
result_queue
.
start
()
ensure_proc_terminate
(
self
.
workers
)
ensure_proc_terminate
([
self
.
result_queue
,
self
.
inqueue_proc
])
else
:
self
.
func
=
get_predict_func
(
config
)
def
get_result
(
self
):
""" A generator to produce prediction for each data"""
with
tqdm
(
total
=
self
.
ds
.
size
())
as
pbar
:
if
self
.
nr_gpu
==
1
:
for
dp
in
self
.
ds
.
get_data
():
yield
PredictResult
(
dp
,
self
.
func
(
dp
))
pbar
.
update
()
else
:
die_cnt
=
0
while
True
:
res
=
self
.
result_queue
.
get
()
pbar
.
update
()
if
res
[
0
]
!=
DIE
:
yield
res
[
1
]
else
:
die_cnt
+=
1
if
die_cnt
==
self
.
nr_gpu
:
break
self
.
inqueue_proc
.
join
()
self
.
inqueue_proc
.
terminate
()
for
p
in
self
.
workers
:
p
.
join
();
p
.
terminate
()
def
get_all_result
(
self
):
"""
Run over the dataset and return a list of all predictions.
"""
return
list
(
self
.
get_result
())
tensorpack/tfutils/sessinit.py
View file @
ff0a4528
...
...
@@ -12,9 +12,13 @@ import six
from
..utils
import
logger
__all__
=
[
'SessionInit'
,
'NewSession'
,
'SaverRestore'
,
'ParamRestore'
,
__all__
=
[
'SessionInit'
,
'NewSession'
,
'SaverRestore'
,
'ParamRestore'
,
'JustCurrentSession'
,
'dump_session_params'
]
# TODO they initialize_all at the beginning by default.
class
SessionInit
(
object
):
""" Base class for utilities to initialize a session"""
__metaclass__
=
ABCMeta
...
...
@@ -30,6 +34,11 @@ class SessionInit(object):
def
_init
(
self
,
sess
):
pass
class
JustCurrentSession
(
SessionInit
):
""" Just use the current default session. This is a no-op placeholder"""
def
_init
(
self
,
sess
):
logger
.
info
(
"Using the current running session .."
)
class
NewSession
(
SessionInit
):
"""
Create a new session. All variables will be initialized by their
...
...
@@ -139,7 +148,7 @@ class ParamRestore(SessionInit):
def
dump_session_params
(
path
):
""" Dump value of all trainable variables to a dict and save to `path` as
npy format
.
npy format
, loadable by ParamRestore
"""
var
=
tf
.
get_collection
(
tf
.
GraphKeys
.
TRAINABLE_VARIABLES
)
result
=
{}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment