Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
9a711e72
Commit
9a711e72
authored
Oct 20, 2017
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[Trainerv2] use v2 inference interface in v1 trainer
parent
e121701a
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
61 additions
and
73 deletions
+61
-73
tensorpack/callbacks/inference_runner.py
tensorpack/callbacks/inference_runner.py
+23
-52
tensorpack/tfutils/tower.py
tensorpack/tfutils/tower.py
+0
-4
tensorpack/train/base.py
tensorpack/train/base.py
+33
-13
tensorpack/train/distributed.py
tensorpack/train/distributed.py
+1
-1
tensorpack/trainv2/base.py
tensorpack/trainv2/base.py
+4
-3
No files found.
tensorpack/callbacks/inference_runner.py
View file @
9a711e72
...
...
@@ -108,7 +108,7 @@ class InferenceRunner(InferenceRunnerBase):
infs (list): a list of :class:`Inferencer` instances.
tower_name (str): the name scope of the tower to build. Need to set a
different one if multiple InferenceRunner are used.
gpu
(int): the device to use
device
(int): the device to use
"""
if
isinstance
(
input
,
DataFlow
):
input
=
FeedInput
(
input
,
infinite
=
False
)
...
...
@@ -124,32 +124,17 @@ class InferenceRunner(InferenceRunnerBase):
return
InferencerToHook
(
inf
,
fetches
)
def
_setup_graph
(
self
):
if
self
.
trainer
.
_API_VERSION
==
1
:
# old Trainer API
assert
self
.
trainer
.
model
is
not
None
# Use predict_tower in train config. either gpuid or -1
if
self
.
trainer
.
_config
.
predict_tower
is
not
None
:
if
self
.
trainer
.
_API_VERSION
==
1
and
self
.
trainer
.
_config
.
predict_tower
is
not
None
:
device
=
self
.
trainer
.
_config
.
predict_tower
[
0
]
else
:
device
=
self
.
_device
device
=
'/gpu:{}'
.
format
(
device
)
if
device
>=
0
else
'/cpu:0'
input_callbacks
=
self
.
_input_source
.
setup
(
self
.
trainer
.
model
.
get_inputs_desc
())
with
tf
.
variable_scope
(
tf
.
get_variable_scope
(),
reuse
=
True
):
self
.
_tower_handle
=
self
.
trainer
.
predictor_factory
.
build
(
self
.
_tower_name
,
device
,
self
.
_input_source
)
else
:
# new Trainer API
from
..trainv2
import
TowerTrainer
assert
isinstance
(
self
.
trainer
,
TowerTrainer
),
self
.
trainer
assert
self
.
trainer
.
tower_func
is
not
None
,
"You must set tower_func of the trainer to use InferenceRunner!"
input_callbacks
=
self
.
_input_source
.
setup
(
self
.
trainer
.
inputs_desc
)
with
tf
.
variable_scope
(
tf
.
get_variable_scope
(),
reuse
=
True
):
SimplePredictBuilder
(
ns_name
=
self
.
_tower_name
,
vs_name
=
self
.
trainer
.
_main_tower_vs_name
,
device
=
0
)
.
build
(
vs_name
=
self
.
trainer
.
_main_tower_vs_name
,
device
=
device
)
.
build
(
self
.
_input_source
,
self
.
trainer
.
tower_func
)
self
.
_tower_handle
=
self
.
trainer
.
tower_func
.
towers
[
-
1
]
...
...
@@ -202,21 +187,7 @@ class DataParallelInferenceRunner(InferenceRunnerBase):
def
_setup_graph
(
self
):
self
.
_handles
=
[]
if
self
.
trainer
.
_API_VERSION
==
1
:
# old Trainer API
input_callbacks
=
self
.
_input_source
.
setup
(
self
.
trainer
.
model
.
get_inputs_desc
())
# build each predict tower
with
tf
.
variable_scope
(
tf
.
get_variable_scope
(),
reuse
=
True
):
for
idx
,
t
in
enumerate
(
self
.
_gpus
):
tower_name
=
self
.
_tower_names
[
idx
]
device
=
'/gpu:{}'
.
format
(
t
)
self
.
_handles
.
append
(
self
.
trainer
.
predictor_factory
.
build
(
tower_name
,
device
,
self
.
_input_source
))
else
:
# new Trainer API
from
..trainv2
import
TowerTrainer
assert
isinstance
(
self
.
trainer
,
TowerTrainer
),
self
.
trainer
assert
self
.
trainer
.
tower_func
is
not
None
,
"You must set tower_func of the trainer to use InferenceRunner!"
input_callbacks
=
self
.
_input_source
.
setup
(
self
.
trainer
.
inputs_desc
)
with
tf
.
variable_scope
(
tf
.
get_variable_scope
(),
reuse
=
True
):
...
...
tensorpack/tfutils/tower.py
View file @
9a711e72
...
...
@@ -60,9 +60,6 @@ class TowerContext(object):
(
self
.
is_training
and
len
(
self
.
_vs_name
)
>
0
)
or
\
(
not
self
.
is_training
and
len
(
self
.
_vs_name
)
>
0
and
not
self
.
_initial_vs_reuse
)
# TODO clarify the interface on name/vs_name/ns_name.
# TODO in inference, vs_name may need to be different from ns_name.i
# How to deal with this?
@
property
def
name
(
self
):
return
self
.
_name
...
...
@@ -151,7 +148,6 @@ class TowerContext(object):
def
get_current_tower_context
():
global
_CurrentTowerContext
return
_CurrentTowerContext
...
...
tensorpack/train/base.py
View file @
9a711e72
...
...
@@ -18,8 +18,11 @@ from ..tfutils import get_global_step_value
from
..tfutils.model_utils
import
describe_trainable_vars
from
..tfutils.sesscreate
import
ReuseSessionCreator
from
..tfutils.sessinit
import
JustCurrentSession
from
..tfutils.tower
import
TowerFuncWrapper
from
..graph_builder.predictor_factory
import
PredictorFactory
from
..input_source
import
PlaceholderInput
from
..graph_builder.predictor_factory
import
SimplePredictBuilder
from
..predict.base
import
OnlinePredictor
from
..callbacks.steps
import
MaintainStepCounter
__all__
=
[
'Trainer'
,
'StopTraining'
]
...
...
@@ -117,6 +120,16 @@ class Trainer(object):
assert
isinstance
(
config
,
TrainConfig
),
type
(
config
)
self
.
_config
=
config
self
.
model
=
config
.
model
if
self
.
model
is
not
None
:
def
f
(
*
inputs
):
self
.
model
.
build_graph
(
inputs
)
"""
Only to mimic new trainer interafce on inference.
"""
self
.
inputs_desc
=
self
.
model
.
get_inputs_desc
()
self
.
tower_func
=
TowerFuncWrapper
(
f
,
self
.
inputs_desc
)
self
.
_callbacks
=
[]
self
.
_monitors
=
[]
...
...
@@ -268,8 +281,7 @@ class Trainer(object):
def
get_predictor
(
self
,
input_names
,
output_names
,
tower
=
0
):
"""
Returns a callable predictor built under ``is_training=False`` tower context.
Note that this method is only valid when this trainer has a ``ModelDesc``.
Returns a callable predictor built under ``TowerContext(is_training=False)``.
Args:
input_names (list), output_names(list): list of names
...
...
@@ -278,19 +290,27 @@ class Trainer(object):
Returns:
an :class:`OnlinePredictor`.
"""
return
self
.
predictor_factory
.
get_predictor
(
input_names
,
output_names
,
tower
)
device
=
tower
assert
self
.
tower_func
is
not
None
,
"Must set tower_func on the trainer to use get_predictor()!"
tower_name
=
'tower-pred-{}'
.
format
(
device
)
if
device
>=
0
else
'tower-pred-cpu'
@
property
def
predictor_factory
(
self
):
assert
self
.
model
is
not
None
,
\
"Predictor can only be built one Trainer has ModelDesc!"
if
not
hasattr
(
self
,
'_predictor_factory'
):
self
.
_predictor_factory
=
PredictorFactory
(
self
.
model
,
self
.
vs_name_for_predictor
)
return
self
.
_predictor_factory
try
:
tower
=
self
.
tower_func
.
towers
[
tower_name
]
except
KeyError
:
input
=
PlaceholderInput
()
input
.
setup
(
self
.
inputs_desc
)
with
tf
.
variable_scope
(
tf
.
get_variable_scope
(),
reuse
=
True
):
SimplePredictBuilder
(
ns_name
=
tower_name
,
vs_name
=
self
.
_main_tower_vs_name
,
device
=
device
)
.
build
(
input
,
self
.
tower_func
)
tower
=
self
.
tower_func
.
towers
[
tower_name
]
input_tensors
=
tower
.
get_tensors
(
input_names
)
output_tensors
=
tower
.
get_tensors
(
output_names
)
return
OnlinePredictor
(
input_tensors
,
output_tensors
)
@
property
def
vs_name_for_predictor
(
self
):
def
_main_tower_vs_name
(
self
):
# The vs name a predictor should be built under.
# for internal use only. Should let graphbuilder return it.
return
""
...
...
tensorpack/train/distributed.py
View file @
9a711e72
...
...
@@ -95,5 +95,5 @@ class DistributedTrainerReplicated(Trainer):
self
.
_config
.
session_creator
=
get_distributed_session_creator
(
self
.
server
)
@
property
def
vs_name_for_predictor
(
self
):
def
_main_tower_vs_name
(
self
):
return
"tower0"
tensorpack/trainv2/base.py
View file @
9a711e72
...
...
@@ -268,6 +268,7 @@ class TowerTrainer(Trainer):
input
=
PlaceholderInput
()
input
.
setup
(
self
.
inputs_desc
)
with
tf
.
variable_scope
(
tf
.
get_variable_scope
(),
reuse
=
True
):
SimplePredictBuilder
(
ns_name
=
tower_name
,
vs_name
=
self
.
_main_tower_vs_name
,
device
=
device
)
.
build
(
input
,
self
.
tower_func
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment