Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
e0b13533
You need to sign in or sign up before continuing.
Commit
e0b13533
authored
Oct 19, 2017
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[Trainerv2] Add TowerTrainer on top of singlecost
parent
0f90d4c2
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
52 additions
and
15 deletions
+52
-15
tensorpack/callbacks/inference_runner.py
tensorpack/callbacks/inference_runner.py
+8
-9
tensorpack/tfutils/tower.py
tensorpack/tfutils/tower.py
+6
-2
tensorpack/train/base.py
tensorpack/train/base.py
+2
-0
tensorpack/trainv2/base.py
tensorpack/trainv2/base.py
+36
-4
No files found.
tensorpack/callbacks/inference_runner.py
View file @
e0b13533
...
@@ -20,7 +20,6 @@ from ..dataflow.base import DataFlow
...
@@ -20,7 +20,6 @@ from ..dataflow.base import DataFlow
from
..input_source
import
(
from
..input_source
import
(
InputSource
,
FeedInput
,
QueueInput
)
InputSource
,
FeedInput
,
QueueInput
)
from
..graph_builder.predictor_factory
import
SimplePredictBuilder
from
..graph_builder.predictor_factory
import
SimplePredictBuilder
# from ..trainv2 import SingleCostTrainer
from
.base
import
Callback
from
.base
import
Callback
from
.group
import
Callbacks
from
.group
import
Callbacks
...
@@ -125,7 +124,7 @@ class InferenceRunner(InferenceRunnerBase):
...
@@ -125,7 +124,7 @@ class InferenceRunner(InferenceRunnerBase):
return
InferencerToHook
(
inf
,
fetches
)
return
InferencerToHook
(
inf
,
fetches
)
def
_setup_graph
(
self
):
def
_setup_graph
(
self
):
if
hasattr
(
self
.
trainer
,
'model'
)
:
if
self
.
trainer
.
_API_VERSION
==
1
:
# old Trainer API
# old Trainer API
assert
self
.
trainer
.
model
is
not
None
assert
self
.
trainer
.
model
is
not
None
# Use predict_tower in train config. either gpuid or -1
# Use predict_tower in train config. either gpuid or -1
...
@@ -142,16 +141,16 @@ class InferenceRunner(InferenceRunnerBase):
...
@@ -142,16 +141,16 @@ class InferenceRunner(InferenceRunnerBase):
self
.
_tower_name
,
device
,
self
.
_input_source
)
self
.
_tower_name
,
device
,
self
.
_input_source
)
else
:
else
:
# new Trainer API
# new Trainer API
# only works for singlecost t
rainer
from
..trainv2
import
TowerT
rainer
# assert isinstance(self.trainer, SingleCost
Trainer), self.trainer
assert
isinstance
(
self
.
trainer
,
Tower
Trainer
),
self
.
trainer
input_callbacks
=
self
.
_input_source
.
setup
(
self
.
trainer
.
inputs_desc
)
input_callbacks
=
self
.
_input_source
.
setup
(
self
.
trainer
.
inputs_desc
)
with
tf
.
variable_scope
(
tf
.
get_variable_scope
(),
reuse
=
True
):
with
tf
.
variable_scope
(
tf
.
get_variable_scope
(),
reuse
=
True
):
SimplePredictBuilder
(
SimplePredictBuilder
(
ns_name
=
self
.
_tower_name
,
ns_name
=
self
.
_tower_name
,
vs_name
=
''
,
device
=
0
)
.
build
(
# TODO fix vs_name and maybe device
vs_name
=
''
,
device
=
0
)
.
build
(
# TODO fix vs_name and maybe device
self
.
_input_source
,
self
.
trainer
.
get_cost_fn
)
self
.
_input_source
,
self
.
trainer
.
tower_func
)
self
.
_tower_handle
=
self
.
trainer
.
get_cost_fn
.
towers
[
-
1
]
self
.
_tower_handle
=
self
.
trainer
.
tower_func
.
towers
[
-
1
]
self
.
_hooks
=
[
self
.
_build_hook
(
inf
)
for
inf
in
self
.
infs
]
self
.
_hooks
=
[
self
.
_build_hook
(
inf
)
for
inf
in
self
.
infs
]
# trigger_{step,epoch}, {before,after}_epoch is ignored.
# trigger_{step,epoch}, {before,after}_epoch is ignored.
...
@@ -202,7 +201,7 @@ class DataParallelInferenceRunner(InferenceRunnerBase):
...
@@ -202,7 +201,7 @@ class DataParallelInferenceRunner(InferenceRunnerBase):
def
_setup_graph
(
self
):
def
_setup_graph
(
self
):
self
.
_handles
=
[]
self
.
_handles
=
[]
if
hasattr
(
self
.
trainer
,
'model'
)
:
if
self
.
trainer
.
_API_VERSION
==
1
:
# old Trainer API
# old Trainer API
input_callbacks
=
self
.
_input_source
.
setup
(
self
.
trainer
.
model
.
get_inputs_desc
())
input_callbacks
=
self
.
_input_source
.
setup
(
self
.
trainer
.
model
.
get_inputs_desc
())
# build each predict tower
# build each predict tower
...
@@ -222,8 +221,8 @@ class DataParallelInferenceRunner(InferenceRunnerBase):
...
@@ -222,8 +221,8 @@ class DataParallelInferenceRunner(InferenceRunnerBase):
SimplePredictBuilder
(
SimplePredictBuilder
(
ns_name
=
tower_name
,
ns_name
=
tower_name
,
vs_name
=
''
,
device
=
t
)
.
build
(
# TODO fix vs_name and maybe device
vs_name
=
''
,
device
=
t
)
.
build
(
# TODO fix vs_name and maybe device
self
.
_input_source
,
self
.
trainer
.
get_cost_fn
)
self
.
_input_source
,
self
.
trainer
.
tower_func
)
self
.
_handles
.
append
(
self
.
trainer
.
get_cost_fn
.
towers
[
-
1
])
self
.
_handles
.
append
(
self
.
trainer
.
tower_func
.
towers
[
-
1
])
# setup callbacks and hooks
# setup callbacks and hooks
self
.
_input_callbacks
=
Callbacks
(
input_callbacks
)
self
.
_input_callbacks
=
Callbacks
(
input_callbacks
)
...
...
tensorpack/tfutils/tower.py
View file @
e0b13533
...
@@ -154,7 +154,7 @@ class TowerFuncWrapper(object):
...
@@ -154,7 +154,7 @@ class TowerFuncWrapper(object):
each time the function is called.
each time the function is called.
"""
"""
def
__init__
(
self
,
tower_fn
,
inputs_desc
=
None
):
def
__init__
(
self
,
tower_fn
,
inputs_desc
):
"""
"""
Args:
Args:
tower_func: a function which builds one tower in the graph.
tower_func: a function which builds one tower in the graph.
...
@@ -168,7 +168,7 @@ class TowerFuncWrapper(object):
...
@@ -168,7 +168,7 @@ class TowerFuncWrapper(object):
self
.
_towers
=
[]
self
.
_towers
=
[]
def
__new__
(
cls
,
tower_fn
,
inputs_desc
=
None
):
def
__new__
(
cls
,
tower_fn
,
inputs_desc
):
# to avoid double-wrapping a function
# to avoid double-wrapping a function
if
isinstance
(
tower_fn
,
TowerFuncWrapper
):
if
isinstance
(
tower_fn
,
TowerFuncWrapper
):
return
tower_fn
return
tower_fn
...
@@ -188,6 +188,10 @@ class TowerFuncWrapper(object):
...
@@ -188,6 +188,10 @@ class TowerFuncWrapper(object):
# TODO another wrapper around towerhandlelist
# TODO another wrapper around towerhandlelist
return
self
.
_towers
return
self
.
_towers
@
property
def
inputs_desc
(
self
):
return
self
.
_inputs_desc
class
TowerTensorHandle
(
object
):
class
TowerTensorHandle
(
object
):
"""
"""
...
...
tensorpack/train/base.py
View file @
e0b13533
...
@@ -101,6 +101,8 @@ class Trainer(object):
...
@@ -101,6 +101,8 @@ class Trainer(object):
monitors (Monitors): the monitors. Other callbacks can use it for logging.
monitors (Monitors): the monitors. Other callbacks can use it for logging.
"""
"""
_API_VERSION
=
1
is_chief
=
True
is_chief
=
True
"""
"""
Whether this process is the chief worker in distributed training.
Whether this process is the chief worker in distributed training.
...
...
tensorpack/trainv2/base.py
View file @
e0b13533
...
@@ -31,6 +31,8 @@ class Trainer(object):
...
@@ -31,6 +31,8 @@ class Trainer(object):
""" Base class for a trainer.
""" Base class for a trainer.
"""
"""
_API_VERSION
=
2
is_chief
=
True
is_chief
=
True
def
__init__
(
self
):
def
__init__
(
self
):
...
@@ -215,8 +217,39 @@ for name in ['global_step', 'local_step', 'steps_per_epoch',
...
@@ -215,8 +217,39 @@ for name in ['global_step', 'local_step', 'steps_per_epoch',
setattr
(
Trainer
,
name
,
_get_property
(
name
))
setattr
(
Trainer
,
name
,
_get_property
(
name
))
class
TowerTrainer
(
Trainer
):
"""
Base trainers for models that can be built by calling a tower function under a :class:`TowerContext`.
This is required by some features that replicates the model
automatically, e.g. creating a predictor.
"""
tower_func
=
None
"""
A :class:`TowerFuncWrapper` instance.
A callable which takes some input tensors and builds one replicate of the model.
"""
@
call_only_once
def
set_tower_func
(
self
,
tower_func
):
"""
Args:
tower_func (TowerFuncWrapper)
"""
self
.
tower_func
=
tower_func
@
property
def
inputs_desc
(
self
):
"""
Returns:
list[InputDesc]: metainfo about the inputs to the tower.
"""
return
self
.
tower_func
.
inputs_desc
@
six
.
add_metaclass
(
ABCMeta
)
@
six
.
add_metaclass
(
ABCMeta
)
class
SingleCostTrainer
(
Trainer
):
class
SingleCostTrainer
(
T
owerT
rainer
):
"""
"""
Base class for single-cost trainer.
Base class for single-cost trainer.
...
@@ -261,12 +294,11 @@ class SingleCostTrainer(Trainer):
...
@@ -261,12 +294,11 @@ class SingleCostTrainer(Trainer):
"""
"""
get_cost_fn
=
TowerFuncWrapper
(
get_cost_fn
,
inputs_desc
)
get_cost_fn
=
TowerFuncWrapper
(
get_cost_fn
,
inputs_desc
)
get_opt_fn
=
memoized
(
get_opt_fn
)
get_opt_fn
=
memoized
(
get_opt_fn
)
self
.
set_tower_func
(
get_cost_fn
)
input_callbacks
=
self
.
_setup_input
(
inputs_desc
,
input
)
input_callbacks
=
self
.
_setup_input
(
inputs_desc
,
input
)
train_callbacks
=
self
.
_setup_graph
(
input
,
get_cost_fn
,
get_opt_fn
)
train_callbacks
=
self
.
_setup_graph
(
input
,
get_cost_fn
,
get_opt_fn
)
self
.
_internal_callbacks
=
input_callbacks
+
train_callbacks
self
.
_internal_callbacks
=
input_callbacks
+
train_callbacks
self
.
inputs_desc
=
inputs_desc
self
.
get_cost_fn
=
get_cost_fn
return
self
.
_internal_callbacks
return
self
.
_internal_callbacks
@
abstractmethod
@
abstractmethod
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment