Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
86ec2d15
Commit
86ec2d15
authored
Jul 16, 2016
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
offline predictor
parent
4af48399
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
75 additions
and
24 deletions
+75
-24
tensorpack/models/model_desc.py
tensorpack/models/model_desc.py
+1
-2
tensorpack/predict/base.py
tensorpack/predict/base.py
+63
-0
tensorpack/predict/common.py
tensorpack/predict/common.py
+11
-22
No files found.
tensorpack/models/model_desc.py
View file @
86ec2d15
...
@@ -39,8 +39,7 @@ class ModelDesc(object):
...
@@ -39,8 +39,7 @@ class ModelDesc(object):
def
reuse_input_vars
(
self
):
def
reuse_input_vars
(
self
):
""" Find and return already-defined input_vars in default graph"""
""" Find and return already-defined input_vars in default graph"""
input_var_names
=
[
k
.
name
for
k
in
self
.
_get_input_vars
()]
input_var_names
=
[
k
.
name
for
k
in
self
.
_get_input_vars
()]
g
=
tf
.
get_default_graph
()
return
get_vars_by_names
(
input_var_names
)
return
[
g
.
get_tensor_by_name
(
name
+
":0"
)
for
name
in
input_var_names
]
def
get_input_vars_desc
(
self
):
def
get_input_vars_desc
(
self
):
""" return a list of `InputVar` instance"""
""" return a list of `InputVar` instance"""
...
...
tensorpack/predict/base.py
0 → 100644
View file @
86ec2d15
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# File: base.py
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
from
abc
import
abstractmethod
,
ABCMeta
,
abstractproperty
import
tensorflow
as
tf
from
..tfutils
import
get_vars_by_names
class
PredictorBase
(
object
):
__metaclass__
=
ABCMeta
@
abstractproperty
def
session
(
self
):
""" return the session the predictor is running on"""
pass
def
__call__
(
self
,
dp
):
assert
len
(
dp
)
==
len
(
self
.
input_var_names
),
\
"{} != {}"
.
format
(
len
(
dp
),
len
(
self
.
input_var_names
))
output
=
self
.
_do_call
(
dp
)
if
self
.
return_input
:
return
(
dp
,
output
)
else
:
return
output
@
abstractmethod
def
_do_call
(
self
,
dp
):
"""
:param dp: input datapoint. must have the same length as input_var_names
:return: output as defined by the config
"""
pass
class
OfflinePredictor
(
PredictorBase
):
""" Build a predictor from a given config, in an independent graph"""
def
__init__
(
self
,
config
):
self
.
graph
=
tf
.
Graph
()
with
self
.
graph
.
as_default
():
input_vars
=
config
.
model
.
get_input_vars
()
config
.
model
.
_build_graph
(
input_vars
,
False
)
self
.
input_var_names
=
config
.
input_var_names
self
.
output_var_names
=
config
.
output_var_names
self
.
return_input
=
config
.
return_input
self
.
input_vars
=
get_vars_by_names
(
self
.
input_var_names
)
self
.
output_vars
=
get_vars_by_names
(
self
.
output_var_names
)
sess
=
tf
.
Session
(
config
=
config
.
session_config
)
config
.
session_init
.
init
(
sess
)
self
.
_session
=
sess
@
property
def
session
(
self
):
return
self
.
_session
def
_do_call
(
self
,
dp
):
feed
=
dict
(
zip
(
self
.
input_vars
,
dp
))
output
=
self
.
session
.
run
(
self
.
output_vars
,
feed_dict
=
feed
)
return
output
tensorpack/predict/common.py
View file @
86ec2d15
...
@@ -4,11 +4,13 @@
...
@@ -4,11 +4,13 @@
import
tensorflow
as
tf
import
tensorflow
as
tf
from
collections
import
namedtuple
from
collections
import
namedtuple
import
six
from
six.moves
import
zip
from
six.moves
import
zip
from
tensorpack.models
import
ModelDesc
from
tensorpack.models
import
ModelDesc
from
..utils
import
logger
from
..utils
import
logger
from
..tfutils
import
*
from
..tfutils
import
*
from
.base
import
OfflinePredictor
import
multiprocessing
import
multiprocessing
...
@@ -29,20 +31,20 @@ class PredictConfig(object):
...
@@ -29,20 +31,20 @@ class PredictConfig(object):
:param output_var_names: a list of names of the output tensors to predict, the
:param output_var_names: a list of names of the output tensors to predict, the
variables can be any computable tensor in the graph.
variables can be any computable tensor in the graph.
Predict specific output might not require all input variables.
Predict specific output might not require all input variables.
:param return_input: whether to produce (input, output) pair or just output. default to False.
:param return_input: whether to return (input, output) pair or just output. default to False.
It's only effective for `DatasetPredictorBase`.
"""
"""
def
assert_type
(
v
,
tp
):
def
assert_type
(
v
,
tp
):
assert
isinstance
(
v
,
tp
),
v
.
__class__
assert
isinstance
(
v
,
tp
),
v
.
__class__
# XXX does it work? start with minimal memory, but allow growth.
# XXX does it work? start with minimal memory, but allow growth.
# allow_growth doesn't seem to work very well in TF.
# allow_growth doesn't seem to work very well in TF.
self
.
session_config
=
kwargs
.
pop
(
'session_config'
,
get_default_sess_config
(
0.
3
))
self
.
session_config
=
kwargs
.
pop
(
'session_config'
,
get_default_sess_config
(
0.
4
))
self
.
session_init
=
kwargs
.
pop
(
'session_init'
,
JustCurrentSession
())
self
.
session_init
=
kwargs
.
pop
(
'session_init'
,
JustCurrentSession
())
assert_type
(
self
.
session_init
,
SessionInit
)
assert_type
(
self
.
session_init
,
SessionInit
)
self
.
model
=
kwargs
.
pop
(
'model'
)
self
.
model
=
kwargs
.
pop
(
'model'
)
assert_type
(
self
.
model
,
ModelDesc
)
assert_type
(
self
.
model
,
ModelDesc
)
self
.
input_var_names
=
kwargs
.
pop
(
'input_var_names'
,
None
)
# inputs & outputs
self
.
input_var_names
=
kwargs
.
pop
(
'input_var_names'
,
None
)
input_mapping
=
kwargs
.
pop
(
'input_data_mapping'
,
None
)
input_mapping
=
kwargs
.
pop
(
'input_data_mapping'
,
None
)
if
input_mapping
:
if
input_mapping
:
raw_vars
=
self
.
model
.
get_input_vars_desc
()
raw_vars
=
self
.
model
.
get_input_vars_desc
()
...
@@ -55,32 +57,19 @@ Use \'input_var_names=[{}]\' instead'.format(', '.join(self.input_var_names)))
...
@@ -55,32 +57,19 @@ Use \'input_var_names=[{}]\' instead'.format(', '.join(self.input_var_names)))
self
.
input_var_names
=
[
k
.
name
for
k
in
raw_vars
]
self
.
input_var_names
=
[
k
.
name
for
k
in
raw_vars
]
self
.
output_var_names
=
kwargs
.
pop
(
'output_var_names'
)
self
.
output_var_names
=
kwargs
.
pop
(
'output_var_names'
)
assert
len
(
self
.
input_var_names
),
self
.
input_var_names
assert
len
(
self
.
input_var_names
),
self
.
input_var_names
for
v
in
self
.
input_var_names
:
assert_type
(
v
,
six
.
string_types
)
assert
len
(
self
.
output_var_names
),
self
.
output_var_names
assert
len
(
self
.
output_var_names
),
self
.
output_var_names
self
.
return_input
=
kwargs
.
pop
(
'return_input'
,
False
)
self
.
return_input
=
kwargs
.
pop
(
'return_input'
,
False
)
assert
len
(
kwargs
)
==
0
,
'Unknown arguments: {}'
.
format
(
str
(
kwargs
.
keys
()))
assert
len
(
kwargs
)
==
0
,
'Unknown arguments: {}'
.
format
(
str
(
kwargs
.
keys
()))
def
get_predict_func
(
config
):
def
get_predict_func
(
config
):
"""
"""
Produce a
simple predictor function
run inside a new session.
Produce a
offline predictor
run inside a new session.
:param config: a `PredictConfig` instance.
:param config: a `PredictConfig` instance.
:returns: A
prediction function
that takes a list of input values, and return
:returns: A
callable predictor
that takes a list of input values, and return
a list of output values defined in ``config.output_var_names``.
a list of output values defined in ``config.output_var_names``.
"""
"""
# build graph
return
OfflinePredictor
(
config
)
input_vars
=
config
.
model
.
get_input_vars
()
config
.
model
.
_build_graph
(
input_vars
,
False
)
input_vars
=
get_vars_by_names
(
config
.
input_var_names
)
output_vars
=
get_vars_by_names
(
config
.
output_var_names
)
sess
=
tf
.
Session
(
config
=
config
.
session_config
)
config
.
session_init
.
init
(
sess
)
def
run_input
(
dp
):
assert
len
(
input_vars
)
==
len
(
dp
),
"{} != {}"
.
format
(
len
(
input_vars
),
len
(
dp
))
feed
=
dict
(
zip
(
input_vars
,
dp
))
return
sess
.
run
(
output_vars
,
feed_dict
=
feed
)
# XXX hack. so the caller can get access to the session.
run_input
.
session
=
sess
return
run_input
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment