Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
ea0f1b90
Commit
ea0f1b90
authored
Oct 27, 2017
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
split trainer to tower.py
parent
0b8727b0
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
197 additions
and
167 deletions
+197
-167
docs/tutorial/trainer.md
docs/tutorial/trainer.md
+3
-2
docs/tutorial/training-interface.md
docs/tutorial/training-interface.md
+7
-7
tensorpack/train/base.py
tensorpack/train/base.py
+19
-157
tensorpack/train/tower.py
tensorpack/train/tower.py
+167
-0
tensorpack/train/trainers.py
tensorpack/train/trainers.py
+1
-1
No files found.
docs/tutorial/trainer.md
View file @
ea0f1b90
# Trainer
# Trainer
Tensorpack
trainers prepares and runs the training, which consists of the following
steps:
Tensorpack
follows the "define-and-run" paradigm. A training has two
steps:
1.
__Build graph__
for the model.
1.
Build graph
for the model.
Users can call whatever tensorflow functions to setup the graph.
Users can call whatever tensorflow functions to setup the graph.
Users may or may not use tensorpack
`InputSource`
,
`ModelDesc`
to build the graph.
Users may or may not use tensorpack
`InputSource`
,
`ModelDesc`
to build the graph.
This step defines "what to run" in every training step.
This step defines "what to run" in every training step.
It can happen either inside or outside the trainer.
2.
Train the model (the
[
Trainer.train() method
](
http://tensorpack.readthedocs.io/en/latest/modules/train.html#tensorpack.train.Trainer.train
)
):
2.
Train the model (the
[
Trainer.train() method
](
http://tensorpack.readthedocs.io/en/latest/modules/train.html#tensorpack.train.Trainer.train
)
):
...
...
docs/tutorial/training-interface.md
View file @
ea0f1b90
# Training Interface
# Training Interface
Tensorpack trainers
provide low-level API which requires a number of options to setup
.
Tensorpack trainers
have an interface for maximum flexibility
.
There are
high-level interfaces built on top of trainer
to simplify the use,
There are
also interfaces built on top of trainers
to simplify the use,
when you don't want to customize too much.
when you don't want to customize too much.
### With ModelDesc and TrainConfig
### With ModelDesc and TrainConfig
[
SingleCost trainers
](
trainer.html#single-cost-trainers
)
[
SingleCost trainers
](
trainer.html#single-cost-trainers
)
expects
`InputDesc`
,
`InputSource`
, get_cost function, and optimizer.
expects
4 arguments to build the graph:
`InputDesc`
,
`InputSource`
, get_cost function, and optimizer.
`ModelDesc`
describes a model by packing
three
of them together into one object:
`ModelDesc`
describes a model by packing
3
of them together into one object:
```
python
```
python
class
MyModel
(
ModelDesc
):
class
MyModel
(
ModelDesc
):
...
@@ -25,9 +25,9 @@ class MyModel(ModelDesc):
...
@@ -25,9 +25,9 @@ class MyModel(ModelDesc):
return
tf
.
train
.
GradientDescentOptimizer
(
0.1
)
return
tf
.
train
.
GradientDescentOptimizer
(
0.1
)
```
```
`_get_inputs`
should define the metainfo of all the inputs your graph
may nee
d.
`_get_inputs`
should define the metainfo of all the inputs your graph
will take to buil
d.
`_build_graph`
should add tensors/operations to the graph, where
the argument
`inputs`
is a list of
tensors which will match
`_get_inputs`
.
`_build_graph`
takes a list of
`inputs`
tensors which will match
`_get_inputs`
.
You can use any symbolic functions in
`_build_graph`
, including TensorFlow core library
You can use any symbolic functions in
`_build_graph`
, including TensorFlow core library
functions and other symbolic libraries.
functions and other symbolic libraries.
...
...
tensorpack/train/base.py
View file @
ea0f1b90
...
@@ -7,7 +7,6 @@ import weakref
...
@@ -7,7 +7,6 @@ import weakref
import
time
import
time
from
six.moves
import
range
from
six.moves
import
range
import
six
import
six
from
abc
import
abstractmethod
,
ABCMeta
from
..callbacks
import
(
from
..callbacks
import
(
Callback
,
Callbacks
,
Monitors
,
TrainingMonitor
,
Callback
,
Callbacks
,
Monitors
,
TrainingMonitor
,
...
@@ -15,26 +14,29 @@ from ..callbacks import (
...
@@ -15,26 +14,29 @@ from ..callbacks import (
ProgressBar
,
MergeAllSummaries
,
ProgressBar
,
MergeAllSummaries
,
TFEventWriter
,
JSONWriter
,
ScalarPrinter
,
RunUpdateOps
)
TFEventWriter
,
JSONWriter
,
ScalarPrinter
,
RunUpdateOps
)
from
..utils
import
logger
from
..utils
import
logger
from
..utils.argtools
import
call_only_once
,
memoized
from
..utils.argtools
import
call_only_once
from
..tfutils.tower
import
TowerFuncWrapper
from
..tfutils.model_utils
import
describe_trainable_vars
from
..tfutils.model_utils
import
describe_trainable_vars
from
..tfutils.sessinit
import
JustCurrentSession
from
..tfutils.sessinit
import
JustCurrentSession
from
..tfutils.sesscreate
import
ReuseSessionCreator
,
NewSessionCreator
from
..tfutils.sesscreate
import
ReuseSessionCreator
,
NewSessionCreator
from
..tfutils.tower
import
TowerFuncWrapper
,
get_current_tower_context
from
..tfutils.gradproc
import
FilterNoneGrad
from
..callbacks.steps
import
MaintainStepCounter
from
..callbacks.steps
import
MaintainStepCounter
from
..graph_builder.predictor_factory
import
SimplePredictBuilder
from
..input_source
import
PlaceholderInput
from
..predict.base
import
OnlinePredictor
import
tensorpack.trainv1
as
old_train
# noqa
import
tensorpack.trainv1
as
old_train
# noqa
from
..trainv1.base
import
StopTraining
,
TrainLoop
from
..trainv1.base
import
StopTraining
,
TrainLoop
from
..trainv1.config
import
TrainConfig
from
..trainv1.config
import
TrainConfig
__all__
=
[
'TrainConfig'
,
'Trainer'
,
'SingleCostTrainer'
,
'TowerTrainer'
]
__all__
=
[
'TrainConfig'
,
'Trainer'
]
def
DEFAULT_CALLBACKS
():
def
DEFAULT_CALLBACKS
():
"""
Return the default callbacks. They are:
1. MovingAverageSummary()
2. ProgressBar()
3. MergeAllSummaries()
4. RunUpdateOps()
"""
return
[
return
[
MovingAverageSummary
(),
MovingAverageSummary
(),
ProgressBar
(),
ProgressBar
(),
...
@@ -43,6 +45,13 @@ def DEFAULT_CALLBACKS():
...
@@ -43,6 +45,13 @@ def DEFAULT_CALLBACKS():
def
DEFAULT_MONITORS
():
def
DEFAULT_MONITORS
():
"""
Return the default monitors. They are:
1. TFEventWriter()
2. JSONWriter()
3. ScalarPrinter()
"""
return
[
TFEventWriter
(),
JSONWriter
(),
ScalarPrinter
()]
return
[
TFEventWriter
(),
JSONWriter
(),
ScalarPrinter
()]
...
@@ -77,6 +86,7 @@ class Trainer(object):
...
@@ -77,6 +86,7 @@ class Trainer(object):
self
.
_main_tower_vs_name
=
""
self
.
_main_tower_vs_name
=
""
def
gp
(
input_names
,
output_names
,
tower
=
0
):
def
gp
(
input_names
,
output_names
,
tower
=
0
):
from
.tower
import
TowerTrainer
return
TowerTrainer
.
get_predictor
(
self
,
input_names
,
output_names
,
device
=
tower
)
return
TowerTrainer
.
get_predictor
(
self
,
input_names
,
output_names
,
device
=
tower
)
self
.
get_predictor
=
gp
self
.
get_predictor
=
gp
...
@@ -314,151 +324,3 @@ def _get_property(name):
...
@@ -314,151 +324,3 @@ def _get_property(name):
for
name
in
[
'global_step'
,
'local_step'
,
'steps_per_epoch'
,
for
name
in
[
'global_step'
,
'local_step'
,
'steps_per_epoch'
,
'epoch_num'
,
'starting_epoch'
,
'max_epoch'
]:
'epoch_num'
,
'starting_epoch'
,
'max_epoch'
]:
setattr
(
Trainer
,
name
,
_get_property
(
name
))
setattr
(
Trainer
,
name
,
_get_property
(
name
))
class
TowerTrainer
(
Trainer
):
"""
Base trainers for models that can be built by calling a tower function under a :class:`TowerContext`.
This is required by some features that replicates the model
automatically, e.g. creating a predictor.
"""
tower_func
=
None
"""
A :class:`TowerFuncWrapper` instance.
A callable which takes some input tensors and builds one replicate of the model.
"""
@
call_only_once
def
set_tower_func
(
self
,
tower_func
):
"""
Args:
tower_func (TowerFuncWrapper)
"""
assert
isinstance
(
tower_func
,
TowerFuncWrapper
),
tower_func
self
.
tower_func
=
tower_func
@
property
def
inputs_desc
(
self
):
"""
Returns:
list[InputDesc]: metainfo about the inputs to the tower.
"""
return
self
.
tower_func
.
inputs_desc
def
get_predictor
(
self
,
input_names
,
output_names
,
device
=
0
):
"""
Returns a callable predictor built under ``TowerContext(is_training=False)``.
Args:
input_names (list), output_names(list): list of names
device (int): build the predictor on device '/gpu:{device}' or use -1 for '/cpu:0'.
Returns:
an :class:`OnlinePredictor`.
"""
assert
self
.
tower_func
is
not
None
,
"Must set tower_func on the trainer to use get_predictor()!"
tower_name
=
'tower-pred-{}'
.
format
(
device
)
if
device
>=
0
else
'tower-pred-cpu'
try
:
tower
=
self
.
tower_func
.
towers
[
tower_name
]
except
KeyError
:
input
=
PlaceholderInput
()
input
.
setup
(
self
.
inputs_desc
)
with
tf
.
variable_scope
(
tf
.
get_variable_scope
(),
reuse
=
True
):
SimplePredictBuilder
(
ns_name
=
tower_name
,
vs_name
=
self
.
_main_tower_vs_name
,
device
=
device
)
.
build
(
input
,
self
.
tower_func
)
tower
=
self
.
tower_func
.
towers
[
tower_name
]
input_tensors
=
tower
.
get_tensors
(
input_names
)
output_tensors
=
tower
.
get_tensors
(
output_names
)
return
OnlinePredictor
(
input_tensors
,
output_tensors
)
@
property
def
_main_tower_vs_name
(
self
):
"""
The vs name for the "main" copy of the model,
to be used to build predictors.
"""
return
""
@
six
.
add_metaclass
(
ABCMeta
)
class
SingleCostTrainer
(
TowerTrainer
):
"""
Base class for single-cost trainer.
Single-cost trainer has a :meth:`setup_graph` method which takes
(inputs_desc, input, get_cost_fn, get_opt_fn), and build the training operations from them.
To use a :class:`SingleCostTrainer` object, call `trainer.setup_graph(...); trainer.train(...)`.
"""
@
call_only_once
def
setup_graph
(
self
,
inputs_desc
,
input
,
get_cost_fn
,
get_opt_fn
):
"""
Responsible for building the main training graph for single-cost training.
Args:
inputs_desc ([InputDesc]):
input (InputSource):
get_cost_fn ([tf.Tensor] -> tf.Tensor): callable, takes some input tenosrs and return a cost tensor.
get_opt_fn (-> tf.train.Optimizer): callable which returns an
optimizer. Will only be called once.
Note:
1. `get_cost_fn` will always be called under a :class:`TowerContext`.
which will contain information abouut reuse,
training/inference, scope name, etc.
2. `get_cost_fn` might get called multiple times for data-parallel training or inference.
3. To respect variable reuse, use `tf.get_variable` instead of
`tf.Variable` in `get_cost_fn`.
"""
get_cost_fn
=
TowerFuncWrapper
(
get_cost_fn
,
inputs_desc
)
get_opt_fn
=
memoized
(
get_opt_fn
)
self
.
set_tower_func
(
get_cost_fn
)
input_callbacks
=
self
.
_setup_input
(
inputs_desc
,
input
)
train_callbacks
=
self
.
_setup_graph
(
input
,
get_cost_fn
,
get_opt_fn
)
internal_callbacks
=
input_callbacks
+
train_callbacks
for
cb
in
internal_callbacks
:
self
.
_register_callback
(
cb
)
# TODO register directly instead of return?
@
abstractmethod
def
_setup_graph
(
self
,
input
,
get_cost_fn
,
get_opt_fn
):
"""
Implement the logic to build the graph, with an :class:`InputSource`
that's been setup already.
Returns:
[Callback]: list of callbacks needed
"""
def
_setup_input
(
self
,
inputs_desc
,
input
):
assert
not
input
.
setup_done
()
return
input
.
setup
(
inputs_desc
)
def
_make_get_grad_fn
(
self
,
input
,
get_cost_fn
,
get_opt_fn
):
"""
Returns:
a get_grad_fn for GraphBuilder to use.
"""
# internal use only
assert
input
.
setup_done
()
def
get_grad_fn
():
ctx
=
get_current_tower_context
()
cost
=
get_cost_fn
(
*
input
.
get_input_tensors
())
varlist
=
ctx
.
filter_vars_by_vs_name
(
tf
.
trainable_variables
())
opt
=
get_opt_fn
()
grads
=
opt
.
compute_gradients
(
cost
,
var_list
=
varlist
,
gate_gradients
=
False
,
colocate_gradients_with_ops
=
True
)
grads
=
FilterNoneGrad
()
.
process
(
grads
)
return
grads
return
get_grad_fn
tensorpack/train/tower.py
0 → 100644
View file @
ea0f1b90
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# File: tower.py
import
tensorflow
as
tf
import
six
from
abc
import
abstractmethod
,
ABCMeta
from
..utils.argtools
import
call_only_once
,
memoized
from
..graph_builder.predictor_factory
import
SimplePredictBuilder
from
..input_source
import
PlaceholderInput
from
..predict.base
import
OnlinePredictor
from
..tfutils.tower
import
TowerFuncWrapper
,
get_current_tower_context
from
..tfutils.gradproc
import
FilterNoneGrad
from
.base
import
Trainer
__all__
=
[
'SingleCostTrainer'
,
'TowerTrainer'
]
class
TowerTrainer
(
Trainer
):
"""
Base trainers for models that can be built by calling a tower function under a :class:`TowerContext`.
This is required by some features that replicates the model
automatically, e.g. creating a predictor.
"""
tower_func
=
None
"""
A :class:`TowerFuncWrapper` instance.
A callable which takes some input tensors and builds one replicate of the model.
"""
@
call_only_once
def
set_tower_func
(
self
,
tower_func
):
"""
Args:
tower_func (TowerFuncWrapper)
"""
assert
isinstance
(
tower_func
,
TowerFuncWrapper
),
tower_func
self
.
tower_func
=
tower_func
@
property
def
inputs_desc
(
self
):
"""
Returns:
list[InputDesc]: metainfo about the inputs to the tower.
"""
return
self
.
tower_func
.
inputs_desc
def
get_predictor
(
self
,
input_names
,
output_names
,
device
=
0
):
"""
Returns a callable predictor built under ``TowerContext(is_training=False)``.
Args:
input_names (list), output_names(list): list of names
device (int): build the predictor on device '/gpu:{device}' or use -1 for '/cpu:0'.
Returns:
an :class:`OnlinePredictor`.
"""
assert
self
.
tower_func
is
not
None
,
"Must set tower_func on the trainer to use get_predictor()!"
tower_name
=
'tower-pred-{}'
.
format
(
device
)
if
device
>=
0
else
'tower-pred-cpu'
try
:
tower
=
self
.
tower_func
.
towers
[
tower_name
]
except
KeyError
:
input
=
PlaceholderInput
()
input
.
setup
(
self
.
inputs_desc
)
with
tf
.
variable_scope
(
tf
.
get_variable_scope
(),
reuse
=
True
):
SimplePredictBuilder
(
ns_name
=
tower_name
,
vs_name
=
self
.
_main_tower_vs_name
,
device
=
device
)
.
build
(
input
,
self
.
tower_func
)
tower
=
self
.
tower_func
.
towers
[
tower_name
]
input_tensors
=
tower
.
get_tensors
(
input_names
)
output_tensors
=
tower
.
get_tensors
(
output_names
)
return
OnlinePredictor
(
input_tensors
,
output_tensors
)
@
property
def
_main_tower_vs_name
(
self
):
"""
The vs name for the "main" copy of the model,
to be used to build predictors.
"""
return
""
@
six
.
add_metaclass
(
ABCMeta
)
class
SingleCostTrainer
(
TowerTrainer
):
"""
Base class for single-cost trainer.
Single-cost trainer has a :meth:`setup_graph` method which takes
(inputs_desc, input, get_cost_fn, get_opt_fn), and build the training operations from them.
To use a :class:`SingleCostTrainer` object, call `trainer.setup_graph(...); trainer.train(...)`.
"""
@
call_only_once
def
setup_graph
(
self
,
inputs_desc
,
input
,
get_cost_fn
,
get_opt_fn
):
"""
Responsible for building the main training graph for single-cost training.
Args:
inputs_desc ([InputDesc]):
input (InputSource):
get_cost_fn ([tf.Tensor] -> tf.Tensor): callable, takes some input tenosrs and return a cost tensor.
get_opt_fn (-> tf.train.Optimizer): callable which returns an
optimizer. Will only be called once.
Note:
1. `get_cost_fn` will always be called under a :class:`TowerContext`.
which will contain information abouut reuse,
training/inference, scope name, etc.
2. `get_cost_fn` might get called multiple times for data-parallel training or inference.
3. To respect variable reuse, use `tf.get_variable` instead of
`tf.Variable` in `get_cost_fn`.
"""
get_cost_fn
=
TowerFuncWrapper
(
get_cost_fn
,
inputs_desc
)
get_opt_fn
=
memoized
(
get_opt_fn
)
self
.
set_tower_func
(
get_cost_fn
)
input_callbacks
=
self
.
_setup_input
(
inputs_desc
,
input
)
train_callbacks
=
self
.
_setup_graph
(
input
,
get_cost_fn
,
get_opt_fn
)
internal_callbacks
=
input_callbacks
+
train_callbacks
for
cb
in
internal_callbacks
:
self
.
_register_callback
(
cb
)
# TODO register directly instead of return?
@
abstractmethod
def
_setup_graph
(
self
,
input
,
get_cost_fn
,
get_opt_fn
):
"""
Implement the logic to build the graph, with an :class:`InputSource`
that's been setup already.
Returns:
[Callback]: list of callbacks needed
"""
def
_setup_input
(
self
,
inputs_desc
,
input
):
assert
not
input
.
setup_done
()
return
input
.
setup
(
inputs_desc
)
def
_make_get_grad_fn
(
self
,
input
,
get_cost_fn
,
get_opt_fn
):
"""
Returns:
a get_grad_fn for GraphBuilder to use.
"""
# internal use only
assert
input
.
setup_done
()
def
get_grad_fn
():
ctx
=
get_current_tower_context
()
cost
=
get_cost_fn
(
*
input
.
get_input_tensors
())
varlist
=
ctx
.
filter_vars_by_vs_name
(
tf
.
trainable_variables
())
opt
=
get_opt_fn
()
grads
=
opt
.
compute_gradients
(
cost
,
var_list
=
varlist
,
gate_gradients
=
False
,
colocate_gradients_with_ops
=
True
)
grads
=
FilterNoneGrad
()
.
process
(
grads
)
return
grads
return
get_grad_fn
tensorpack/train/trainers.py
View file @
ea0f1b90
...
@@ -21,7 +21,7 @@ from ..graph_builder.training import (
...
@@ -21,7 +21,7 @@ from ..graph_builder.training import (
from
..graph_builder.distributed
import
DistributedReplicatedBuilder
from
..graph_builder.distributed
import
DistributedReplicatedBuilder
from
..graph_builder.utils
import
override_to_local_variable
from
..graph_builder.utils
import
override_to_local_variable
from
.
base
import
SingleCostTrainer
from
.
tower
import
SingleCostTrainer
__all__
=
[
'SimpleTrainer'
,
__all__
=
[
'SimpleTrainer'
,
'QueueInputTrainer'
,
'QueueInputTrainer'
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment