Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
ba4e3178
Commit
ba4e3178
authored
Oct 21, 2017
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
[Trainerv2] Swap trainer directory. change two examples.
parent
9268bc8c
Changes
17
Hide whitespace changes
Inline
Side-by-side
Showing
17 changed files
with
668 additions
and
202 deletions
+668
-202
CHANGES.md
CHANGES.md
+8
-0
docs/tutorial/trainer.md
docs/tutorial/trainer.md
+6
-4
examples/cifar-convnet.py
examples/cifar-convnet.py
+6
-6
examples/mnist-convnet.py
examples/mnist-convnet.py
+2
-1
examples/tox.ini
examples/tox.ini
+1
-1
tensorpack/__init__.py
tensorpack/__init__.py
+2
-2
tensorpack/train/base.py
tensorpack/train/base.py
+251
-185
tensorpack/train/interface.py
tensorpack/train/interface.py
+2
-2
tensorpack/train/trainers.py
tensorpack/train/trainers.py
+12
-0
tensorpack/trainv1/__init__.py
tensorpack/trainv1/__init__.py
+1
-1
tensorpack/trainv1/base.py
tensorpack/trainv1/base.py
+358
-0
tensorpack/trainv1/config.py
tensorpack/trainv1/config.py
+0
-0
tensorpack/trainv1/distributed.py
tensorpack/trainv1/distributed.py
+0
-0
tensorpack/trainv1/interface.py
tensorpack/trainv1/interface.py
+11
-0
tensorpack/trainv1/multigpu.py
tensorpack/trainv1/multigpu.py
+0
-0
tensorpack/trainv1/simple.py
tensorpack/trainv1/simple.py
+0
-0
tensorpack/trainv1/utility.py
tensorpack/trainv1/utility.py
+8
-0
No files found.
CHANGES.md
View file @
ba4e3178
...
@@ -8,6 +8,14 @@ so you won't need to look at here very often.
...
@@ -8,6 +8,14 @@ so you won't need to look at here very often.
Here are a list of things that were changed, starting from an early version.
Here are a list of things that were changed, starting from an early version.
TensorFlow itself also changed APIs before 1.0 and those are not listed here.
TensorFlow itself also changed APIs before 1.0 and those are not listed here.
+
[2017/10/21]
tensorpack is gradually switching to a new Trainer API.
Compatibility is kept in most ways but not guaranteed.
To switch to new API, the easiest way is to:
1. `export TENSORPACK_TRAIN_API=v2` (will be default in the future).
2. Replace `SomeTrainer(config, ...).train()` with `launch_train_with_config(config, SomeTrainer(...))`.
+
[2017/10/18]
+
[2017/10/18]
`TrainConfig(predict_tower)`
was deprecated. You can set the inference device directly when creating the
`InferenceRunner`
callback.
`TrainConfig(predict_tower)`
was deprecated. You can set the inference device directly when creating the
`InferenceRunner`
callback.
+
[
2017/10/12
](
https://github.com/ppwwyyxx/tensorpack/commit/7e963996f615b85f7459455596b4ee9bbd0bce8e
)
.
+
[
2017/10/12
](
https://github.com/ppwwyyxx/tensorpack/commit/7e963996f615b85f7459455596b4ee9bbd0bce8e
)
.
...
...
docs/tutorial/trainer.md
View file @
ba4e3178
...
@@ -22,9 +22,11 @@ In other words, an "epoch" in tensorpack is the __default period to run callback
...
@@ -22,9 +22,11 @@ In other words, an "epoch" in tensorpack is the __default period to run callback
### Common Trainers
### Common Trainers
Most neural network training tasks are single-cost optimization.
<!--
Tensorpack provides some trainer implementations for such tasks.
-Most neural network training tasks are single-cost optimization.
These trainers will build the graph based on the given
`ModelDesc`
, and minimizes
`ModelDesc.cost`
.
-Tensorpack provides some trainer implementations for such tasks.
-These trainers will build the graph based on the given
`ModelDesc`
, and minimizes
`ModelDesc.cost`
.
-->
<!--
<!--
-To use trainers, pass a
`TrainConfig`
to configure them:
-To use trainers, pass a
`TrainConfig`
to configure them:
...
@@ -49,7 +51,7 @@ These trainers will build the graph based on the given `ModelDesc`, and minimize
...
@@ -49,7 +51,7 @@ These trainers will build the graph based on the given `ModelDesc`, and minimize
-in the
[
Input Pipeline
](
input-source.html
)
tutorial.
-in the
[
Input Pipeline
](
input-source.html
)
tutorial.
-You can set the InputSource instead, to customize this behavior.
-You can set the InputSource instead, to customize this behavior.
-->
-->
Trainers are being redesigned,
so the recommended API will likely be chang
ed soon.
Trainers are being redesigned,
this page will be updat
ed soon.
Existing multi-GPU trainers include the logic of data-parallel training.
Existing multi-GPU trainers include the logic of data-parallel training.
You can enable them by just one line, and all the necessary logic to achieve the best performance was baked into the trainers already.
You can enable them by just one line, and all the necessary logic to achieve the best performance was baked into the trainers already.
...
...
examples/cifar-convnet.py
View file @
ba4e3178
...
@@ -2,12 +2,13 @@
...
@@ -2,12 +2,13 @@
# -*- coding: UTF-8 -*-
# -*- coding: UTF-8 -*-
# File: cifar-convnet.py
# File: cifar-convnet.py
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
from
tensorpack
import
*
import
tensorflow
as
tf
import
tensorflow
as
tf
import
argparse
import
argparse
import
numpy
as
np
import
numpy
as
np
import
os
import
os
os
.
environ
[
'TENSORPACK_TRAIN_API'
]
=
'v2'
# will become default soon
from
tensorpack
import
*
import
tensorpack.tfutils.symbolic_functions
as
symbf
import
tensorpack.tfutils.symbolic_functions
as
symbf
from
tensorpack.tfutils.summary
import
*
from
tensorpack.tfutils.summary
import
*
from
tensorpack.dataflow
import
dataset
from
tensorpack.dataflow
import
dataset
...
@@ -151,8 +152,7 @@ if __name__ == '__main__':
...
@@ -151,8 +152,7 @@ if __name__ == '__main__':
if
args
.
load
:
if
args
.
load
:
config
.
session_init
=
SaverRestore
(
args
.
load
)
config
.
session_init
=
SaverRestore
(
args
.
load
)
config
.
nr_tower
=
max
(
len
(
args
.
gpu
.
split
(
','
)),
1
)
nr_gpu
=
len
(
args
.
gpu
.
split
(
','
))
if
config
.
nr_tower
<=
1
:
trainer
=
QueueInputTrainer
()
if
nr_gpu
<=
1
\
QueueInputTrainer
(
config
)
.
train
()
else
SyncMultiGPUTrainerParameterServer
(
list
(
range
(
nr_gpu
)))
else
:
launch_train_with_config
(
config
,
trainer
)
SyncMultiGPUTrainerParameterServer
(
config
)
.
train
()
examples/mnist-convnet.py
View file @
ba4e3178
...
@@ -12,6 +12,7 @@ MNIST ConvNet example.
...
@@ -12,6 +12,7 @@ MNIST ConvNet example.
about 0.6
%
validation error after 30 epochs.
about 0.6
%
validation error after 30 epochs.
"""
"""
os
.
environ
[
'TENSORPACK_TRAIN_API'
]
=
'v2'
# will become default soon
# Just import everything into current namespace
# Just import everything into current namespace
from
tensorpack
import
*
from
tensorpack
import
*
from
tensorpack.tfutils
import
summary
from
tensorpack.tfutils
import
summary
...
@@ -142,4 +143,4 @@ if __name__ == '__main__':
...
@@ -142,4 +143,4 @@ if __name__ == '__main__':
config
.
session_init
=
SaverRestore
(
args
.
load
)
config
.
session_init
=
SaverRestore
(
args
.
load
)
# SimpleTrainer is slow, this is just a demo.
# SimpleTrainer is slow, this is just a demo.
# You can use QueueInputTrainer instead
# You can use QueueInputTrainer instead
SimpleTrainer
(
config
)
.
train
(
)
launch_train_with_config
(
config
,
SimpleTrainer
()
)
examples/tox.ini
View file @
ba4e3178
[flake8]
[flake8]
max-line-length
=
120
max-line-length
=
120
ignore
=
F403,F401,F405,F841,E401
ignore
=
F403,F401,F405,F841,E401
,E402
exclude
=
private,
exclude
=
private,
FasterRCNN/utils
FasterRCNN/utils
tensorpack/__init__.py
View file @
ba4e3178
...
@@ -18,9 +18,9 @@ if _HAS_TF:
...
@@ -18,9 +18,9 @@ if _HAS_TF:
# In development. Default to v1
# In development. Default to v1
if
_os
.
environ
.
get
(
'TENSORPACK_TRAIN_API'
,
'v1'
)
==
'v2'
:
if
_os
.
environ
.
get
(
'TENSORPACK_TRAIN_API'
,
'v1'
)
==
'v2'
:
from
tensorpack.trainv2
import
*
else
:
from
tensorpack.train
import
*
from
tensorpack.train
import
*
else
:
from
tensorpack.trainv1
import
*
from
tensorpack.graph_builder
import
InputDesc
,
ModelDesc
,
ModelDescBase
from
tensorpack.graph_builder
import
InputDesc
,
ModelDesc
,
ModelDescBase
from
tensorpack.input_source
import
*
from
tensorpack.input_source
import
*
from
tensorpack.predict
import
*
from
tensorpack.predict
import
*
tensorpack/train/base.py
View file @
ba4e3178
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# -*- coding: utf-8 -*-
# File: base.py
# File: base.py
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
import
t
ime
import
t
ensorflow
as
tf
import
weakref
import
weakref
import
six
import
time
from
six.moves
import
range
from
six.moves
import
range
import
six
from
abc
import
abstractmethod
,
ABCMeta
import
tensorflow
as
tf
from
.config
import
TrainConfig
from
..utils
import
logger
from
..utils
import
logger
from
..utils.
develop
import
log_deprecat
ed
from
..utils.
argtools
import
call_only_once
,
memoiz
ed
from
..callbacks
import
Callback
,
Callbacks
from
..callbacks
import
Callback
,
Callbacks
from
..callbacks.monitor
import
Monitors
,
TrainingMonitor
from
..callbacks.monitor
import
Monitors
,
TrainingMonitor
from
..tfutils
import
get_global_step_value
from
..tfutils.model_utils
import
describe_trainable_vars
from
..tfutils.model_utils
import
describe_trainable_vars
from
..tfutils.sesscreate
import
ReuseSessionCreator
from
..tfutils.sessinit
import
JustCurrentSession
from
..tfutils.sessinit
import
JustCurrentSession
from
..tfutils.tower
import
TowerFuncWrapper
from
..tfutils.sesscreate
import
ReuseSessionCreator
from
..tfutils.tower
import
TowerFuncWrapper
,
get_current_tower_context
from
..tfutils.gradproc
import
FilterNoneGrad
from
..callbacks.steps
import
MaintainStepCounter
from
..input_source
import
PlaceholderInput
from
..graph_builder.predictor_factory
import
SimplePredictBuilder
from
..graph_builder.predictor_factory
import
SimplePredictBuilder
from
..input_source
import
PlaceholderInput
from
..predict.base
import
OnlinePredictor
from
..predict.base
import
OnlinePredictor
from
..callbacks.steps
import
MaintainStepCounter
__all__
=
[
'Trainer'
,
'StopTraining'
]
class
StopTraining
(
BaseException
):
"""
An exception thrown to stop training.
"""
pass
class
TrainLoop
(
object
):
import
tensorpack.trainv1
as
old_train
# noqa
"""
from
..trainv1.base
import
StopTraining
,
TrainLoop
Manage the double for loop.
from
..trainv1.config
import
TrainConfig
"""
def
__init__
(
self
):
self
.
_epoch_num
=
0
self
.
_global_step
=
0
self
.
_local_step
=
-
1
def
config
(
self
,
steps_per_epoch
,
starting_epoch
,
max_epoch
):
"""
Configure the loop given the settings.
"""
self
.
starting_epoch
=
starting_epoch
self
.
max_epoch
=
max_epoch
self
.
steps_per_epoch
=
steps_per_epoch
self
.
_epoch_num
=
starting_epoch
-
1
__all__
=
[
'TrainConfig'
,
'Trainer'
,
'SingleCostTrainer'
,
'TowerTrainer'
]
def
update_global_step
(
self
):
"""
Update the Python-side global_step from TF.
This must be called under initialized default session.
"""
self
.
_global_step
=
get_global_step_value
()
@
property
def
epoch_num
(
self
):
"""
The number of the currently ongoing epoch.
An epoch is defined to cover the moment before calling `before_epoch` until after calling `trigger_epoch`.
i.e., in the `trigger_epoch` of epoch 3, `self.epoch_num` is 3.
If you need use `self.epoch_num` in your callback, you'll need to know this.
"""
return
self
.
_epoch_num
@
property
def
global_step
(
self
):
"""
The tensorflow global_step, i.e. how many times ``hooked_sess.run`` has been called.
Note:
1. global_step is incremented **after** each ``hooked_sess.run`` returns from TF runtime.
2. If you make zero or more than one calls to ``hooked_sess.run`` in one
:meth:`run_step`, local_step and global_step may increment at different speed.
"""
return
self
.
_global_step
@
property
def
local_step
(
self
):
"""
The number of steps that have finished in the current epoch.
"""
return
self
.
_local_step
class
Trainer
(
object
):
class
Trainer
(
object
):
""" Base class for a trainer.
""" Base class for a trainer.
Attributes:
config (TrainConfig): the config used in this trainer.
model (ModelDesc): alias for ``config.model``.
sess (tf.Session): the current session in use.
hooked_sess (tf.train.MonitoredSession): the session with hooks.
monitors (Monitors): the monitors. Other callbacks can use it for logging.
"""
"""
_API_VERSION
=
1
_API_VERSION
=
2
is_chief
=
True
is_chief
=
True
"""
Whether this process is the chief worker in distributed training.
Only chief worker will run some callbacks.
"""
def
__init__
(
self
,
config
):
def
__init__
(
self
,
config
=
None
):
"""
"""
Args:
config is only for compatibility reasons in case you're
config (TrainConfig): the train config.
using custom trainers with old-style API.
You should never use config.
"""
"""
assert
isinstance
(
config
,
TrainConfig
),
type
(
config
)
self
.
_config
=
config
self
.
model
=
config
.
model
if
self
.
model
is
not
None
:
def
f
(
*
inputs
):
self
.
model
.
build_graph
(
inputs
)
"""
Only to mimic new trainer interafce on inference.
"""
self
.
inputs_desc
=
self
.
model
.
get_inputs_desc
()
self
.
tower_func
=
TowerFuncWrapper
(
f
,
self
.
inputs_desc
)
self
.
_callbacks
=
[]
self
.
_callbacks
=
[]
self
.
_monitors
=
[]
self
.
loop
=
TrainLoop
()
self
.
loop
=
TrainLoop
()
self
.
loop
.
config
(
config
.
steps_per_epoch
,
config
.
starting_epoch
,
config
.
max_epoch
)
self
.
_monitors
=
[]
# Clarify the type. Don't change from list to monitors.
self
.
_setup
()
# subclass will setup the graph and InputSource
# Hacks!
if
config
is
not
None
:
def
register_callback
(
self
,
cb
):
logger
.
warn
(
"You're initializing new trainer with old trainer API!"
)
logger
.
warn
(
"This could happen if you wrote a custom trainer before."
)
logger
.
warn
(
"It may work now through some hacks, but please switch to the new API!"
)
self
.
_config
=
config
self
.
inputs_desc
=
config
.
model
.
get_inputs_desc
()
self
.
tower_func
=
TowerFuncWrapper
(
lambda
*
inputs
:
config
.
model
.
build_graph
(
inputs
),
self
.
inputs_desc
)
self
.
_main_tower_vs_name
=
""
def
gp
(
input_names
,
output_names
,
tower
=
0
):
return
TowerTrainer
.
get_predictor
(
self
,
input_names
,
output_names
,
device
=
tower
)
self
.
get_predictor
=
gp
old_train
=
self
.
train
def
train
():
return
old_train
(
config
.
callbacks
,
config
.
monitors
,
config
.
session_creator
,
config
.
session_init
,
config
.
steps_per_epoch
,
config
.
starting_epoch
,
config
.
max_epoch
)
self
.
train
=
train
def
_register_callback
(
self
,
cb
):
"""
"""
Register a callback to the trainer.
Register a callback to the trainer.
It can only be called before :meth:`Trainer.train` gets called.
It can only be called before :meth:`Trainer.train` gets called.
...
@@ -151,7 +87,7 @@ class Trainer(object):
...
@@ -151,7 +87,7 @@ class Trainer(object):
else
:
else
:
self
.
_callbacks
.
append
(
cb
)
self
.
_callbacks
.
append
(
cb
)
def
register_monitor
(
self
,
mon
):
def
_
register_monitor
(
self
,
mon
):
"""
"""
Register a monitor to the trainer.
Register a monitor to the trainer.
It can only be called before :meth:`Trainer.train` gets called.
It can only be called before :meth:`Trainer.train` gets called.
...
@@ -162,18 +98,7 @@ class Trainer(object):
...
@@ -162,18 +98,7 @@ class Trainer(object):
if
not
self
.
is_chief
and
mon
.
chief_only
:
if
not
self
.
is_chief
and
mon
.
chief_only
:
logger
.
warn
(
"Monitor {} is chief-only, skipped."
.
format
(
str
(
mon
)))
logger
.
warn
(
"Monitor {} is chief-only, skipped."
.
format
(
str
(
mon
)))
else
:
else
:
self
.
_monitors
.
append
(
mon
)
self
.
_register_callback
(
mon
)
self
.
register_callback
(
mon
)
@
property
def
monitors
(
self
):
assert
isinstance
(
self
.
_monitors
,
Monitors
),
"Monitors haven't been setup!"
return
self
.
_monitors
def
train
(
self
):
""" Start training """
self
.
setup
()
self
.
main_loop
()
def
run_step
(
self
):
def
run_step
(
self
):
"""
"""
...
@@ -189,32 +114,44 @@ class Trainer(object):
...
@@ -189,32 +114,44 @@ class Trainer(object):
"of Trainer.run_step()!"
)
"of Trainer.run_step()!"
)
self
.
hooked_sess
.
run
(
self
.
train_op
)
self
.
hooked_sess
.
run
(
self
.
train_op
)
def
setup
(
self
):
@
call_only_once
def
setup_callbacks
(
self
,
callbacks
,
monitors
):
"""
"""
Setup
the trainer and be ready for the main loop
.
Setup
callbacks and monitors. Must be called after the main graph is built
.
"""
"""
self
.
register_callback
(
MaintainStepCounter
())
describe_trainable_vars
()
# TODO weird
for
cb
in
self
.
_config
.
callbacks
:
self
.
register_callback
(
cb
)
for
m
in
self
.
_config
.
monitors
:
self
.
register_monitor
(
m
)
self
.
_monitors
=
Monitors
(
self
.
_monitors
)
self
.
register_callback
(
self
.
_monitors
)
describe_trainable_vars
()
self
.
_register_callback
(
MaintainStepCounter
())
for
cb
in
callbacks
:
self
.
_register_callback
(
cb
)
for
m
in
monitors
:
self
.
_register_monitor
(
m
)
self
.
monitors
=
Monitors
(
monitors
)
self
.
_register_callback
(
self
.
monitors
)
# monitors is also a callback
# some final operations that might modify the graph
# some final operations that might modify the graph
logger
.
info
(
"Setup callbacks graph ..."
)
logger
.
info
(
"Setup callbacks graph ..."
)
self
.
_callbacks
=
Callbacks
(
self
.
_callbacks
)
self
.
_callbacks
=
Callbacks
(
self
.
_callbacks
)
self
.
_callbacks
.
setup_graph
(
weakref
.
proxy
(
self
))
self
.
_callbacks
.
setup_graph
(
weakref
.
proxy
(
self
))
self
.
_config
.
session_init
.
_setup_graph
()
@
call_only_once
def
initialize
(
self
,
session_creator
,
session_init
):
"""
Initialize self.sess and self.hooked_sess.
Must be called after callbacks are setup.
"""
session_init
.
_setup_graph
()
logger
.
info
(
"Creating the session ..."
)
logger
.
info
(
"Creating the session ..."
)
self
.
_create_session
()
hooks
=
self
.
_callbacks
.
get_hooks
()
self
.
sess
=
session_creator
.
create_session
()
self
.
hooked_sess
=
tf
.
train
.
MonitoredSession
(
session_creator
=
ReuseSessionCreator
(
self
.
sess
),
hooks
=
hooks
)
if
self
.
is_chief
:
if
self
.
is_chief
:
logger
.
info
(
"Initializing the session ..."
)
logger
.
info
(
"Initializing the session ..."
)
se
lf
.
_config
.
se
ssion_init
.
_run_init
(
self
.
sess
)
session_init
.
_run_init
(
self
.
sess
)
else
:
else
:
if
not
isinstance
(
self
.
_config
.
session_init
,
JustCurrentSession
):
if
not
isinstance
(
self
.
_config
.
session_init
,
JustCurrentSession
):
logger
.
warn
(
"This is not a chief worker, 'session_init' was ignored!"
)
logger
.
warn
(
"This is not a chief worker, 'session_init' was ignored!"
)
...
@@ -222,35 +159,18 @@ class Trainer(object):
...
@@ -222,35 +159,18 @@ class Trainer(object):
self
.
sess
.
graph
.
finalize
()
self
.
sess
.
graph
.
finalize
()
logger
.
info
(
"Graph Finalized."
)
logger
.
info
(
"Graph Finalized."
)
def
_create_session
(
self
):
@
call_only_once
"""
def
main_loop
(
self
,
steps_per_epoch
,
starting_epoch
=
1
,
max_epoch
=
99999
):
Setup self.sess (the raw tf.Session)
and self.hooked_sess (the session with hooks and coordinator)
"""
hooks
=
self
.
_callbacks
.
get_hooks
()
self
.
sess
=
self
.
_config
.
session_creator
.
create_session
()
self
.
hooked_sess
=
tf
.
train
.
MonitoredSession
(
session_creator
=
ReuseSessionCreator
(
self
.
sess
),
hooks
=
hooks
)
def
_setup
(
self
):
"""
Build the entire graph for training.
Responsible for setup InputSource as well (including registering InputSource callbacks)
Since this method will get called in constructor only,
you can simply leave it empty and build your graph outside the trainer.
"""
pass
def
main_loop
(
self
):
"""
"""
Run the main training loop.
Run the main training loop.
"""
"""
with
self
.
sess
.
as_default
():
with
self
.
sess
.
as_default
():
self
.
loop
.
config
(
steps_per_epoch
,
starting_epoch
,
max_epoch
)
self
.
loop
.
update_global_step
()
self
.
loop
.
update_global_step
()
try
:
try
:
self
.
_callbacks
.
before_train
()
self
.
_callbacks
.
before_train
()
# refresh global step (might have changed by callbacks) TODO ugly
# refresh global step (might have changed by callbacks) TODO ugly
# what if gs is changed later?
self
.
loop
.
update_global_step
()
self
.
loop
.
update_global_step
()
for
self
.
loop
.
_epoch_num
in
range
(
for
self
.
loop
.
_epoch_num
in
range
(
self
.
loop
.
starting_epoch
,
self
.
loop
.
max_epoch
+
1
):
self
.
loop
.
starting_epoch
,
self
.
loop
.
max_epoch
+
1
):
...
@@ -279,18 +199,106 @@ class Trainer(object):
...
@@ -279,18 +199,106 @@ class Trainer(object):
self
.
_callbacks
.
after_train
()
self
.
_callbacks
.
after_train
()
self
.
hooked_sess
.
close
()
self
.
hooked_sess
.
close
()
def
get_predictor
(
self
,
input_names
,
output_names
,
tower
=
0
):
def
train
(
self
,
callbacks
,
monitors
,
session_creator
,
session_init
,
steps_per_epoch
,
starting_epoch
,
max_epoch
):
"""
Implemented by:
.. code-block:: python
self.setup_callbacks(callbacks, monitors)
self.initialize(session_creator, session_init)
self.main_loop(steps_per_epoch, starting_epoch, max_epoch)
You can call those methods by yourself to have better control on details if needed.
"""
self
.
setup_callbacks
(
callbacks
,
monitors
)
self
.
initialize
(
session_creator
,
session_init
)
self
.
main_loop
(
steps_per_epoch
,
starting_epoch
,
max_epoch
)
# create the old trainer when called with TrainConfig
def
__new__
(
cls
,
*
args
,
**
kwargs
):
if
(
len
(
args
)
>
0
and
isinstance
(
args
[
0
],
TrainConfig
))
\
or
'config'
in
kwargs
:
name
=
cls
.
__name__
try
:
old_trainer
=
getattr
(
old_train
,
name
)
except
AttributeError
:
# custom trainer. has to live with it
return
super
(
Trainer
,
cls
)
.
__new__
(
cls
)
else
:
logger
.
warn
(
"You're calling new trainers with old trainer API!"
)
logger
.
warn
(
"Now it returns the old trainer for you, please switch to use new trainers correctly!"
)
logger
.
warn
(
"'SomeTrainer(config, ...).train()' should be equivalent to "
"'launch_train_with_config(config, SomeTrainer(...))' in the new API."
)
return
old_trainer
(
*
args
,
**
kwargs
)
else
:
return
super
(
Trainer
,
cls
)
.
__new__
(
cls
)
def
_get_property
(
name
):
"""
Delegate property to self.loop
"""
ret
=
property
(
lambda
self
:
getattr
(
self
.
loop
,
name
))
if
six
.
PY3
:
# __doc__ is readonly in Py2
try
:
ret
.
__doc__
=
getattr
(
TrainLoop
,
name
)
.
__doc__
except
AttributeError
:
pass
return
ret
for
name
in
[
'global_step'
,
'local_step'
,
'steps_per_epoch'
,
'epoch_num'
,
'starting_epoch'
,
'max_epoch'
]:
setattr
(
Trainer
,
name
,
_get_property
(
name
))
class
TowerTrainer
(
Trainer
):
"""
Base trainers for models that can be built by calling a tower function under a :class:`TowerContext`.
This is required by some features that replicates the model
automatically, e.g. creating a predictor.
"""
tower_func
=
None
"""
A :class:`TowerFuncWrapper` instance.
A callable which takes some input tensors and builds one replicate of the model.
"""
@
call_only_once
def
set_tower_func
(
self
,
tower_func
):
"""
Args:
tower_func (TowerFuncWrapper)
"""
assert
isinstance
(
tower_func
,
TowerFuncWrapper
),
tower_func
self
.
tower_func
=
tower_func
@
property
def
inputs_desc
(
self
):
"""
Returns:
list[InputDesc]: metainfo about the inputs to the tower.
"""
return
self
.
tower_func
.
inputs_desc
def
get_predictor
(
self
,
input_names
,
output_names
,
device
=
0
):
"""
"""
Returns a callable predictor built under ``TowerContext(is_training=False)``.
Returns a callable predictor built under ``TowerContext(is_training=False)``.
Args:
Args:
input_names (list), output_names(list): list of names
input_names (list), output_names(list): list of names
tower (int): build the predictor on device '/gpu:{tower
}' or use -1 for '/cpu:0'.
device (int): build the predictor on device '/gpu:{device
}' or use -1 for '/cpu:0'.
Returns:
Returns:
an :class:`OnlinePredictor`.
an :class:`OnlinePredictor`.
"""
"""
device
=
tower
assert
self
.
tower_func
is
not
None
,
"Must set tower_func on the trainer to use get_predictor()!"
assert
self
.
tower_func
is
not
None
,
"Must set tower_func on the trainer to use get_predictor()!"
tower_name
=
'tower-pred-{}'
.
format
(
device
)
if
device
>=
0
else
'tower-pred-cpu'
tower_name
=
'tower-pred-{}'
.
format
(
device
)
if
device
>=
0
else
'tower-pred-cpu'
...
@@ -311,33 +319,91 @@ class Trainer(object):
...
@@ -311,33 +319,91 @@ class Trainer(object):
@
property
@
property
def
_main_tower_vs_name
(
self
):
def
_main_tower_vs_name
(
self
):
# The vs name a predictor should be built under.
"""
# for internal use only. Should let graphbuilder return it.
The vs name for the "main" copy of the model,
to be used to build predictors.
"""
return
""
return
""
@
property
def
config
(
self
):
log_deprecated
(
"Trainer.config"
,
"It is supposed to be private! Most of its attributes can be accessed by other means."
,
"2017-12-31"
)
return
self
.
_config
def
_get_property
(
name
):
@
six
.
add_metaclass
(
ABCMeta
)
class
SingleCostTrainer
(
TowerTrainer
):
"""
"""
Delegate property to self.loop
Base class for single-cost trainer.
Single-cost trainer has a :meth:`setup_graph` method which takes
(inputs_desc, input, get_cost_fn, get_opt_fn), and build the training operations from them.
To use a SingleCostTrainer object, call `trainer.setup_graph(...); trainer.train(...)`.
"""
"""
ret
=
property
(
lambda
self
:
getattr
(
self
.
loop
,
name
))
if
six
.
PY3
:
# __doc__ is readonly in Py2
try
:
ret
.
__doc__
=
getattr
(
TrainLoop
,
name
)
.
__doc__
except
AttributeError
:
pass
return
ret
def
train
(
self
,
callbacks
,
monitors
,
session_creator
,
session_init
,
steps_per_epoch
,
starting_epoch
,
max_epoch
):
"""
Same as :meth:`Trainer.train()`, except that the callbacks this
trainer needs are automatically added.
"""
callbacks
=
callbacks
+
self
.
_internal_callbacks
super
(
SingleCostTrainer
,
self
)
.
train
(
callbacks
,
monitors
,
session_creator
,
session_init
,
steps_per_epoch
,
starting_epoch
,
max_epoch
)
@
call_only_once
def
setup_graph
(
self
,
inputs_desc
,
input
,
get_cost_fn
,
get_opt_fn
):
"""
Responsible for building the main training graph for single-cost training.
for
name
in
[
'global_step'
,
'local_step'
,
'steps_per_epoch'
,
Args:
'epoch_num'
,
'starting_epoch'
,
'max_epoch'
]:
inputs_desc ([InputDesc]):
setattr
(
Trainer
,
name
,
_get_property
(
name
))
input (InputSource):
get_cost_fn ([tf.Tensor] -> tf.Tensor): callable, takes some input tenosrs and return a cost tensor.
Might get called multiple times for data-parallel training or inference.
get_opt_fn (-> tf.train.Optimizer): callable which returns an
optimizer. Will only be called once.
Returns:
[Callback]: a (possibly empty) list of callbacks needed for training.
These callbacks will be automatically added when you call `train()`.
So you can usually ignore the return value.
"""
get_cost_fn
=
TowerFuncWrapper
(
get_cost_fn
,
inputs_desc
)
get_opt_fn
=
memoized
(
get_opt_fn
)
self
.
set_tower_func
(
get_cost_fn
)
input_callbacks
=
self
.
_setup_input
(
inputs_desc
,
input
)
train_callbacks
=
self
.
_setup_graph
(
input
,
get_cost_fn
,
get_opt_fn
)
self
.
_internal_callbacks
=
input_callbacks
+
train_callbacks
return
self
.
_internal_callbacks
@
abstractmethod
def
_setup_graph
(
self
,
input
,
get_cost_fn
,
get_opt_fn
):
pass
def
_setup_input
(
self
,
inputs_desc
,
input
):
assert
not
input
.
setup_done
()
return
input
.
setup
(
inputs_desc
)
def
_make_get_grad_fn
(
self
,
input
,
get_cost_fn
,
get_opt_fn
):
"""
Returns:
a get_grad_fn for GraphBuilder to use.
"""
# internal use only
assert
input
.
setup_done
()
def
get_grad_fn
():
ctx
=
get_current_tower_context
()
cost
=
get_cost_fn
(
*
input
.
get_input_tensors
())
varlist
=
ctx
.
filter_vars_by_vs_name
(
tf
.
trainable_variables
())
opt
=
get_opt_fn
()
grads
=
opt
.
compute_gradients
(
cost
,
var_list
=
varlist
,
gate_gradients
=
False
,
colocate_gradients_with_ops
=
True
)
grads
=
FilterNoneGrad
()
.
process
(
grads
)
return
grads
return
get_grad_fn
tensorpack/train
v2
/interface.py
→
tensorpack/train/interface.py
View file @
ba4e3178
...
@@ -7,11 +7,11 @@ import tensorflow as tf
...
@@ -7,11 +7,11 @@ import tensorflow as tf
from
..input_source
import
(
from
..input_source
import
(
InputSource
,
FeedInput
,
QueueInput
,
StagingInputWrapper
,
DummyConstantInput
)
InputSource
,
FeedInput
,
QueueInput
,
StagingInputWrapper
,
DummyConstantInput
)
from
..train.config
import
TrainConfig
from
..train
v1
.config
import
TrainConfig
from
.base
import
SingleCostTrainer
from
.base
import
SingleCostTrainer
from
.trainers
import
SimpleTrainer
,
DistributedTrainerReplicated
from
.trainers
import
SimpleTrainer
,
DistributedTrainerReplicated
__all__
=
[
'launch_train_with_config'
,
'
TrainConfig'
,
'
apply_default_prefetch'
]
__all__
=
[
'launch_train_with_config'
,
'apply_default_prefetch'
]
def
apply_default_prefetch
(
input_source_or_dataflow
,
trainer
,
towers
):
def
apply_default_prefetch
(
input_source_or_dataflow
,
trainer
,
towers
):
...
...
tensorpack/train
v2
/trainers.py
→
tensorpack/train/trainers.py
View file @
ba4e3178
...
@@ -24,6 +24,7 @@ from .base import SingleCostTrainer
...
@@ -24,6 +24,7 @@ from .base import SingleCostTrainer
__all__
=
[
'SimpleTrainer'
,
__all__
=
[
'SimpleTrainer'
,
'QueueInputTrainer'
,
'QueueInputTrainer'
,
'SyncMultiGPUTrainer'
,
'SyncMultiGPUTrainerReplicated'
,
'SyncMultiGPUTrainerReplicated'
,
'SyncMultiGPUTrainerParameterServer'
,
'SyncMultiGPUTrainerParameterServer'
,
'AsyncMultiGPUTrainer'
,
'AsyncMultiGPUTrainer'
,
...
@@ -68,6 +69,17 @@ class SyncMultiGPUTrainerParameterServer(SingleCostTrainer):
...
@@ -68,6 +69,17 @@ class SyncMultiGPUTrainerParameterServer(SingleCostTrainer):
return
[]
return
[]
def
SyncMultiGPUTrainer
(
towers
):
"""
Return a default multi-GPU trainer, if you don't care about the details.
It may not be the most efficient one for your task.
Args:
towers (list[int]): list of GPU ids.
"""
return
SyncMultiGPUTrainerParameterServer
(
towers
,
ps_device
=
'gpu'
)
class
AsyncMultiGPUTrainer
(
SingleCostTrainer
):
class
AsyncMultiGPUTrainer
(
SingleCostTrainer
):
__doc__
=
AsyncMultiGPUBuilder
.
__doc__
__doc__
=
AsyncMultiGPUBuilder
.
__doc__
...
...
tensorpack/trainv
2
/__init__.py
→
tensorpack/trainv
1
/__init__.py
View file @
ba4e3178
...
@@ -19,7 +19,7 @@ def global_import(name):
...
@@ -19,7 +19,7 @@ def global_import(name):
_CURR_DIR
=
os
.
path
.
dirname
(
__file__
)
_CURR_DIR
=
os
.
path
.
dirname
(
__file__
)
_SKIP
=
[]
_SKIP
=
[
'utility'
]
for
_
,
module_name
,
_
in
iter_modules
(
for
_
,
module_name
,
_
in
iter_modules
(
[
_CURR_DIR
]):
[
_CURR_DIR
]):
srcpath
=
os
.
path
.
join
(
_CURR_DIR
,
module_name
+
'.py'
)
srcpath
=
os
.
path
.
join
(
_CURR_DIR
,
module_name
+
'.py'
)
...
...
tensorpack/trainv
2
/base.py
→
tensorpack/trainv
1
/base.py
View file @
ba4e3178
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# -*- coding: utf-8 -*-
# File: base.py
# File: base.py
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
import
tensorflow
as
tf
import
weakref
import
time
import
time
from
six.moves
import
range
import
weakref
import
six
import
six
from
abc
import
abstractmethod
,
ABCMeta
from
six.moves
import
range
import
tensorflow
as
tf
from
.config
import
TrainConfig
from
..utils
import
logger
from
..utils
import
logger
from
..utils.
argtools
import
call_only_once
,
memoiz
ed
from
..utils.
develop
import
log_deprecat
ed
from
..callbacks
import
Callback
,
Callbacks
from
..callbacks
import
Callback
,
Callbacks
from
..callbacks.monitor
import
Monitors
,
TrainingMonitor
from
..callbacks.monitor
import
Monitors
,
TrainingMonitor
from
..tfutils
import
get_global_step_value
from
..tfutils.model_utils
import
describe_trainable_vars
from
..tfutils.model_utils
import
describe_trainable_vars
from
..tfutils.sessinit
import
JustCurrentSession
from
..tfutils.sesscreate
import
ReuseSessionCreator
from
..tfutils.sesscreate
import
ReuseSessionCreator
from
..tfutils.tower
import
TowerFuncWrapper
,
get_current_tower_context
from
..tfutils.sessinit
import
JustCurrentSession
from
..tfutils.gradproc
import
FilterNoneGrad
from
..tfutils.tower
import
TowerFuncWrapper
from
..callbacks.steps
import
MaintainStepCounter
from
..input_source
import
PlaceholderInput
from
..graph_builder.predictor_factory
import
SimplePredictBuilder
from
..graph_builder.predictor_factory
import
SimplePredictBuilder
from
..input_source
import
FeedfreeInput
,
PlaceholderInput
from
..predict.base
import
OnlinePredictor
from
..predict.base
import
OnlinePredictor
from
..callbacks.steps
import
MaintainStepCounter
__all__
=
[
'Trainer'
,
'StopTraining'
]
class
StopTraining
(
BaseException
):
"""
An exception thrown to stop training.
"""
pass
class
TrainLoop
(
object
):
"""
Manage the double for loop.
"""
import
tensorpack.train
as
old_train
# noqa
def
__init__
(
self
):
from
..train.base
import
StopTraining
,
TrainLoop
self
.
_epoch_num
=
0
self
.
_global_step
=
0
self
.
_local_step
=
-
1
__all__
=
[
'Trainer'
,
'SingleCostTrainer'
,
'TowerTrainer'
]
def
config
(
self
,
steps_per_epoch
,
starting_epoch
,
max_epoch
):
"""
Configure the loop given the settings.
"""
self
.
starting_epoch
=
starting_epoch
self
.
max_epoch
=
max_epoch
self
.
steps_per_epoch
=
steps_per_epoch
self
.
_epoch_num
=
starting_epoch
-
1
def
update_global_step
(
self
):
"""
Update the Python-side global_step from TF.
This must be called under initialized default session.
"""
self
.
_global_step
=
get_global_step_value
()
@
property
def
epoch_num
(
self
):
"""
The number of the currently ongoing epoch.
An epoch is defined to cover the moment before calling `before_epoch` until after calling `trigger_epoch`.
i.e., in the `trigger_epoch` of epoch 3, `self.epoch_num` is 3.
If you need use `self.epoch_num` in your callback, you'll need to know this.
"""
return
self
.
_epoch_num
@
property
def
global_step
(
self
):
"""
The tensorflow global_step, i.e. how many times ``hooked_sess.run`` has been called.
Note:
1. global_step is incremented **after** each ``hooked_sess.run`` returns from TF runtime.
2. If you make zero or more than one calls to ``hooked_sess.run`` in one
:meth:`run_step`, local_step and global_step may increment at different speed.
"""
return
self
.
_global_step
@
property
def
local_step
(
self
):
"""
The number of steps that have finished in the current epoch.
"""
return
self
.
_local_step
class
Trainer
(
object
):
class
Trainer
(
object
):
""" Base class for a trainer.
""" Base class for a trainer.
Attributes:
config (TrainConfig): the config used in this trainer.
model (ModelDesc): alias for ``config.model``.
sess (tf.Session): the current session in use.
hooked_sess (tf.train.MonitoredSession): the session with hooks.
monitors (Monitors): the monitors. Other callbacks can use it for logging.
"""
"""
_API_VERSION
=
2
_API_VERSION
=
1
is_chief
=
True
is_chief
=
True
"""
Whether this process is the chief worker in distributed training.
Only chief worker will run some callbacks.
"""
def
__init__
(
self
,
config
=
None
):
def
__init__
(
self
,
config
):
"""
"""
config is only for compatibility reasons in case you're
Args:
using custom trainers with old-style API.
config (TrainConfig): the train config.
You should never use config.
"""
"""
assert
isinstance
(
config
,
TrainConfig
),
type
(
config
)
self
.
_config
=
config
self
.
model
=
config
.
model
if
self
.
model
is
not
None
:
def
f
(
*
inputs
):
self
.
model
.
build_graph
(
inputs
)
"""
Only to mimic new trainer interafce on inference.
"""
self
.
inputs_desc
=
self
.
model
.
get_inputs_desc
()
self
.
tower_func
=
TowerFuncWrapper
(
f
,
self
.
inputs_desc
)
self
.
_callbacks
=
[]
self
.
_callbacks
=
[]
self
.
_monitors
=
[]
self
.
loop
=
TrainLoop
()
self
.
loop
=
TrainLoop
()
self
.
_monitors
=
[]
# Clarify the type. Don't change from list to monitors.
self
.
loop
.
config
(
config
.
steps_per_epoch
,
config
.
starting_epoch
,
config
.
max_epoch
)
# Hacks!
self
.
_setup
()
# subclass will setup the graph and InputSource
if
config
is
not
None
:
logger
.
warn
(
"You're initializing new trainer with old trainer API!"
)
def
register_callback
(
self
,
cb
):
logger
.
warn
(
"This could happen if you wrote a custom trainer before."
)
logger
.
warn
(
"It may work now through some hacks, but please switch to the new API!"
)
self
.
_config
=
config
self
.
inputs_desc
=
config
.
model
.
get_inputs_desc
()
self
.
tower_func
=
TowerFuncWrapper
(
lambda
*
inputs
:
config
.
model
.
build_graph
(
inputs
),
self
.
inputs_desc
)
self
.
_main_tower_vs_name
=
""
def
gp
(
input_names
,
output_names
,
tower
=
0
):
return
TowerTrainer
.
get_predictor
(
self
,
input_names
,
output_names
,
device
=
tower
)
self
.
get_predictor
=
gp
old_train
=
self
.
train
def
train
():
return
old_train
(
config
.
callbacks
,
config
.
monitors
,
config
.
session_creator
,
config
.
session_init
,
config
.
steps_per_epoch
,
config
.
starting_epoch
,
config
.
max_epoch
)
self
.
train
=
train
def
_register_callback
(
self
,
cb
):
"""
"""
Register a callback to the trainer.
Register a callback to the trainer.
It can only be called before :meth:`Trainer.train` gets called.
It can only be called before :meth:`Trainer.train` gets called.
...
@@ -86,7 +151,7 @@ class Trainer(object):
...
@@ -86,7 +151,7 @@ class Trainer(object):
else
:
else
:
self
.
_callbacks
.
append
(
cb
)
self
.
_callbacks
.
append
(
cb
)
def
_
register_monitor
(
self
,
mon
):
def
register_monitor
(
self
,
mon
):
"""
"""
Register a monitor to the trainer.
Register a monitor to the trainer.
It can only be called before :meth:`Trainer.train` gets called.
It can only be called before :meth:`Trainer.train` gets called.
...
@@ -97,7 +162,18 @@ class Trainer(object):
...
@@ -97,7 +162,18 @@ class Trainer(object):
if
not
self
.
is_chief
and
mon
.
chief_only
:
if
not
self
.
is_chief
and
mon
.
chief_only
:
logger
.
warn
(
"Monitor {} is chief-only, skipped."
.
format
(
str
(
mon
)))
logger
.
warn
(
"Monitor {} is chief-only, skipped."
.
format
(
str
(
mon
)))
else
:
else
:
self
.
_register_callback
(
mon
)
self
.
_monitors
.
append
(
mon
)
self
.
register_callback
(
mon
)
@
property
def
monitors
(
self
):
assert
isinstance
(
self
.
_monitors
,
Monitors
),
"Monitors haven't been setup!"
return
self
.
_monitors
def
train
(
self
):
""" Start training """
self
.
setup
()
self
.
main_loop
()
def
run_step
(
self
):
def
run_step
(
self
):
"""
"""
...
@@ -113,44 +189,32 @@ class Trainer(object):
...
@@ -113,44 +189,32 @@ class Trainer(object):
"of Trainer.run_step()!"
)
"of Trainer.run_step()!"
)
self
.
hooked_sess
.
run
(
self
.
train_op
)
self
.
hooked_sess
.
run
(
self
.
train_op
)
@
call_only_once
def
setup
(
self
):
def
setup_callbacks
(
self
,
callbacks
,
monitors
):
"""
"""
Setup
callbacks and monitors. Must be called after the main graph is built
.
Setup
the trainer and be ready for the main loop
.
"""
"""
describe_trainable_vars
()
# TODO weird
self
.
register_callback
(
MaintainStepCounter
())
for
cb
in
self
.
_config
.
callbacks
:
self
.
register_callback
(
cb
)
for
m
in
self
.
_config
.
monitors
:
self
.
register_monitor
(
m
)
self
.
_monitors
=
Monitors
(
self
.
_monitors
)
self
.
register_callback
(
self
.
_monitors
)
self
.
_register_callback
(
MaintainStepCounter
())
describe_trainable_vars
()
for
cb
in
callbacks
:
self
.
_register_callback
(
cb
)
for
m
in
monitors
:
self
.
_register_monitor
(
m
)
self
.
monitors
=
Monitors
(
monitors
)
self
.
_register_callback
(
self
.
monitors
)
# monitors is also a callback
# some final operations that might modify the graph
# some final operations that might modify the graph
logger
.
info
(
"Setup callbacks graph ..."
)
logger
.
info
(
"Setup callbacks graph ..."
)
self
.
_callbacks
=
Callbacks
(
self
.
_callbacks
)
self
.
_callbacks
=
Callbacks
(
self
.
_callbacks
)
self
.
_callbacks
.
setup_graph
(
weakref
.
proxy
(
self
))
self
.
_callbacks
.
setup_graph
(
weakref
.
proxy
(
self
))
self
.
_config
.
session_init
.
_setup_graph
()
@
call_only_once
def
initialize
(
self
,
session_creator
,
session_init
):
"""
Initialize self.sess and self.hooked_sess.
Must be called after callbacks are setup.
"""
session_init
.
_setup_graph
()
logger
.
info
(
"Creating the session ..."
)
logger
.
info
(
"Creating the session ..."
)
self
.
_create_session
()
hooks
=
self
.
_callbacks
.
get_hooks
()
self
.
sess
=
session_creator
.
create_session
()
self
.
hooked_sess
=
tf
.
train
.
MonitoredSession
(
session_creator
=
ReuseSessionCreator
(
self
.
sess
),
hooks
=
hooks
)
if
self
.
is_chief
:
if
self
.
is_chief
:
logger
.
info
(
"Initializing the session ..."
)
logger
.
info
(
"Initializing the session ..."
)
session_init
.
_run_init
(
self
.
sess
)
se
lf
.
_config
.
se
ssion_init
.
_run_init
(
self
.
sess
)
else
:
else
:
if
not
isinstance
(
self
.
_config
.
session_init
,
JustCurrentSession
):
if
not
isinstance
(
self
.
_config
.
session_init
,
JustCurrentSession
):
logger
.
warn
(
"This is not a chief worker, 'session_init' was ignored!"
)
logger
.
warn
(
"This is not a chief worker, 'session_init' was ignored!"
)
...
@@ -158,18 +222,35 @@ class Trainer(object):
...
@@ -158,18 +222,35 @@ class Trainer(object):
self
.
sess
.
graph
.
finalize
()
self
.
sess
.
graph
.
finalize
()
logger
.
info
(
"Graph Finalized."
)
logger
.
info
(
"Graph Finalized."
)
@
call_only_once
def
_create_session
(
self
):
def
main_loop
(
self
,
steps_per_epoch
,
starting_epoch
=
1
,
max_epoch
=
99999
):
"""
Setup self.sess (the raw tf.Session)
and self.hooked_sess (the session with hooks and coordinator)
"""
hooks
=
self
.
_callbacks
.
get_hooks
()
self
.
sess
=
self
.
_config
.
session_creator
.
create_session
()
self
.
hooked_sess
=
tf
.
train
.
MonitoredSession
(
session_creator
=
ReuseSessionCreator
(
self
.
sess
),
hooks
=
hooks
)
def
_setup
(
self
):
"""
Build the entire graph for training.
Responsible for setup InputSource as well (including registering InputSource callbacks)
Since this method will get called in constructor only,
you can simply leave it empty and build your graph outside the trainer.
"""
pass
def
main_loop
(
self
):
"""
"""
Run the main training loop.
Run the main training loop.
"""
"""
with
self
.
sess
.
as_default
():
with
self
.
sess
.
as_default
():
self
.
loop
.
config
(
steps_per_epoch
,
starting_epoch
,
max_epoch
)
self
.
loop
.
update_global_step
()
self
.
loop
.
update_global_step
()
try
:
try
:
self
.
_callbacks
.
before_train
()
self
.
_callbacks
.
before_train
()
# refresh global step (might have changed by callbacks) TODO ugly
# refresh global step (might have changed by callbacks) TODO ugly
# what if gs is changed later?
self
.
loop
.
update_global_step
()
self
.
loop
.
update_global_step
()
for
self
.
loop
.
_epoch_num
in
range
(
for
self
.
loop
.
_epoch_num
in
range
(
self
.
loop
.
starting_epoch
,
self
.
loop
.
max_epoch
+
1
):
self
.
loop
.
starting_epoch
,
self
.
loop
.
max_epoch
+
1
):
...
@@ -198,106 +279,18 @@ class Trainer(object):
...
@@ -198,106 +279,18 @@ class Trainer(object):
self
.
_callbacks
.
after_train
()
self
.
_callbacks
.
after_train
()
self
.
hooked_sess
.
close
()
self
.
hooked_sess
.
close
()
def
train
(
self
,
def
get_predictor
(
self
,
input_names
,
output_names
,
tower
=
0
):
callbacks
,
monitors
,
session_creator
,
session_init
,
steps_per_epoch
,
starting_epoch
,
max_epoch
):
"""
Implemented by:
.. code-block:: python
self.setup_callbacks(callbacks, monitors)
self.initialize(session_creator, session_init)
self.main_loop(steps_per_epoch, starting_epoch, max_epoch)
You can call those methods by yourself to have better control on details if needed.
"""
self
.
setup_callbacks
(
callbacks
,
monitors
)
self
.
initialize
(
session_creator
,
session_init
)
self
.
main_loop
(
steps_per_epoch
,
starting_epoch
,
max_epoch
)
# create the old trainer when called with TrainConfig
def
__new__
(
cls
,
*
args
,
**
kwargs
):
if
(
len
(
args
)
>
0
and
isinstance
(
args
[
0
],
old_train
.
TrainConfig
))
\
or
'config'
in
kwargs
:
name
=
cls
.
__name__
try
:
old_trainer
=
getattr
(
old_train
,
name
)
except
AttributeError
:
# custom trainer. has to live with it
return
super
(
Trainer
,
cls
)
.
__new__
(
cls
)
else
:
logger
.
warn
(
"You're creating trainers with old trainer API!"
)
logger
.
warn
(
"Now it returns the old trainer for you, please switch to the new API!"
)
logger
.
warn
(
"'SomeTrainer(config, ...).train()' should be equivalent to "
"'launch_train_with_config(config, SomeTrainer(...))' in the new API."
)
return
old_trainer
(
*
args
,
**
kwargs
)
else
:
return
super
(
Trainer
,
cls
)
.
__new__
(
cls
)
def
_get_property
(
name
):
"""
Delegate property to self.loop
"""
ret
=
property
(
lambda
self
:
getattr
(
self
.
loop
,
name
))
if
six
.
PY3
:
# __doc__ is readonly in Py2
try
:
ret
.
__doc__
=
getattr
(
TrainLoop
,
name
)
.
__doc__
except
AttributeError
:
pass
return
ret
for
name
in
[
'global_step'
,
'local_step'
,
'steps_per_epoch'
,
'epoch_num'
,
'starting_epoch'
,
'max_epoch'
]:
setattr
(
Trainer
,
name
,
_get_property
(
name
))
class
TowerTrainer
(
Trainer
):
"""
Base trainers for models that can be built by calling a tower function under a :class:`TowerContext`.
This is required by some features that replicates the model
automatically, e.g. creating a predictor.
"""
tower_func
=
None
"""
A :class:`TowerFuncWrapper` instance.
A callable which takes some input tensors and builds one replicate of the model.
"""
@
call_only_once
def
set_tower_func
(
self
,
tower_func
):
"""
Args:
tower_func (TowerFuncWrapper)
"""
assert
isinstance
(
tower_func
,
TowerFuncWrapper
),
tower_func
self
.
tower_func
=
tower_func
@
property
def
inputs_desc
(
self
):
"""
Returns:
list[InputDesc]: metainfo about the inputs to the tower.
"""
return
self
.
tower_func
.
inputs_desc
def
get_predictor
(
self
,
input_names
,
output_names
,
device
=
0
):
"""
"""
Returns a callable predictor built under ``TowerContext(is_training=False)``.
Returns a callable predictor built under ``TowerContext(is_training=False)``.
Args:
Args:
input_names (list), output_names(list): list of names
input_names (list), output_names(list): list of names
device (int): build the predictor on device '/gpu:{device
}' or use -1 for '/cpu:0'.
tower (int): build the predictor on device '/gpu:{tower
}' or use -1 for '/cpu:0'.
Returns:
Returns:
an :class:`OnlinePredictor`.
an :class:`OnlinePredictor`.
"""
"""
device
=
tower
assert
self
.
tower_func
is
not
None
,
"Must set tower_func on the trainer to use get_predictor()!"
assert
self
.
tower_func
is
not
None
,
"Must set tower_func on the trainer to use get_predictor()!"
tower_name
=
'tower-pred-{}'
.
format
(
device
)
if
device
>=
0
else
'tower-pred-cpu'
tower_name
=
'tower-pred-{}'
.
format
(
device
)
if
device
>=
0
else
'tower-pred-cpu'
...
@@ -318,92 +311,48 @@ class TowerTrainer(Trainer):
...
@@ -318,92 +311,48 @@ class TowerTrainer(Trainer):
@
property
@
property
def
_main_tower_vs_name
(
self
):
def
_main_tower_vs_name
(
self
):
"""
# The vs name a predictor should be built under.
The vs name for the "main" copy of the model,
# for internal use only. Should let graphbuilder return it.
to be used to build predictors.
"""
return
""
return
""
@
property
def
config
(
self
):
log_deprecated
(
"Trainer.config"
,
"It is supposed to be private! Most of its attributes can be accessed by other means."
,
"2017-12-31"
)
return
self
.
_config
# create new trainer when not called with TrainConfig
def
__new__
(
cls
,
*
args
,
**
kwargs
):
if
(
len
(
args
)
>
0
and
isinstance
(
args
[
0
],
TrainConfig
))
\
or
'config'
in
kwargs
:
return
super
(
Trainer
,
cls
)
.
__new__
(
cls
)
else
:
import
tensorpack.train
as
new_train
name
=
cls
.
__name__
new_trainer
=
getattr
(
new_train
,
name
)
logger
.
warn
(
"You're calling old trainers with new trainer API!"
)
logger
.
warn
(
"Now it returns the new trainer for you, please `export TENSORPACK_TRAIN_API=v2`"
" to import new trainers automatically."
)
logger
.
warn
(
"You can also ignore this warning and wait for new API to become the default."
)
return
new_trainer
(
*
args
,
**
kwargs
)
@
six
.
add_metaclass
(
ABCMeta
)
class
SingleCostTrainer
(
TowerTrainer
):
"""
Base class for single-cost trainer.
Single-cost trainer has a :meth:`setup_graph` method which takes
(inputs_desc, input, get_cost_fn, get_opt_fn), and build the training operations from them.
To use a SingleCostTrainer object, call `trainer.setup_graph(...); trainer.train(...)`.
def
_get_property
(
name
):
"""
"""
Delegate property to self.loop
"""
ret
=
property
(
lambda
self
:
getattr
(
self
.
loop
,
name
))
if
six
.
PY3
:
# __doc__ is readonly in Py2
try
:
ret
.
__doc__
=
getattr
(
TrainLoop
,
name
)
.
__doc__
except
AttributeError
:
pass
return
ret
def
train
(
self
,
callbacks
,
monitors
,
session_creator
,
session_init
,
steps_per_epoch
,
starting_epoch
,
max_epoch
):
"""
Same as :meth:`Trainer.train()`, except that the callbacks this
trainer needs are automatically added.
"""
callbacks
=
callbacks
+
self
.
_internal_callbacks
super
(
SingleCostTrainer
,
self
)
.
train
(
callbacks
,
monitors
,
session_creator
,
session_init
,
steps_per_epoch
,
starting_epoch
,
max_epoch
)
@
call_only_once
def
setup_graph
(
self
,
inputs_desc
,
input
,
get_cost_fn
,
get_opt_fn
):
"""
Responsible for building the main training graph for single-cost training.
Args:
inputs_desc ([InputDesc]):
input (InputSource):
get_cost_fn ([tf.Tensor] -> tf.Tensor): callable, takes some input tenosrs and return a cost tensor.
Might get called multiple times for data-parallel training or inference.
get_opt_fn (-> tf.train.Optimizer): callable which returns an
optimizer. Will only be called once.
Returns:
[Callback]: a (possibly empty) list of callbacks needed for training.
These callbacks will be automatically added when you call `train()`.
So you can usually ignore the return value.
"""
get_cost_fn
=
TowerFuncWrapper
(
get_cost_fn
,
inputs_desc
)
get_opt_fn
=
memoized
(
get_opt_fn
)
self
.
set_tower_func
(
get_cost_fn
)
input_callbacks
=
self
.
_setup_input
(
inputs_desc
,
input
)
train_callbacks
=
self
.
_setup_graph
(
input
,
get_cost_fn
,
get_opt_fn
)
self
.
_internal_callbacks
=
input_callbacks
+
train_callbacks
return
self
.
_internal_callbacks
@
abstractmethod
def
_setup_graph
(
self
,
input
,
get_cost_fn
,
get_opt_fn
):
pass
def
_setup_input
(
self
,
inputs_desc
,
input
):
assert
not
input
.
setup_done
()
return
input
.
setup
(
inputs_desc
)
def
_make_get_grad_fn
(
self
,
input
,
get_cost_fn
,
get_opt_fn
):
for
name
in
[
'global_step'
,
'local_step'
,
'steps_per_epoch'
,
"""
'epoch_num'
,
'starting_epoch'
,
'max_epoch'
]:
Returns:
setattr
(
Trainer
,
name
,
_get_property
(
name
))
a get_grad_fn for GraphBuilder to use.
"""
# internal use only
assert
input
.
setup_done
()
assert
isinstance
(
input
,
FeedfreeInput
),
input
def
get_grad_fn
():
ctx
=
get_current_tower_context
()
cost
=
get_cost_fn
(
*
input
.
get_input_tensors
())
varlist
=
ctx
.
filter_vars_by_vs_name
(
tf
.
trainable_variables
())
opt
=
get_opt_fn
()
grads
=
opt
.
compute_gradients
(
cost
,
var_list
=
varlist
,
gate_gradients
=
False
,
colocate_gradients_with_ops
=
True
)
grads
=
FilterNoneGrad
()
.
process
(
grads
)
return
grads
return
get_grad_fn
tensorpack/train/config.py
→
tensorpack/train
v1
/config.py
View file @
ba4e3178
File moved
tensorpack/train/distributed.py
→
tensorpack/train
v1
/distributed.py
View file @
ba4e3178
File moved
tensorpack/trainv1/interface.py
0 → 100644
View file @
ba4e3178
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# File: interface.py
__all__
=
[
'launch_train_with_config'
]
def
launch_train_with_config
(
config
,
trainer
):
from
..train.interface
import
launch_train_with_config
as
old_launch
old_launch
(
config
,
trainer
)
tensorpack/train/multigpu.py
→
tensorpack/train
v1
/multigpu.py
View file @
ba4e3178
File moved
tensorpack/train/simple.py
→
tensorpack/train
v1
/simple.py
View file @
ba4e3178
File moved
tensorpack/trainv1/utility.py
0 → 100644
View file @
ba4e3178
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# File: utility.py
# for backwards-compatibility
from
..graph_builder.utils
import
(
# noqa
OverrideToLocalVariable
,
override_to_local_variable
,
LeastLoadedDeviceSetter
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment