Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
da98e447
Commit
da98e447
authored
Oct 18, 2017
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[Trainerv2] Let InferenceRunner run with new Trainer
parent
e5ff50e7
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
103 additions
and
23 deletions
+103
-23
tensorpack/callbacks/inference_runner.py
tensorpack/callbacks/inference_runner.py
+49
-21
tensorpack/graph_builder/predictor_factory.py
tensorpack/graph_builder/predictor_factory.py
+47
-1
tensorpack/train/config.py
tensorpack/train/config.py
+2
-1
tensorpack/trainv2/base.py
tensorpack/trainv2/base.py
+5
-0
No files found.
tensorpack/callbacks/inference_runner.py
View file @
da98e447
...
...
@@ -19,6 +19,8 @@ from ..dataflow.base import DataFlow
from
..input_source
import
(
InputSource
,
FeedInput
,
QueueInput
)
from
..graph_builder.predictor_factory
import
SimplePredictBuilder
# from ..trainv2 import SingleCostTrainer
from
.base
import
Callback
from
.group
import
Callbacks
...
...
@@ -121,16 +123,30 @@ class InferenceRunner(InferenceRunnerBase):
return
InferencerToHook
(
inf
,
fetches
)
def
_setup_graph
(
self
):
assert
self
.
trainer
.
model
is
not
None
# Use predict_tower in train config. either gpuid or -1
tower_id
=
self
.
trainer
.
_config
.
predict_tower
[
0
]
device
=
'/gpu:{}'
.
format
(
tower_id
)
if
tower_id
>=
0
else
'/cpu:0'
input_callbacks
=
self
.
_input_source
.
setup
(
self
.
trainer
.
model
.
get_inputs_desc
())
with
tf
.
variable_scope
(
tf
.
get_variable_scope
(),
reuse
=
True
):
self
.
_tower_handle
=
self
.
trainer
.
predictor_factory
.
build
(
self
.
_tower_name
,
device
,
self
.
_input_source
)
if
hasattr
(
self
.
trainer
,
'model'
):
# old Trainer API
assert
self
.
trainer
.
model
is
not
None
# Use predict_tower in train config. either gpuid or -1
tower_id
=
self
.
trainer
.
_config
.
predict_tower
[
0
]
device
=
'/gpu:{}'
.
format
(
tower_id
)
if
tower_id
>=
0
else
'/cpu:0'
input_callbacks
=
self
.
_input_source
.
setup
(
self
.
trainer
.
model
.
get_inputs_desc
())
with
tf
.
variable_scope
(
tf
.
get_variable_scope
(),
reuse
=
True
):
self
.
_tower_handle
=
self
.
trainer
.
predictor_factory
.
build
(
self
.
_tower_name
,
device
,
self
.
_input_source
)
else
:
# new Trainer API
# only works for singlecost trainer
# assert isinstance(self.trainer, SingleCostTrainer), self.trainer
input_callbacks
=
self
.
_input_source
.
setup
(
self
.
trainer
.
inputs_desc
)
with
tf
.
variable_scope
(
tf
.
get_variable_scope
(),
reuse
=
True
):
SimplePredictBuilder
(
ns_name
=
self
.
_tower_name
,
vs_name
=
''
,
device
=
0
)
.
build
(
# TODO fix vs_name and maybe device
self
.
_input_source
,
self
.
trainer
.
get_cost_fn
)
self
.
_tower_handle
=
self
.
trainer
.
get_cost_fn
.
towers
[
-
1
]
self
.
_hooks
=
[
self
.
_build_hook
(
inf
)
for
inf
in
self
.
infs
]
# trigger_{step,epoch}, {before,after}_epoch is ignored.
...
...
@@ -180,20 +196,32 @@ class DataParallelInferenceRunner(InferenceRunnerBase):
self
.
_gpus
=
gpus
def
_setup_graph
(
self
):
assert
self
.
trainer
.
model
is
not
None
cbs
=
self
.
_input_source
.
setup
(
self
.
trainer
.
model
.
get_inputs_desc
())
# build each predict tower
self
.
_handles
=
[]
with
tf
.
variable_scope
(
tf
.
get_variable_scope
(),
reuse
=
True
):
for
idx
,
t
in
enumerate
(
self
.
_gpus
):
tower_name
=
self
.
_tower_names
[
idx
]
device
=
'/gpu:{}'
.
format
(
t
)
self
.
_handles
.
append
(
self
.
trainer
.
predictor_factory
.
build
(
tower_name
,
device
,
self
.
_input_source
))
if
hasattr
(
self
.
trainer
,
'model'
):
# old Trainer API
input_callbacks
=
self
.
_input_source
.
setup
(
self
.
trainer
.
model
.
get_inputs_desc
())
# build each predict tower
with
tf
.
variable_scope
(
tf
.
get_variable_scope
(),
reuse
=
True
):
for
idx
,
t
in
enumerate
(
self
.
_gpus
):
tower_name
=
self
.
_tower_names
[
idx
]
device
=
'/gpu:{}'
.
format
(
t
)
self
.
_handles
.
append
(
self
.
trainer
.
predictor_factory
.
build
(
tower_name
,
device
,
self
.
_input_source
))
else
:
# new Trainer API
input_callbacks
=
self
.
_input_source
.
setup
(
self
.
trainer
.
inputs_desc
)
with
tf
.
variable_scope
(
tf
.
get_variable_scope
(),
reuse
=
True
):
for
idx
,
t
in
enumerate
(
self
.
_gpus
):
tower_name
=
self
.
_tower_names
[
idx
]
SimplePredictBuilder
(
ns_name
=
tower_name
,
vs_name
=
''
,
device
=
t
)
.
build
(
# TODO fix vs_name and maybe device
self
.
_input_source
,
self
.
trainer
.
get_cost_fn
)
self
.
_handles
.
append
(
self
.
trainer
.
get_cost_fn
.
towers
[
-
1
])
# setup callbacks and hooks
self
.
_input_callbacks
=
Callbacks
(
cb
s
)
self
.
_input_callbacks
=
Callbacks
(
input_callback
s
)
# InputSource might have hooks which break us.
# e.g. hooks from StagingInputWrapper will force the consumption
...
...
tensorpack/graph_builder/predictor_factory.py
View file @
da98e447
...
...
@@ -3,13 +3,59 @@
# File: predictor_factory.py
import
tensorflow
as
tf
from
contextlib
import
contextmanager
from
..utils
import
logger
from
..tfutils.tower
import
TowerContext
,
TowerFuncWrapper
from
..tfutils.collection
import
freeze_collection
from
..utils.naming
import
TOWER_FREEZE_KEYS
from
..input_source
import
PlaceholderInput
from
.training
import
GraphBuilder
__all__
=
[
'SimplePredictBuilder'
]
class
SimplePredictBuilder
(
GraphBuilder
):
"""
Single-tower predictor.
"""
def
__init__
(
self
,
ns_name
=
''
,
vs_name
=
''
,
device
=
0
):
"""
Args:
ns_name (str):
vs_name (str):
device (int):
"""
# TODO does vs_name work properly here when different from ns_name?
self
.
_ns_name
=
ns_name
self
.
_vs_name
=
vs_name
device
=
'/gpu:{}'
.
format
(
device
)
if
device
>=
0
else
'/cpu:0'
self
.
_device
=
device
__all__
=
[]
@
contextmanager
def
_maybe_open_vs
(
self
):
if
len
(
self
.
_vs_name
):
with
tf
.
variable_scope
(
self
.
_vs_name
):
yield
else
:
yield
def
build
(
self
,
input
,
tower_fn
):
assert
input
.
setup_done
()
logger
.
info
(
"Building predictor tower '{}' on device {} ..."
.
format
(
self
.
_ns_name
,
self
.
_device
))
with
tf
.
device
(
self
.
_device
),
\
self
.
_maybe_open_vs
(),
\
TowerContext
(
self
.
_ns_name
,
is_training
=
False
),
\
freeze_collection
(
TOWER_FREEZE_KEYS
+
[
tf
.
GraphKeys
.
UPDATE_OPS
]):
# also freeze UPDATE_OPS in inference, because they should never be used
# TODO a better way to log and warn about collection change during build_graph.
inputs
=
input
.
get_input_tensors
()
assert
isinstance
(
inputs
,
(
list
,
tuple
)),
inputs
return
tower_fn
(
*
inputs
)
class
PredictorFactory
(
object
):
...
...
tensorpack/train/config.py
View file @
da98e447
...
...
@@ -27,7 +27,7 @@ class TrainConfig(object):
callbacks
=
None
,
extra_callbacks
=
None
,
monitors
=
None
,
session_creator
=
None
,
session_config
=
None
,
session_init
=
None
,
starting_epoch
=
1
,
steps_per_epoch
=
None
,
max_epoch
=
99999
,
nr_tower
=
1
,
tower
=
None
,
predict_tower
=
None
,
nr_tower
=
1
,
tower
=
None
,
**
kwargs
):
"""
Note:
...
...
@@ -127,6 +127,7 @@ class TrainConfig(object):
assert
self
.
nr_tower
==
1
,
"Cannot set both nr_tower and tower in TrainConfig!"
self
.
tower
=
tower
predict_tower
=
kwargs
.
pop
(
'predict_tower'
,
None
)
if
predict_tower
is
None
:
predict_tower
=
[
0
]
self
.
predict_tower
=
predict_tower
...
...
tensorpack/trainv2/base.py
View file @
da98e447
...
...
@@ -16,6 +16,7 @@ from ..callbacks.monitor import Monitors, TrainingMonitor
from
..tfutils.model_utils
import
describe_trainable_vars
from
..tfutils.sessinit
import
JustCurrentSession
from
..tfutils.sesscreate
import
ReuseSessionCreator
from
..tfutils.tower
import
TowerFuncWrapper
from
..callbacks.steps
import
MaintainStepCounter
from
..train.base
import
StopTraining
,
TrainLoop
...
...
@@ -240,9 +241,13 @@ class SingleCostTrainer(Trainer):
These callbacks will be automatically added when you call `train()`.
So you can usually ignore the return value.
"""
get_cost_fn
=
TowerFuncWrapper
(
get_cost_fn
,
inputs_desc
)
input_callbacks
=
self
.
_setup_input
(
inputs_desc
,
input
)
train_callbacks
=
self
.
_setup_graph
(
input
,
get_cost_fn
,
get_opt_fn
)
self
.
_internal_callbacks
=
input_callbacks
+
train_callbacks
self
.
inputs_desc
=
inputs_desc
self
.
get_cost_fn
=
get_cost_fn
return
self
.
_internal_callbacks
@
abstractmethod
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment