Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
a55d81ca
Commit
a55d81ca
authored
Oct 19, 2017
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
use sess.make_callable for predictors
parent
f363d2e8
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
39 additions
and
9 deletions
+39
-9
examples/DynamicFilterNetwork/steering-filter.py
examples/DynamicFilterNetwork/steering-filter.py
+0
-1
examples/boilerplate.py
examples/boilerplate.py
+1
-2
tensorpack/predict/base.py
tensorpack/predict/base.py
+30
-4
tensorpack/tfutils/tower.py
tensorpack/tfutils/tower.py
+8
-2
No files found.
examples/DynamicFilterNetwork/steering-filter.py
View file @
a55d81ca
...
...
@@ -12,7 +12,6 @@ import multiprocessing
from
tensorpack
import
*
from
tensorpack.utils
import
logger
from
tensorpack.utils.gpu
import
get_nr_gpu
from
tensorpack.utils.viz
import
*
from
tensorpack.utils.argtools
import
shape2d
,
shape4d
from
tensorpack.dataflow
import
dataset
...
...
examples/boilerplate.py
View file @
a55d81ca
...
...
@@ -4,9 +4,8 @@
import
os
import
argparse
from
tensorpack
import
*
from
tensorpack.utils.gpu
import
get_nr_gpu
import
tensorflow
as
tf
from
tensorpack
import
*
"""
This is a boiler-plate template.
...
...
tensorpack/predict/base.py
View file @
a55d81ca
...
...
@@ -7,10 +7,11 @@ from abc import abstractmethod, ABCMeta
import
tensorflow
as
tf
import
six
from
..tfutils.common
import
get_tensors_by_names
from
..tfutils.common
import
get_tensors_by_names
,
get_tf_version_number
from
..tfutils.tower
import
TowerContext
from
..input_source
import
PlaceholderInput
from
..utils.develop
import
log_deprecated
from
..utils.argtools
import
log_once
__all__
=
[
'PredictorBase'
,
'AsyncPredictorBase'
,
'OnlinePredictor'
,
'OfflinePredictor'
,
...
...
@@ -106,15 +107,40 @@ class OnlinePredictor(PredictorBase):
self
.
input_tensors
=
input_tensors
self
.
output_tensors
=
output_tensors
self
.
sess
=
sess
self
.
_use_callable
=
get_tf_version_number
()
>=
1.2
if
self
.
_use_callable
:
if
sess
is
not
None
:
self
.
_callable
=
sess
.
make_callable
(
fetches
=
output_tensors
,
feed_list
=
input_tensors
)
else
:
log_once
(
"TF>=1.2 is recommended for better performance of predictor!"
,
'warn'
)
self
.
_callable
=
None
def
_do_call_old
(
self
,
dp
):
feed
=
dict
(
zip
(
self
.
input_tensors
,
dp
))
output
=
self
.
sess
.
run
(
self
.
output_tensors
,
feed_dict
=
feed
)
return
output
def
_do_call_new
(
self
,
dp
):
if
self
.
_callable
is
None
:
self
.
_callable
=
self
.
sess
.
make_callable
(
fetches
=
self
.
output_tensors
,
feed_list
=
self
.
input_tensors
)
return
self
.
_callable
(
*
dp
)
def
_do_call
(
self
,
dp
):
assert
len
(
dp
)
==
len
(
self
.
input_tensors
),
\
"{} != {}"
.
format
(
len
(
dp
),
len
(
self
.
input_tensors
))
feed
=
dict
(
zip
(
self
.
input_tensors
,
dp
))
if
self
.
sess
is
None
:
self
.
sess
=
tf
.
get_default_session
()
output
=
self
.
sess
.
run
(
self
.
output_tensors
,
feed_dict
=
feed
)
return
output
if
self
.
_use_callable
:
return
self
.
_do_call_new
(
dp
)
else
:
return
self
.
_do_call_old
(
dp
)
class
OfflinePredictor
(
OnlinePredictor
):
...
...
tensorpack/tfutils/tower.py
View file @
a55d81ca
...
...
@@ -278,7 +278,13 @@ class TowerTensorHandle(object):
"""
return
self
.
_output
# def make_callable(self, input_names, output_names):
# should move to somewhere else.
# def get_predictor(self, input_names, output_names):
# """
# Get a predictor with tensors inside this tower.
# """
# input_tensors = self.get_tensors(input_names)
# output_tensors = self.get_tensors(output_names)
# pass
# # TODO sort out the import order
# from ..predict.base import OnlinePredictor # noqa
# return OnlinePredictor(input_tensors, output_tensors)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment