Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
3080e91e
Commit
3080e91e
authored
Dec 15, 2016
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
some rename and alias
parent
42e5481a
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
57 additions
and
8 deletions
+57
-8
tensorpack/callbacks/inference_runner.py
tensorpack/callbacks/inference_runner.py
+50
-5
tensorpack/models/model_desc.py
tensorpack/models/model_desc.py
+5
-2
tensorpack/predict/base.py
tensorpack/predict/base.py
+2
-1
No files found.
tensorpack/callbacks/inference_runner.py
View file @
3080e91e
...
...
@@ -6,7 +6,7 @@
import
tensorflow
as
tf
from
collections
import
namedtuple
import
six
from
six.moves
import
zip
from
six.moves
import
zip
,
range
from
..dataflow
import
DataFlow
from
.base
import
Callback
...
...
@@ -61,7 +61,7 @@ class InferenceRunner(Callback):
def
_find_input_tensors
(
self
):
if
self
.
input_tensors
is
None
:
input_vars
=
self
.
trainer
.
model
.
get_
input_va
rs
()
input_vars
=
self
.
trainer
.
model
.
get_
reuse_placehd
rs
()
# TODO even if it works here, sparse still is unavailable
# because get_tensor_by_name doesn't work for sparse
def
get_name
(
x
):
...
...
@@ -125,13 +125,58 @@ class FeedfreeInferenceRunner(Callback):
self
.
input_tensor_names
=
input_tensors
def
_setup_graph
(
self
):
self
.
_find_input_tensors
()
# tensors
self
.
_find_output_tensors
()
# TODO build tower
def
_find_input_tensors
(
self
):
self
.
_input_data
.
_setup
(
self
.
trainer
)
# only 1 prediction tower will be used for inference
self
.
_input_tensors
=
self
.
_input_data
.
get_input_tensors
()
# TODO filter by names
self
.
_find_output_tensors
()
model_placehdrs
=
self
.
trainer
.
model
.
get_reuse_placehdrs
()
assert
len
(
self
.
_input_tensors
)
==
len
(
model_placehdrs
),
\
"FeedfreeInput doesn't produce correct number of output tensors"
if
self
.
input_tensor_names
is
not
None
:
assert
isinstance
(
self
.
input_tensor_names
,
list
)
self
.
_input_tensors
=
[
k
for
idx
,
k
in
enumerate
(
self
.
_input_tensors
)
if
model_placehdrs
[
idx
]
.
name
in
self
.
input_tensor_names
]
assert
len
(
self
.
_input_tensors
)
==
len
(
self
.
input_tensor_names
),
\
"names of input tensors are not defined in the Model"
def
_find_output_tensors
(
self
):
pass
# doesn't support output an input tensor
dispatcer
=
OutputTensorDispatcer
()
for
inf
in
self
.
infs
:
dispatcer
.
add_entry
(
inf
.
get_output_tensors
())
all_names
=
dispatcer
.
get_all_names
()
IOTensor
=
InferenceRunner
.
IOTensor
self
.
output_tensors
=
all_names
def
find_oid
(
idxs
):
ret
=
[]
for
idx
in
idxs
:
name
=
all_names
[
idx
]
ret
.
append
(
IOTensor
(
self
.
output_tensors
.
index
(
name
),
True
))
return
ret
self
.
inf_to_tensors
=
[
find_oid
(
t
)
for
t
in
dispatcer
.
get_idx_for_each_entry
()]
# list of list of (var_name: IOTensor)
def
_trigger_epoch
(
self
):
for
inf
in
self
.
infs
:
inf
.
before_inference
()
sess
=
tf
.
get_default_session
()
sz
=
self
.
_input_data
.
size
()
with
get_tqdm
(
total
=
sz
)
as
pbar
:
for
_
in
range
(
sz
):
#outputs = self.pred_func(dp)
#for inf, tensormap in zip(self.infs, self.inf_to_tensors):
#inf_output = [(outputs if k.isOutput else dp)[k.index]
#for k in tensormap]
#inf.datapoint(inf_output)
pbar
.
update
()
self
.
_write_summary_after_inference
()
def
_write_summary_after_inference
(
self
):
summary_inferencer
(
self
.
trainer
,
self
.
infs
)
tensorpack/models/model_desc.py
View file @
3080e91e
...
...
@@ -43,11 +43,14 @@ class ModelDesc(object):
"""
if
hasattr
(
self
,
'reuse_input_vars'
):
return
self
.
reuse_input_vars
ret
=
self
.
get
_placeholders
()
ret
=
self
.
build
_placeholders
()
self
.
reuse_input_vars
=
ret
return
ret
def
get_placeholders
(
self
,
prefix
=
''
):
# alias
get_reuse_placehdrs
=
get_input_vars
def
build_placeholders
(
self
,
prefix
=
''
):
""" build placeholders with optional prefix, for each InputVar
"""
input_vars
=
self
.
_get_input_vars
()
...
...
tensorpack/predict/base.py
View file @
3080e91e
...
...
@@ -151,7 +151,8 @@ class DataParallelOfflinePredictor(OnlinePredictor):
output_vars
=
[]
for
k
in
towers
:
towername
=
PREDICT_TOWER
+
str
(
k
)
input_vars
=
config
.
model
.
get_placeholders
(
prefix
=
towername
+
'-'
)
input_vars
=
config
.
model
.
build_placeholders
(
prefix
=
towername
+
'-'
)
logger
.
info
(
"Building graph for predictor tower {}..."
.
format
(
k
))
with
tf
.
device
(
'/gpu:{}'
.
format
(
k
)
if
k
>=
0
else
'/cpu:0'
),
\
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment