Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
86ec2d15
Commit
86ec2d15
authored
Jul 16, 2016
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
offline predictor
parent
4af48399
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
75 additions
and
24 deletions
+75
-24
tensorpack/models/model_desc.py
tensorpack/models/model_desc.py
+1
-2
tensorpack/predict/base.py
tensorpack/predict/base.py
+63
-0
tensorpack/predict/common.py
tensorpack/predict/common.py
+11
-22
No files found.
tensorpack/models/model_desc.py
View file @
86ec2d15
...
...
@@ -39,8 +39,7 @@ class ModelDesc(object):
def
reuse_input_vars
(
self
):
""" Find and return already-defined input_vars in default graph"""
input_var_names
=
[
k
.
name
for
k
in
self
.
_get_input_vars
()]
g
=
tf
.
get_default_graph
()
return
[
g
.
get_tensor_by_name
(
name
+
":0"
)
for
name
in
input_var_names
]
return
get_vars_by_names
(
input_var_names
)
def
get_input_vars_desc
(
self
):
""" return a list of `InputVar` instance"""
...
...
tensorpack/predict/base.py
0 → 100644
View file @
86ec2d15
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# File: base.py
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
from
abc
import
abstractmethod
,
ABCMeta
,
abstractproperty
import
tensorflow
as
tf
from
..tfutils
import
get_vars_by_names
class
PredictorBase
(
object
):
__metaclass__
=
ABCMeta
@
abstractproperty
def
session
(
self
):
""" return the session the predictor is running on"""
pass
def
__call__
(
self
,
dp
):
assert
len
(
dp
)
==
len
(
self
.
input_var_names
),
\
"{} != {}"
.
format
(
len
(
dp
),
len
(
self
.
input_var_names
))
output
=
self
.
_do_call
(
dp
)
if
self
.
return_input
:
return
(
dp
,
output
)
else
:
return
output
@
abstractmethod
def
_do_call
(
self
,
dp
):
"""
:param dp: input datapoint. must have the same length as input_var_names
:return: output as defined by the config
"""
pass
class
OfflinePredictor
(
PredictorBase
):
""" Build a predictor from a given config, in an independent graph"""
def
__init__
(
self
,
config
):
self
.
graph
=
tf
.
Graph
()
with
self
.
graph
.
as_default
():
input_vars
=
config
.
model
.
get_input_vars
()
config
.
model
.
_build_graph
(
input_vars
,
False
)
self
.
input_var_names
=
config
.
input_var_names
self
.
output_var_names
=
config
.
output_var_names
self
.
return_input
=
config
.
return_input
self
.
input_vars
=
get_vars_by_names
(
self
.
input_var_names
)
self
.
output_vars
=
get_vars_by_names
(
self
.
output_var_names
)
sess
=
tf
.
Session
(
config
=
config
.
session_config
)
config
.
session_init
.
init
(
sess
)
self
.
_session
=
sess
@
property
def
session
(
self
):
return
self
.
_session
def
_do_call
(
self
,
dp
):
feed
=
dict
(
zip
(
self
.
input_vars
,
dp
))
output
=
self
.
session
.
run
(
self
.
output_vars
,
feed_dict
=
feed
)
return
output
tensorpack/predict/common.py
View file @
86ec2d15
...
...
@@ -4,11 +4,13 @@
import
tensorflow
as
tf
from
collections
import
namedtuple
import
six
from
six.moves
import
zip
from
tensorpack.models
import
ModelDesc
from
..utils
import
logger
from
..tfutils
import
*
from
.base
import
OfflinePredictor
import
multiprocessing
...
...
@@ -29,20 +31,20 @@ class PredictConfig(object):
:param output_var_names: a list of names of the output tensors to predict, the
variables can be any computable tensor in the graph.
Predict specific output might not require all input variables.
:param return_input: whether to produce (input, output) pair or just output. default to False.
It's only effective for `DatasetPredictorBase`.
:param return_input: whether to return (input, output) pair or just output. default to False.
"""
def
assert_type
(
v
,
tp
):
assert
isinstance
(
v
,
tp
),
v
.
__class__
# XXX does it work? start with minimal memory, but allow growth.
# allow_growth doesn't seem to work very well in TF.
self
.
session_config
=
kwargs
.
pop
(
'session_config'
,
get_default_sess_config
(
0.
3
))
self
.
session_config
=
kwargs
.
pop
(
'session_config'
,
get_default_sess_config
(
0.
4
))
self
.
session_init
=
kwargs
.
pop
(
'session_init'
,
JustCurrentSession
())
assert_type
(
self
.
session_init
,
SessionInit
)
self
.
model
=
kwargs
.
pop
(
'model'
)
assert_type
(
self
.
model
,
ModelDesc
)
self
.
input_var_names
=
kwargs
.
pop
(
'input_var_names'
,
None
)
# inputs & outputs
self
.
input_var_names
=
kwargs
.
pop
(
'input_var_names'
,
None
)
input_mapping
=
kwargs
.
pop
(
'input_data_mapping'
,
None
)
if
input_mapping
:
raw_vars
=
self
.
model
.
get_input_vars_desc
()
...
...
@@ -55,32 +57,19 @@ Use \'input_var_names=[{}]\' instead'.format(', '.join(self.input_var_names)))
self
.
input_var_names
=
[
k
.
name
for
k
in
raw_vars
]
self
.
output_var_names
=
kwargs
.
pop
(
'output_var_names'
)
assert
len
(
self
.
input_var_names
),
self
.
input_var_names
for
v
in
self
.
input_var_names
:
assert_type
(
v
,
six
.
string_types
)
assert
len
(
self
.
output_var_names
),
self
.
output_var_names
self
.
return_input
=
kwargs
.
pop
(
'return_input'
,
False
)
assert
len
(
kwargs
)
==
0
,
'Unknown arguments: {}'
.
format
(
str
(
kwargs
.
keys
()))
def
get_predict_func
(
config
):
"""
Produce a
simple predictor function
run inside a new session.
Produce a
offline predictor
run inside a new session.
:param config: a `PredictConfig` instance.
:returns: A
prediction function
that takes a list of input values, and return
:returns: A
callable predictor
that takes a list of input values, and return
a list of output values defined in ``config.output_var_names``.
"""
# build graph
input_vars
=
config
.
model
.
get_input_vars
()
config
.
model
.
_build_graph
(
input_vars
,
False
)
input_vars
=
get_vars_by_names
(
config
.
input_var_names
)
output_vars
=
get_vars_by_names
(
config
.
output_var_names
)
sess
=
tf
.
Session
(
config
=
config
.
session_config
)
config
.
session_init
.
init
(
sess
)
return
OfflinePredictor
(
config
)
def
run_input
(
dp
):
assert
len
(
input_vars
)
==
len
(
dp
),
"{} != {}"
.
format
(
len
(
input_vars
),
len
(
dp
))
feed
=
dict
(
zip
(
input_vars
,
dp
))
return
sess
.
run
(
output_vars
,
feed_dict
=
feed
)
# XXX hack. so the caller can get access to the session.
run_input
.
session
=
sess
return
run_input
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment