Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
ea0f1b90
Commit
ea0f1b90
authored
Oct 27, 2017
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
split trainer to tower.py
parent
0b8727b0
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
197 additions
and
167 deletions
+197
-167
docs/tutorial/trainer.md
docs/tutorial/trainer.md
+3
-2
docs/tutorial/training-interface.md
docs/tutorial/training-interface.md
+7
-7
tensorpack/train/base.py
tensorpack/train/base.py
+19
-157
tensorpack/train/tower.py
tensorpack/train/tower.py
+167
-0
tensorpack/train/trainers.py
tensorpack/train/trainers.py
+1
-1
No files found.
docs/tutorial/trainer.md
View file @
ea0f1b90
# Trainer
Tensorpack
trainers prepares and runs the training, which consists of the following
steps:
Tensorpack
follows the "define-and-run" paradigm. A training has two
steps:
1.
__Build graph__
for the model.
1.
Build graph
for the model.
Users can call whatever tensorflow functions to setup the graph.
Users may or may not use tensorpack
`InputSource`
,
`ModelDesc`
to build the graph.
This step defines "what to run" in every training step.
It can happen either inside or outside the trainer.
2.
Train the model (the
[
Trainer.train() method
](
http://tensorpack.readthedocs.io/en/latest/modules/train.html#tensorpack.train.Trainer.train
)
):
...
...
docs/tutorial/training-interface.md
View file @
ea0f1b90
# Training Interface
Tensorpack trainers
provide low-level API which requires a number of options to setup
.
There are
high-level interfaces built on top of trainer
to simplify the use,
Tensorpack trainers
have an interface for maximum flexibility
.
There are
also interfaces built on top of trainers
to simplify the use,
when you don't want to customize too much.
### With ModelDesc and TrainConfig
[
SingleCost trainers
](
trainer.html#single-cost-trainers
)
expects
`InputDesc`
,
`InputSource`
, get_cost function, and optimizer.
`ModelDesc`
describes a model by packing
three
of them together into one object:
expects
4 arguments to build the graph:
`InputDesc`
,
`InputSource`
, get_cost function, and optimizer.
`ModelDesc`
describes a model by packing
3
of them together into one object:
```
python
class
MyModel
(
ModelDesc
):
...
...
@@ -25,9 +25,9 @@ class MyModel(ModelDesc):
return
tf
.
train
.
GradientDescentOptimizer
(
0.1
)
```
`_get_inputs`
should define the metainfo of all the inputs your graph
may nee
d.
`_build_graph`
should add tensors/operations to the graph, where
the argument
`inputs`
is a list of
tensors which will match
`_get_inputs`
.
`_get_inputs`
should define the metainfo of all the inputs your graph
will take to buil
d.
`_build_graph`
takes a list of
`inputs`
tensors which will match
`_get_inputs`
.
You can use any symbolic functions in
`_build_graph`
, including TensorFlow core library
functions and other symbolic libraries.
...
...
tensorpack/train/base.py
View file @
ea0f1b90
...
...
@@ -7,7 +7,6 @@ import weakref
import
time
from
six.moves
import
range
import
six
from
abc
import
abstractmethod
,
ABCMeta
from
..callbacks
import
(
Callback
,
Callbacks
,
Monitors
,
TrainingMonitor
,
...
...
@@ -15,26 +14,29 @@ from ..callbacks import (
ProgressBar
,
MergeAllSummaries
,
TFEventWriter
,
JSONWriter
,
ScalarPrinter
,
RunUpdateOps
)
from
..utils
import
logger
from
..utils.argtools
import
call_only_once
,
memoized
from
..utils.argtools
import
call_only_once
from
..tfutils.tower
import
TowerFuncWrapper
from
..tfutils.model_utils
import
describe_trainable_vars
from
..tfutils.sessinit
import
JustCurrentSession
from
..tfutils.sesscreate
import
ReuseSessionCreator
,
NewSessionCreator
from
..tfutils.tower
import
TowerFuncWrapper
,
get_current_tower_context
from
..tfutils.gradproc
import
FilterNoneGrad
from
..callbacks.steps
import
MaintainStepCounter
from
..graph_builder.predictor_factory
import
SimplePredictBuilder
from
..input_source
import
PlaceholderInput
from
..predict.base
import
OnlinePredictor
import
tensorpack.trainv1
as
old_train
# noqa
from
..trainv1.base
import
StopTraining
,
TrainLoop
from
..trainv1.config
import
TrainConfig
__all__
=
[
'TrainConfig'
,
'Trainer'
,
'SingleCostTrainer'
,
'TowerTrainer'
]
__all__
=
[
'TrainConfig'
,
'Trainer'
]
def
DEFAULT_CALLBACKS
():
"""
Return the default callbacks. They are:
1. MovingAverageSummary()
2. ProgressBar()
3. MergeAllSummaries()
4. RunUpdateOps()
"""
return
[
MovingAverageSummary
(),
ProgressBar
(),
...
...
@@ -43,6 +45,13 @@ def DEFAULT_CALLBACKS():
def
DEFAULT_MONITORS
():
"""
Return the default monitors. They are:
1. TFEventWriter()
2. JSONWriter()
3. ScalarPrinter()
"""
return
[
TFEventWriter
(),
JSONWriter
(),
ScalarPrinter
()]
...
...
@@ -77,6 +86,7 @@ class Trainer(object):
self
.
_main_tower_vs_name
=
""
def
gp
(
input_names
,
output_names
,
tower
=
0
):
from
.tower
import
TowerTrainer
return
TowerTrainer
.
get_predictor
(
self
,
input_names
,
output_names
,
device
=
tower
)
self
.
get_predictor
=
gp
...
...
@@ -314,151 +324,3 @@ def _get_property(name):
for
name
in
[
'global_step'
,
'local_step'
,
'steps_per_epoch'
,
'epoch_num'
,
'starting_epoch'
,
'max_epoch'
]:
setattr
(
Trainer
,
name
,
_get_property
(
name
))
class
TowerTrainer
(
Trainer
):
"""
Base trainers for models that can be built by calling a tower function under a :class:`TowerContext`.
This is required by some features that replicates the model
automatically, e.g. creating a predictor.
"""
tower_func
=
None
"""
A :class:`TowerFuncWrapper` instance.
A callable which takes some input tensors and builds one replicate of the model.
"""
@
call_only_once
def
set_tower_func
(
self
,
tower_func
):
"""
Args:
tower_func (TowerFuncWrapper)
"""
assert
isinstance
(
tower_func
,
TowerFuncWrapper
),
tower_func
self
.
tower_func
=
tower_func
@
property
def
inputs_desc
(
self
):
"""
Returns:
list[InputDesc]: metainfo about the inputs to the tower.
"""
return
self
.
tower_func
.
inputs_desc
def
get_predictor
(
self
,
input_names
,
output_names
,
device
=
0
):
"""
Returns a callable predictor built under ``TowerContext(is_training=False)``.
Args:
input_names (list), output_names(list): list of names
device (int): build the predictor on device '/gpu:{device}' or use -1 for '/cpu:0'.
Returns:
an :class:`OnlinePredictor`.
"""
assert
self
.
tower_func
is
not
None
,
"Must set tower_func on the trainer to use get_predictor()!"
tower_name
=
'tower-pred-{}'
.
format
(
device
)
if
device
>=
0
else
'tower-pred-cpu'
try
:
tower
=
self
.
tower_func
.
towers
[
tower_name
]
except
KeyError
:
input
=
PlaceholderInput
()
input
.
setup
(
self
.
inputs_desc
)
with
tf
.
variable_scope
(
tf
.
get_variable_scope
(),
reuse
=
True
):
SimplePredictBuilder
(
ns_name
=
tower_name
,
vs_name
=
self
.
_main_tower_vs_name
,
device
=
device
)
.
build
(
input
,
self
.
tower_func
)
tower
=
self
.
tower_func
.
towers
[
tower_name
]
input_tensors
=
tower
.
get_tensors
(
input_names
)
output_tensors
=
tower
.
get_tensors
(
output_names
)
return
OnlinePredictor
(
input_tensors
,
output_tensors
)
@
property
def
_main_tower_vs_name
(
self
):
"""
The vs name for the "main" copy of the model,
to be used to build predictors.
"""
return
""
@
six
.
add_metaclass
(
ABCMeta
)
class
SingleCostTrainer
(
TowerTrainer
):
"""
Base class for single-cost trainer.
Single-cost trainer has a :meth:`setup_graph` method which takes
(inputs_desc, input, get_cost_fn, get_opt_fn), and build the training operations from them.
To use a :class:`SingleCostTrainer` object, call `trainer.setup_graph(...); trainer.train(...)`.
"""
@
call_only_once
def
setup_graph
(
self
,
inputs_desc
,
input
,
get_cost_fn
,
get_opt_fn
):
"""
Responsible for building the main training graph for single-cost training.
Args:
inputs_desc ([InputDesc]):
input (InputSource):
get_cost_fn ([tf.Tensor] -> tf.Tensor): callable, takes some input tenosrs and return a cost tensor.
get_opt_fn (-> tf.train.Optimizer): callable which returns an
optimizer. Will only be called once.
Note:
1. `get_cost_fn` will always be called under a :class:`TowerContext`.
which will contain information abouut reuse,
training/inference, scope name, etc.
2. `get_cost_fn` might get called multiple times for data-parallel training or inference.
3. To respect variable reuse, use `tf.get_variable` instead of
`tf.Variable` in `get_cost_fn`.
"""
get_cost_fn
=
TowerFuncWrapper
(
get_cost_fn
,
inputs_desc
)
get_opt_fn
=
memoized
(
get_opt_fn
)
self
.
set_tower_func
(
get_cost_fn
)
input_callbacks
=
self
.
_setup_input
(
inputs_desc
,
input
)
train_callbacks
=
self
.
_setup_graph
(
input
,
get_cost_fn
,
get_opt_fn
)
internal_callbacks
=
input_callbacks
+
train_callbacks
for
cb
in
internal_callbacks
:
self
.
_register_callback
(
cb
)
# TODO register directly instead of return?
@
abstractmethod
def
_setup_graph
(
self
,
input
,
get_cost_fn
,
get_opt_fn
):
"""
Implement the logic to build the graph, with an :class:`InputSource`
that's been setup already.
Returns:
[Callback]: list of callbacks needed
"""
def
_setup_input
(
self
,
inputs_desc
,
input
):
assert
not
input
.
setup_done
()
return
input
.
setup
(
inputs_desc
)
def
_make_get_grad_fn
(
self
,
input
,
get_cost_fn
,
get_opt_fn
):
"""
Returns:
a get_grad_fn for GraphBuilder to use.
"""
# internal use only
assert
input
.
setup_done
()
def
get_grad_fn
():
ctx
=
get_current_tower_context
()
cost
=
get_cost_fn
(
*
input
.
get_input_tensors
())
varlist
=
ctx
.
filter_vars_by_vs_name
(
tf
.
trainable_variables
())
opt
=
get_opt_fn
()
grads
=
opt
.
compute_gradients
(
cost
,
var_list
=
varlist
,
gate_gradients
=
False
,
colocate_gradients_with_ops
=
True
)
grads
=
FilterNoneGrad
()
.
process
(
grads
)
return
grads
return
get_grad_fn
tensorpack/train/tower.py
0 → 100644
View file @
ea0f1b90
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# File: tower.py
import
tensorflow
as
tf
import
six
from
abc
import
abstractmethod
,
ABCMeta
from
..utils.argtools
import
call_only_once
,
memoized
from
..graph_builder.predictor_factory
import
SimplePredictBuilder
from
..input_source
import
PlaceholderInput
from
..predict.base
import
OnlinePredictor
from
..tfutils.tower
import
TowerFuncWrapper
,
get_current_tower_context
from
..tfutils.gradproc
import
FilterNoneGrad
from
.base
import
Trainer
__all__
=
[
'SingleCostTrainer'
,
'TowerTrainer'
]
class
TowerTrainer
(
Trainer
):
"""
Base trainers for models that can be built by calling a tower function under a :class:`TowerContext`.
This is required by some features that replicates the model
automatically, e.g. creating a predictor.
"""
tower_func
=
None
"""
A :class:`TowerFuncWrapper` instance.
A callable which takes some input tensors and builds one replicate of the model.
"""
@
call_only_once
def
set_tower_func
(
self
,
tower_func
):
"""
Args:
tower_func (TowerFuncWrapper)
"""
assert
isinstance
(
tower_func
,
TowerFuncWrapper
),
tower_func
self
.
tower_func
=
tower_func
@
property
def
inputs_desc
(
self
):
"""
Returns:
list[InputDesc]: metainfo about the inputs to the tower.
"""
return
self
.
tower_func
.
inputs_desc
def
get_predictor
(
self
,
input_names
,
output_names
,
device
=
0
):
"""
Returns a callable predictor built under ``TowerContext(is_training=False)``.
Args:
input_names (list), output_names(list): list of names
device (int): build the predictor on device '/gpu:{device}' or use -1 for '/cpu:0'.
Returns:
an :class:`OnlinePredictor`.
"""
assert
self
.
tower_func
is
not
None
,
"Must set tower_func on the trainer to use get_predictor()!"
tower_name
=
'tower-pred-{}'
.
format
(
device
)
if
device
>=
0
else
'tower-pred-cpu'
try
:
tower
=
self
.
tower_func
.
towers
[
tower_name
]
except
KeyError
:
input
=
PlaceholderInput
()
input
.
setup
(
self
.
inputs_desc
)
with
tf
.
variable_scope
(
tf
.
get_variable_scope
(),
reuse
=
True
):
SimplePredictBuilder
(
ns_name
=
tower_name
,
vs_name
=
self
.
_main_tower_vs_name
,
device
=
device
)
.
build
(
input
,
self
.
tower_func
)
tower
=
self
.
tower_func
.
towers
[
tower_name
]
input_tensors
=
tower
.
get_tensors
(
input_names
)
output_tensors
=
tower
.
get_tensors
(
output_names
)
return
OnlinePredictor
(
input_tensors
,
output_tensors
)
@
property
def
_main_tower_vs_name
(
self
):
"""
The vs name for the "main" copy of the model,
to be used to build predictors.
"""
return
""
@
six
.
add_metaclass
(
ABCMeta
)
class
SingleCostTrainer
(
TowerTrainer
):
"""
Base class for single-cost trainer.
Single-cost trainer has a :meth:`setup_graph` method which takes
(inputs_desc, input, get_cost_fn, get_opt_fn), and build the training operations from them.
To use a :class:`SingleCostTrainer` object, call `trainer.setup_graph(...); trainer.train(...)`.
"""
@
call_only_once
def
setup_graph
(
self
,
inputs_desc
,
input
,
get_cost_fn
,
get_opt_fn
):
"""
Responsible for building the main training graph for single-cost training.
Args:
inputs_desc ([InputDesc]):
input (InputSource):
get_cost_fn ([tf.Tensor] -> tf.Tensor): callable, takes some input tenosrs and return a cost tensor.
get_opt_fn (-> tf.train.Optimizer): callable which returns an
optimizer. Will only be called once.
Note:
1. `get_cost_fn` will always be called under a :class:`TowerContext`.
which will contain information abouut reuse,
training/inference, scope name, etc.
2. `get_cost_fn` might get called multiple times for data-parallel training or inference.
3. To respect variable reuse, use `tf.get_variable` instead of
`tf.Variable` in `get_cost_fn`.
"""
get_cost_fn
=
TowerFuncWrapper
(
get_cost_fn
,
inputs_desc
)
get_opt_fn
=
memoized
(
get_opt_fn
)
self
.
set_tower_func
(
get_cost_fn
)
input_callbacks
=
self
.
_setup_input
(
inputs_desc
,
input
)
train_callbacks
=
self
.
_setup_graph
(
input
,
get_cost_fn
,
get_opt_fn
)
internal_callbacks
=
input_callbacks
+
train_callbacks
for
cb
in
internal_callbacks
:
self
.
_register_callback
(
cb
)
# TODO register directly instead of return?
@
abstractmethod
def
_setup_graph
(
self
,
input
,
get_cost_fn
,
get_opt_fn
):
"""
Implement the logic to build the graph, with an :class:`InputSource`
that's been setup already.
Returns:
[Callback]: list of callbacks needed
"""
def
_setup_input
(
self
,
inputs_desc
,
input
):
assert
not
input
.
setup_done
()
return
input
.
setup
(
inputs_desc
)
def
_make_get_grad_fn
(
self
,
input
,
get_cost_fn
,
get_opt_fn
):
"""
Returns:
a get_grad_fn for GraphBuilder to use.
"""
# internal use only
assert
input
.
setup_done
()
def
get_grad_fn
():
ctx
=
get_current_tower_context
()
cost
=
get_cost_fn
(
*
input
.
get_input_tensors
())
varlist
=
ctx
.
filter_vars_by_vs_name
(
tf
.
trainable_variables
())
opt
=
get_opt_fn
()
grads
=
opt
.
compute_gradients
(
cost
,
var_list
=
varlist
,
gate_gradients
=
False
,
colocate_gradients_with_ops
=
True
)
grads
=
FilterNoneGrad
()
.
process
(
grads
)
return
grads
return
get_grad_fn
tensorpack/train/trainers.py
View file @
ea0f1b90
...
...
@@ -21,7 +21,7 @@ from ..graph_builder.training import (
from
..graph_builder.distributed
import
DistributedReplicatedBuilder
from
..graph_builder.utils
import
override_to_local_variable
from
.
base
import
SingleCostTrainer
from
.
tower
import
SingleCostTrainer
__all__
=
[
'SimpleTrainer'
,
'QueueInputTrainer'
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment