Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
6c68f8aa
Commit
6c68f8aa
authored
Feb 23, 2017
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
rename some variables in inferencerunner.
parent
a47c9980
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
27 additions
and
31 deletions
+27
-31
tensorpack/callbacks/inference_runner.py
tensorpack/callbacks/inference_runner.py
+24
-28
tensorpack/train/base.py
tensorpack/train/base.py
+3
-3
No files found.
tensorpack/callbacks/inference_runner.py
View file @
6c68f8aa
...
@@ -8,13 +8,12 @@ from collections import namedtuple
...
@@ -8,13 +8,12 @@ from collections import namedtuple
import
six
import
six
from
six.moves
import
zip
,
range
from
six.moves
import
zip
,
range
from
..utils
import
logger
,
get_tqdm
from
..dataflow
import
DataFlow
from
..dataflow
import
DataFlow
from
..utils
import
logger
,
get_tqdm
,
SUMMARY_BACKUP_KEYS
from
..tfutils.common
import
get_op_tensor_name
from
..tfutils.common
import
get_op_tensor_name
from
..tfutils.collection
import
freeze_collection
from
..tfutils
import
TowerContext
from
..tfutils
import
TowerContext
from
..train.input_data
import
FeedfreeInput
from
..train.input_data
import
FeedfreeInput
from
..predict
import
build_prediction_graph
from
..predict
import
PredictorTowerBuilder
from
.base
import
Triggerable
from
.base
import
Triggerable
from
.inference
import
Inferencer
from
.inference
import
Inferencer
...
@@ -22,7 +21,7 @@ from .inference import Inferencer
...
@@ -22,7 +21,7 @@ from .inference import Inferencer
__all__
=
[
'InferenceRunner'
,
'FeedfreeInferenceRunner'
]
__all__
=
[
'InferenceRunner'
,
'FeedfreeInferenceRunner'
]
class
OutputTensorDispatcer
(
object
):
class
OutputTensorDispatc
h
er
(
object
):
def
__init__
(
self
):
def
__init__
(
self
):
self
.
_names
=
[]
self
.
_names
=
[]
self
.
_idxs
=
[]
self
.
_idxs
=
[]
...
@@ -71,7 +70,7 @@ class InferenceRunner(Triggerable):
...
@@ -71,7 +70,7 @@ class InferenceRunner(Triggerable):
_IOTensor
=
namedtuple
(
'IOTensor'
,
[
'index'
,
'isOutput'
])
_IOTensor
=
namedtuple
(
'IOTensor'
,
[
'index'
,
'isOutput'
])
def
__init__
(
self
,
ds
,
infs
,
input_tensors
=
None
):
def
__init__
(
self
,
ds
,
infs
,
input_tensor
_name
s
=
None
):
"""
"""
Args:
Args:
ds (DataFlow): the DataFlow to run inferencer on.
ds (DataFlow): the DataFlow to run inferencer on.
...
@@ -87,16 +86,16 @@ class InferenceRunner(Triggerable):
...
@@ -87,16 +86,16 @@ class InferenceRunner(Triggerable):
self
.
infs
=
infs
self
.
infs
=
infs
for
v
in
self
.
infs
:
for
v
in
self
.
infs
:
assert
isinstance
(
v
,
Inferencer
),
v
assert
isinstance
(
v
,
Inferencer
),
v
self
.
input_
tensors
=
input_tensor
s
# names actually
self
.
input_
names
=
input_tensor_name
s
# names actually
def
_setup_graph
(
self
):
def
_setup_graph
(
self
):
self
.
_find_input_tensors
()
# these are all tensor names
self
.
_find_input_tensors
()
# these are all tensor names
self
.
_find_output_tensors
()
# may be either tensor name or op name
self
.
_find_output_tensors
()
# may be either tensor name or op name
self
.
predictor
=
self
.
trainer
.
get_predictor
(
self
.
predictor
=
self
.
trainer
.
get_predictor
(
self
.
input_
tensors
,
self
.
output_tensor
s
)
self
.
input_
names
,
self
.
output_name
s
)
def
_find_input_tensors
(
self
):
def
_find_input_tensors
(
self
):
if
self
.
input_
tensor
s
is
None
:
if
self
.
input_
name
s
is
None
:
input_vars
=
self
.
trainer
.
model
.
get_reused_placehdrs
()
input_vars
=
self
.
trainer
.
model
.
get_reused_placehdrs
()
# TODO even if it works here, sparse still is unavailable
# TODO even if it works here, sparse still is unavailable
# because get_tensor_by_name doesn't work for sparse
# because get_tensor_by_name doesn't work for sparse
...
@@ -105,27 +104,27 @@ class InferenceRunner(Triggerable):
...
@@ -105,27 +104,27 @@ class InferenceRunner(Triggerable):
if
isinstance
(
x
,
tf
.
SparseTensor
):
if
isinstance
(
x
,
tf
.
SparseTensor
):
return
x
.
op
.
name
.
split
(
'/'
)[
0
]
return
x
.
op
.
name
.
split
(
'/'
)[
0
]
return
x
.
name
return
x
.
name
self
.
input_
tensor
s
=
[
get_name
(
x
)
for
x
in
input_vars
]
self
.
input_
name
s
=
[
get_name
(
x
)
for
x
in
input_vars
]
def
_find_output_tensors
(
self
):
def
_find_output_tensors
(
self
):
dispatc
er
=
OutputTensorDispatc
er
()
dispatc
her
=
OutputTensorDispatch
er
()
for
inf
in
self
.
infs
:
for
inf
in
self
.
infs
:
dispatcer
.
add_entry
(
inf
.
get_output_tensors
())
dispatc
h
er
.
add_entry
(
inf
.
get_output_tensors
())
all_names
=
dispatcer
.
get_all_names
()
all_names
=
dispatc
h
er
.
get_all_names
()
IOTensor
=
InferenceRunner
.
_IOTensor
IOTensor
=
InferenceRunner
.
_IOTensor
self
.
output_
tensor
s
=
list
(
filter
(
self
.
output_
name
s
=
list
(
filter
(
lambda
x
:
x
not
in
self
.
input_
tensor
s
,
all_names
))
lambda
x
:
x
not
in
self
.
input_
name
s
,
all_names
))
def
find_tensors
(
names
):
def
find_tensors
(
names
):
ret
=
[]
ret
=
[]
for
name
in
names
:
for
name
in
names
:
if
name
in
self
.
input_
tensor
s
:
if
name
in
self
.
input_
name
s
:
ret
.
append
(
IOTensor
(
self
.
input_
tensor
s
.
index
(
name
),
False
))
ret
.
append
(
IOTensor
(
self
.
input_
name
s
.
index
(
name
),
False
))
else
:
else
:
ret
.
append
(
IOTensor
(
self
.
output_
tensor
s
.
index
(
name
),
True
))
ret
.
append
(
IOTensor
(
self
.
output_
name
s
.
index
(
name
),
True
))
return
ret
return
ret
self
.
inf_to_tensors
=
[
find_tensors
(
t
)
for
t
in
dispatcer
.
get_names_for_each_entry
()]
self
.
inf_to_tensors
=
[
find_tensors
(
t
)
for
t
in
dispatc
h
er
.
get_names_for_each_entry
()]
# list of list of IOTensor
# list of list of IOTensor
def
_trigger
(
self
):
def
_trigger
(
self
):
...
@@ -183,14 +182,11 @@ class FeedfreeInferenceRunner(Triggerable):
...
@@ -183,14 +182,11 @@ class FeedfreeInferenceRunner(Triggerable):
def
_setup_graph
(
self
):
def
_setup_graph
(
self
):
self
.
_find_input_tensors
()
# tensors
self
.
_find_input_tensors
()
# tensors
# TODO reuse predictor code
# TODO can we reuse predictor factory?
# overwrite the FeedfreeInferenceRunner name scope
with
tf
.
variable_scope
(
tf
.
get_variable_scope
(),
reuse
=
True
):
with
tf
.
variable_scope
(
tf
.
get_variable_scope
(),
reuse
=
True
),
\
tf
.
name_scope
(
None
),
\
freeze_collection
(
SUMMARY_BACKUP_KEYS
):
def
fn
(
_
):
def
fn
(
_
):
self
.
trainer
.
model
.
build_graph
(
self
.
_input_tensors
)
self
.
trainer
.
model
.
build_graph
(
self
.
_input_tensors
)
build_prediction_graph
(
fn
,
[
0
],
prefix
=
self
.
_prefix
)
PredictorTowerBuilder
(
fn
,
self
.
_prefix
)
.
build
(
0
)
self
.
_tower_prefix
=
TowerContext
.
get_predict_tower_name
(
0
,
self
.
_prefix
)
self
.
_tower_prefix
=
TowerContext
.
get_predict_tower_name
(
0
,
self
.
_prefix
)
self
.
_find_output_tensors
()
self
.
_find_output_tensors
()
...
@@ -220,16 +216,16 @@ class FeedfreeInferenceRunner(Triggerable):
...
@@ -220,16 +216,16 @@ class FeedfreeInferenceRunner(Triggerable):
def
_find_output_tensors
(
self
):
def
_find_output_tensors
(
self
):
# TODO doesn't support output an input tensor
# TODO doesn't support output an input tensor
dispatc
er
=
OutputTensorDispatc
er
()
dispatc
her
=
OutputTensorDispatch
er
()
for
inf
in
self
.
infs
:
for
inf
in
self
.
infs
:
dispatcer
.
add_entry
(
inf
.
get_output_tensors
())
dispatc
h
er
.
add_entry
(
inf
.
get_output_tensors
())
all_names
=
dispatcer
.
get_all_names
()
all_names
=
dispatc
h
er
.
get_all_names
()
G
=
tf
.
get_default_graph
()
G
=
tf
.
get_default_graph
()
self
.
_output_tensors
=
[
G
.
get_tensor_by_name
(
self
.
_output_tensors
=
[
G
.
get_tensor_by_name
(
self
.
_tower_prefix
+
'/'
+
n
)
for
n
in
all_names
]
self
.
_tower_prefix
+
'/'
+
n
)
for
n
in
all_names
]
# list of list of id
# list of list of id
self
.
inf_to_idxs
=
dispatcer
.
get_idx_for_each_entry
()
self
.
inf_to_idxs
=
dispatc
h
er
.
get_idx_for_each_entry
()
def
_trigger
(
self
):
def
_trigger
(
self
):
sess
=
tf
.
get_default_session
()
sess
=
tf
.
get_default_session
()
...
...
tensorpack/train/base.py
View file @
6c68f8aa
...
@@ -129,8 +129,10 @@ class Trainer(object):
...
@@ -129,8 +129,10 @@ class Trainer(object):
self
.
config
.
callbacks
.
setup_graph
(
weakref
.
proxy
(
self
))
self
.
config
.
callbacks
.
setup_graph
(
weakref
.
proxy
(
self
))
self
.
config
.
session_init
.
_setup_graph
()
self
.
config
.
session_init
.
_setup_graph
()
def
after_init
(
_
,
__
):
def
after_init
(
scaffold
,
sess
):
logger
.
info
(
"Graph variables initialized."
)
logger
.
info
(
"Graph variables initialized."
)
self
.
config
.
session_init
.
_run_init
(
sess
)
scaffold
=
tf
.
train
.
Scaffold
(
scaffold
=
tf
.
train
.
Scaffold
(
init_op
=
tf
.
global_variables_initializer
(),
init_op
=
tf
.
global_variables_initializer
(),
init_fn
=
after_init
)
init_fn
=
after_init
)
...
@@ -140,9 +142,7 @@ class Trainer(object):
...
@@ -140,9 +142,7 @@ class Trainer(object):
scaffold
=
scaffold
,
config
=
self
.
config
.
session_config
),
scaffold
=
scaffold
,
config
=
self
.
config
.
session_config
),
hooks
=
self
.
config
.
callbacks
.
get_hooks
())
hooks
=
self
.
config
.
callbacks
.
get_hooks
())
self
.
hooked_sess
=
self
.
monitored_sess
# just create an alias
self
.
hooked_sess
=
self
.
monitored_sess
# just create an alias
self
.
sess
=
self
.
monitored_sess
.
_tf_sess
()
# expose the underlying session also
self
.
sess
=
self
.
monitored_sess
.
_tf_sess
()
# expose the underlying session also
self
.
config
.
session_init
.
_run_init
(
self
.
sess
)
@
abstractmethod
@
abstractmethod
def
_setup
(
self
):
def
_setup
(
self
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment