Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
93a177bf
Commit
93a177bf
authored
Oct 20, 2017
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[Trainerv2] Add get_predictor support
parent
e0b13533
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
69 additions
and
20 deletions
+69
-20
tensorpack/callbacks/inference_runner.py
tensorpack/callbacks/inference_runner.py
+4
-0
tensorpack/graph_builder/predictor_factory.py
tensorpack/graph_builder/predictor_factory.py
+8
-0
tensorpack/predict/base.py
tensorpack/predict/base.py
+3
-2
tensorpack/tfutils/tower.py
tensorpack/tfutils/tower.py
+18
-15
tensorpack/trainv2/base.py
tensorpack/trainv2/base.py
+36
-3
No files found.
tensorpack/callbacks/inference_runner.py
View file @
93a177bf
...
...
@@ -143,6 +143,7 @@ class InferenceRunner(InferenceRunnerBase):
# new Trainer API
from
..trainv2
import
TowerTrainer
assert
isinstance
(
self
.
trainer
,
TowerTrainer
),
self
.
trainer
assert
self
.
trainer
.
tower_func
is
not
None
,
"You must set tower_func of the trainer to use InferenceRunner!"
input_callbacks
=
self
.
_input_source
.
setup
(
self
.
trainer
.
inputs_desc
)
with
tf
.
variable_scope
(
tf
.
get_variable_scope
(),
reuse
=
True
):
...
...
@@ -214,6 +215,9 @@ class DataParallelInferenceRunner(InferenceRunnerBase):
tower_name
,
device
,
self
.
_input_source
))
else
:
# new Trainer API
from
..trainv2
import
TowerTrainer
assert
isinstance
(
self
.
trainer
,
TowerTrainer
),
self
.
trainer
assert
self
.
trainer
.
tower_func
is
not
None
,
"You must set tower_func of the trainer to use InferenceRunner!"
input_callbacks
=
self
.
_input_source
.
setup
(
self
.
trainer
.
inputs_desc
)
with
tf
.
variable_scope
(
tf
.
get_variable_scope
(),
reuse
=
True
):
for
idx
,
t
in
enumerate
(
self
.
_gpus
):
...
...
tensorpack/graph_builder/predictor_factory.py
View file @
93a177bf
...
...
@@ -42,6 +42,14 @@ class SimplePredictBuilder(GraphBuilder):
yield
def
build
(
self
,
input
,
tower_fn
):
"""
Args:
input (InputSource): must have been setup
tower_fn ( [tf.Tensors] ->): callable that takes input tensors.
Returns:
The return value of tower_fn called under the proper context.
"""
assert
input
.
setup_done
()
logger
.
info
(
"Building predictor tower '{}' on device {} ..."
.
format
(
self
.
_ns_name
,
self
.
_device
))
...
...
tensorpack/predict/base.py
View file @
93a177bf
...
...
@@ -115,9 +115,10 @@ class OnlinePredictor(PredictorBase):
fetches
=
output_tensors
,
feed_list
=
input_tensors
)
else
:
log_once
(
"TF>=1.2 is recommended for better performance of predictor!"
,
'warn'
)
self
.
_callable
=
None
else
:
log_once
(
"TF>=1.2 is recommended for better performance of predictor!"
,
'warn'
)
def
_do_call_old
(
self
,
dp
):
feed
=
dict
(
zip
(
self
.
input_tensors
,
dp
))
...
...
tensorpack/tfutils/tower.py
View file @
93a177bf
...
...
@@ -166,7 +166,7 @@ class TowerFuncWrapper(object):
self
.
_tower_fn
=
tower_fn
self
.
_inputs_desc
=
inputs_desc
self
.
_
tower
s
=
[]
self
.
_
handle
s
=
[]
def
__new__
(
cls
,
tower_fn
,
inputs_desc
):
# to avoid double-wrapping a function
...
...
@@ -180,19 +180,33 @@ class TowerFuncWrapper(object):
assert
ctx
is
not
None
,
"Function must be called under TowerContext!"
output
=
self
.
_tower_fn
(
*
args
)
handle
=
TowerTensorHandle
(
ctx
,
args
,
output
,
self
.
_inputs_desc
)
self
.
_
tower
s
.
append
(
handle
)
self
.
_
handle
s
.
append
(
handle
)
return
output
@
property
def
towers
(
self
):
# TODO another wrapper around towerhandlelist
return
self
.
_towers
return
TowerTensorHandles
(
self
.
_handles
)
@
property
def
inputs_desc
(
self
):
return
self
.
_inputs_desc
class
TowerTensorHandles
(
object
):
"""
Wrap a list of :class:`TowerTensorHandle`,
to support access to them by index or names.
"""
def
__init__
(
self
,
handles
):
self
.
_handles
=
handles
self
.
_name_to_handle
=
{
k
.
ns_name
:
k
for
k
in
handles
}
def
__getitem__
(
self
,
name_or_index
):
if
isinstance
(
name_or_index
,
int
):
return
self
.
_handles
[
name_or_index
]
return
self
.
_name_to_handle
[
name_or_index
]
class
TowerTensorHandle
(
object
):
"""
When a function is called multiple times under each tower,
...
...
@@ -281,14 +295,3 @@ class TowerTensorHandle(object):
The output returned by the tower function.
"""
return
self
.
_output
# should move to somewhere else.
# def get_predictor(self, input_names, output_names):
# """
# Get a predictor with tensors inside this tower.
# """
# input_tensors = self.get_tensors(input_names)
# output_tensors = self.get_tensors(output_names)
# # TODO sort out the import order
# from ..predict.base import OnlinePredictor # noqa
# return OnlinePredictor(input_tensors, output_tensors)
tensorpack/trainv2/base.py
View file @
93a177bf
...
...
@@ -11,7 +11,6 @@ from abc import abstractmethod, ABCMeta
from
..utils
import
logger
from
..utils.argtools
import
call_only_once
,
memoized
from
..input_source
import
FeedfreeInput
from
..callbacks
import
Callback
,
Callbacks
from
..callbacks.monitor
import
Monitors
,
TrainingMonitor
from
..tfutils.model_utils
import
describe_trainable_vars
...
...
@@ -21,10 +20,14 @@ from ..tfutils.tower import TowerFuncWrapper, get_current_tower_context
from
..tfutils.gradproc
import
FilterNoneGrad
from
..callbacks.steps
import
MaintainStepCounter
from
..graph_builder.predictor_factory
import
SimplePredictBuilder
from
..input_source
import
FeedfreeInput
,
PlaceholderInput
from
..predict.base
import
OnlinePredictor
import
tensorpack.train
as
old_train
# noqa
from
..train.base
import
StopTraining
,
TrainLoop
__all__
=
[
'Trainer'
,
'SingleCostTrainer'
]
__all__
=
[
'Trainer'
,
'SingleCostTrainer'
,
'TowerTrainer'
]
class
Trainer
(
object
):
...
...
@@ -190,7 +193,8 @@ class Trainer(object):
# create the old trainer when called with TrainConfig
def
__new__
(
cls
,
*
args
,
**
kwargs
):
if
isinstance
(
args
[
0
],
old_train
.
TrainConfig
)
or
'config'
in
kwargs
:
if
(
len
(
args
)
>
0
and
isinstance
(
args
[
0
],
old_train
.
TrainConfig
))
\
or
'config'
in
kwargs
:
name
=
cls
.
__name__
old_trainer
=
getattr
(
old_train
,
name
)
return
old_trainer
(
*
args
,
**
kwargs
)
...
...
@@ -237,6 +241,7 @@ class TowerTrainer(Trainer):
Args:
tower_func (TowerFuncWrapper)
"""
assert
isinstance
(
tower_func
,
TowerFuncWrapper
),
tower_func
self
.
tower_func
=
tower_func
@
property
...
...
@@ -247,6 +252,34 @@ class TowerTrainer(Trainer):
"""
return
self
.
tower_func
.
inputs_desc
def
get_predictor
(
self
,
input_names
,
output_names
,
device
=
0
):
"""
Returns a callable predictor built under ``TowerContext(is_training=False)``.
Args:
input_names (list), output_names(list): list of names
device (int): build the predictor on device '/gpu:{device}' or use -1 for '/cpu:0'.
Returns:
an :class:`OnlinePredictor`.
"""
assert
self
.
tower_func
is
not
None
,
"Must set tower_func on the trainer to use get_predictor()!"
tower_name
=
'tower-pred-{}'
.
format
(
device
)
if
device
>=
0
else
'tower-pred-cpu'
try
:
tower
=
self
.
tower_func
.
towers
[
tower_name
]
except
KeyError
:
input
=
PlaceholderInput
()
input
.
setup
(
self
.
inputs_desc
)
SimplePredictBuilder
(
ns_name
=
tower_name
,
vs_name
=
''
,
device
=
device
)
.
build
(
input
,
self
.
tower_func
)
tower
=
self
.
tower_func
.
towers
[
tower_name
]
input_tensors
=
tower
.
get_tensors
(
input_names
)
output_tensors
=
tower
.
get_tensors
(
output_names
)
return
OnlinePredictor
(
input_tensors
,
output_tensors
)
@
six
.
add_metaclass
(
ABCMeta
)
class
SingleCostTrainer
(
TowerTrainer
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment