Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
e0b13533
Commit
e0b13533
authored
Oct 19, 2017
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[Trainerv2] Add TowerTrainer on top of singlecost
parent
0f90d4c2
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
52 additions
and
15 deletions
+52
-15
tensorpack/callbacks/inference_runner.py
tensorpack/callbacks/inference_runner.py
+8
-9
tensorpack/tfutils/tower.py
tensorpack/tfutils/tower.py
+6
-2
tensorpack/train/base.py
tensorpack/train/base.py
+2
-0
tensorpack/trainv2/base.py
tensorpack/trainv2/base.py
+36
-4
No files found.
tensorpack/callbacks/inference_runner.py
View file @
e0b13533
...
@@ -20,7 +20,6 @@ from ..dataflow.base import DataFlow
...
@@ -20,7 +20,6 @@ from ..dataflow.base import DataFlow
from
..input_source
import
(
from
..input_source
import
(
InputSource
,
FeedInput
,
QueueInput
)
InputSource
,
FeedInput
,
QueueInput
)
from
..graph_builder.predictor_factory
import
SimplePredictBuilder
from
..graph_builder.predictor_factory
import
SimplePredictBuilder
# from ..trainv2 import SingleCostTrainer
from
.base
import
Callback
from
.base
import
Callback
from
.group
import
Callbacks
from
.group
import
Callbacks
...
@@ -125,7 +124,7 @@ class InferenceRunner(InferenceRunnerBase):
...
@@ -125,7 +124,7 @@ class InferenceRunner(InferenceRunnerBase):
return
InferencerToHook
(
inf
,
fetches
)
return
InferencerToHook
(
inf
,
fetches
)
def
_setup_graph
(
self
):
def
_setup_graph
(
self
):
if
hasattr
(
self
.
trainer
,
'model'
)
:
if
self
.
trainer
.
_API_VERSION
==
1
:
# old Trainer API
# old Trainer API
assert
self
.
trainer
.
model
is
not
None
assert
self
.
trainer
.
model
is
not
None
# Use predict_tower in train config. either gpuid or -1
# Use predict_tower in train config. either gpuid or -1
...
@@ -142,16 +141,16 @@ class InferenceRunner(InferenceRunnerBase):
...
@@ -142,16 +141,16 @@ class InferenceRunner(InferenceRunnerBase):
self
.
_tower_name
,
device
,
self
.
_input_source
)
self
.
_tower_name
,
device
,
self
.
_input_source
)
else
:
else
:
# new Trainer API
# new Trainer API
# only works for singlecost t
rainer
from
..trainv2
import
TowerT
rainer
# assert isinstance(self.trainer, SingleCost
Trainer), self.trainer
assert
isinstance
(
self
.
trainer
,
Tower
Trainer
),
self
.
trainer
input_callbacks
=
self
.
_input_source
.
setup
(
self
.
trainer
.
inputs_desc
)
input_callbacks
=
self
.
_input_source
.
setup
(
self
.
trainer
.
inputs_desc
)
with
tf
.
variable_scope
(
tf
.
get_variable_scope
(),
reuse
=
True
):
with
tf
.
variable_scope
(
tf
.
get_variable_scope
(),
reuse
=
True
):
SimplePredictBuilder
(
SimplePredictBuilder
(
ns_name
=
self
.
_tower_name
,
ns_name
=
self
.
_tower_name
,
vs_name
=
''
,
device
=
0
)
.
build
(
# TODO fix vs_name and maybe device
vs_name
=
''
,
device
=
0
)
.
build
(
# TODO fix vs_name and maybe device
self
.
_input_source
,
self
.
trainer
.
get_cost_fn
)
self
.
_input_source
,
self
.
trainer
.
tower_func
)
self
.
_tower_handle
=
self
.
trainer
.
get_cost_fn
.
towers
[
-
1
]
self
.
_tower_handle
=
self
.
trainer
.
tower_func
.
towers
[
-
1
]
self
.
_hooks
=
[
self
.
_build_hook
(
inf
)
for
inf
in
self
.
infs
]
self
.
_hooks
=
[
self
.
_build_hook
(
inf
)
for
inf
in
self
.
infs
]
# trigger_{step,epoch}, {before,after}_epoch is ignored.
# trigger_{step,epoch}, {before,after}_epoch is ignored.
...
@@ -202,7 +201,7 @@ class DataParallelInferenceRunner(InferenceRunnerBase):
...
@@ -202,7 +201,7 @@ class DataParallelInferenceRunner(InferenceRunnerBase):
def
_setup_graph
(
self
):
def
_setup_graph
(
self
):
self
.
_handles
=
[]
self
.
_handles
=
[]
if
hasattr
(
self
.
trainer
,
'model'
)
:
if
self
.
trainer
.
_API_VERSION
==
1
:
# old Trainer API
# old Trainer API
input_callbacks
=
self
.
_input_source
.
setup
(
self
.
trainer
.
model
.
get_inputs_desc
())
input_callbacks
=
self
.
_input_source
.
setup
(
self
.
trainer
.
model
.
get_inputs_desc
())
# build each predict tower
# build each predict tower
...
@@ -222,8 +221,8 @@ class DataParallelInferenceRunner(InferenceRunnerBase):
...
@@ -222,8 +221,8 @@ class DataParallelInferenceRunner(InferenceRunnerBase):
SimplePredictBuilder
(
SimplePredictBuilder
(
ns_name
=
tower_name
,
ns_name
=
tower_name
,
vs_name
=
''
,
device
=
t
)
.
build
(
# TODO fix vs_name and maybe device
vs_name
=
''
,
device
=
t
)
.
build
(
# TODO fix vs_name and maybe device
self
.
_input_source
,
self
.
trainer
.
get_cost_fn
)
self
.
_input_source
,
self
.
trainer
.
tower_func
)
self
.
_handles
.
append
(
self
.
trainer
.
get_cost_fn
.
towers
[
-
1
])
self
.
_handles
.
append
(
self
.
trainer
.
tower_func
.
towers
[
-
1
])
# setup callbacks and hooks
# setup callbacks and hooks
self
.
_input_callbacks
=
Callbacks
(
input_callbacks
)
self
.
_input_callbacks
=
Callbacks
(
input_callbacks
)
...
...
tensorpack/tfutils/tower.py
View file @
e0b13533
...
@@ -154,7 +154,7 @@ class TowerFuncWrapper(object):
...
@@ -154,7 +154,7 @@ class TowerFuncWrapper(object):
each time the function is called.
each time the function is called.
"""
"""
def
__init__
(
self
,
tower_fn
,
inputs_desc
=
None
):
def
__init__
(
self
,
tower_fn
,
inputs_desc
):
"""
"""
Args:
Args:
tower_func: a function which builds one tower in the graph.
tower_func: a function which builds one tower in the graph.
...
@@ -168,7 +168,7 @@ class TowerFuncWrapper(object):
...
@@ -168,7 +168,7 @@ class TowerFuncWrapper(object):
self
.
_towers
=
[]
self
.
_towers
=
[]
def
__new__
(
cls
,
tower_fn
,
inputs_desc
=
None
):
def
__new__
(
cls
,
tower_fn
,
inputs_desc
):
# to avoid double-wrapping a function
# to avoid double-wrapping a function
if
isinstance
(
tower_fn
,
TowerFuncWrapper
):
if
isinstance
(
tower_fn
,
TowerFuncWrapper
):
return
tower_fn
return
tower_fn
...
@@ -188,6 +188,10 @@ class TowerFuncWrapper(object):
...
@@ -188,6 +188,10 @@ class TowerFuncWrapper(object):
# TODO another wrapper around towerhandlelist
# TODO another wrapper around towerhandlelist
return
self
.
_towers
return
self
.
_towers
@
property
def
inputs_desc
(
self
):
return
self
.
_inputs_desc
class
TowerTensorHandle
(
object
):
class
TowerTensorHandle
(
object
):
"""
"""
...
...
tensorpack/train/base.py
View file @
e0b13533
...
@@ -101,6 +101,8 @@ class Trainer(object):
...
@@ -101,6 +101,8 @@ class Trainer(object):
monitors (Monitors): the monitors. Other callbacks can use it for logging.
monitors (Monitors): the monitors. Other callbacks can use it for logging.
"""
"""
_API_VERSION
=
1
is_chief
=
True
is_chief
=
True
"""
"""
Whether this process is the chief worker in distributed training.
Whether this process is the chief worker in distributed training.
...
...
tensorpack/trainv2/base.py
View file @
e0b13533
...
@@ -31,6 +31,8 @@ class Trainer(object):
...
@@ -31,6 +31,8 @@ class Trainer(object):
""" Base class for a trainer.
""" Base class for a trainer.
"""
"""
_API_VERSION
=
2
is_chief
=
True
is_chief
=
True
def
__init__
(
self
):
def
__init__
(
self
):
...
@@ -215,8 +217,39 @@ for name in ['global_step', 'local_step', 'steps_per_epoch',
...
@@ -215,8 +217,39 @@ for name in ['global_step', 'local_step', 'steps_per_epoch',
setattr
(
Trainer
,
name
,
_get_property
(
name
))
setattr
(
Trainer
,
name
,
_get_property
(
name
))
class
TowerTrainer
(
Trainer
):
"""
Base trainers for models that can be built by calling a tower function under a :class:`TowerContext`.
This is required by some features that replicates the model
automatically, e.g. creating a predictor.
"""
tower_func
=
None
"""
A :class:`TowerFuncWrapper` instance.
A callable which takes some input tensors and builds one replicate of the model.
"""
@
call_only_once
def
set_tower_func
(
self
,
tower_func
):
"""
Args:
tower_func (TowerFuncWrapper)
"""
self
.
tower_func
=
tower_func
@
property
def
inputs_desc
(
self
):
"""
Returns:
list[InputDesc]: metainfo about the inputs to the tower.
"""
return
self
.
tower_func
.
inputs_desc
@
six
.
add_metaclass
(
ABCMeta
)
@
six
.
add_metaclass
(
ABCMeta
)
class
SingleCostTrainer
(
Trainer
):
class
SingleCostTrainer
(
T
owerT
rainer
):
"""
"""
Base class for single-cost trainer.
Base class for single-cost trainer.
...
@@ -261,12 +294,11 @@ class SingleCostTrainer(Trainer):
...
@@ -261,12 +294,11 @@ class SingleCostTrainer(Trainer):
"""
"""
get_cost_fn
=
TowerFuncWrapper
(
get_cost_fn
,
inputs_desc
)
get_cost_fn
=
TowerFuncWrapper
(
get_cost_fn
,
inputs_desc
)
get_opt_fn
=
memoized
(
get_opt_fn
)
get_opt_fn
=
memoized
(
get_opt_fn
)
self
.
set_tower_func
(
get_cost_fn
)
input_callbacks
=
self
.
_setup_input
(
inputs_desc
,
input
)
input_callbacks
=
self
.
_setup_input
(
inputs_desc
,
input
)
train_callbacks
=
self
.
_setup_graph
(
input
,
get_cost_fn
,
get_opt_fn
)
train_callbacks
=
self
.
_setup_graph
(
input
,
get_cost_fn
,
get_opt_fn
)
self
.
_internal_callbacks
=
input_callbacks
+
train_callbacks
self
.
_internal_callbacks
=
input_callbacks
+
train_callbacks
self
.
inputs_desc
=
inputs_desc
self
.
get_cost_fn
=
get_cost_fn
return
self
.
_internal_callbacks
return
self
.
_internal_callbacks
@
abstractmethod
@
abstractmethod
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment