Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
be50085f
Commit
be50085f
authored
Nov 28, 2017
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Move implementation from trainv1 to train
parent
705732f2
Changes
9
Show whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
265 additions
and
270 deletions
+265
-270
docs/tutorial/trainer.md
docs/tutorial/trainer.md
+11
-2
docs/tutorial/training-interface.md
docs/tutorial/training-interface.md
+6
-7
tensorpack/train/base.py
tensorpack/train/base.py
+63
-33
tensorpack/train/config.py
tensorpack/train/config.py
+176
-0
tensorpack/train/interface.py
tensorpack/train/interface.py
+1
-1
tensorpack/train/tower.py
tensorpack/train/tower.py
+4
-6
tensorpack/trainv1/base.py
tensorpack/trainv1/base.py
+2
-66
tensorpack/trainv1/config.py
tensorpack/trainv1/config.py
+1
-151
tensorpack/trainv1/interface.py
tensorpack/trainv1/interface.py
+1
-4
No files found.
docs/tutorial/trainer.md
View file @
be50085f
...
@@ -13,7 +13,7 @@ You'll only need to __select__ what trainer to use.
...
@@ -13,7 +13,7 @@ You'll only need to __select__ what trainer to use.
### Tower Trainer
### Tower Trainer
Following the terminology in TensorFlow,
Following the terminology in TensorFlow,
a
"tower" function is something
that takes input tensors and adds __one replicate__ of the model to the graph.
a
__tower function__ is a callable
that takes input tensors and adds __one replicate__ of the model to the graph.
Most types of neural-network training could fall into this category.
Most types of neural-network training could fall into this category.
All non-base trainers in tensorpack is a subclass of
[
TowerTrainer
](
../modules/train.html#tensorpack.train.TowerTrainer
)
.
All non-base trainers in tensorpack is a subclass of
[
TowerTrainer
](
../modules/train.html#tensorpack.train.TowerTrainer
)
.
...
@@ -22,6 +22,15 @@ The concept of tower is used mainly to support:
...
@@ -22,6 +22,15 @@ The concept of tower is used mainly to support:
1.
Data-parallel multi-GPU training, where a replicate is built on each GPU.
1.
Data-parallel multi-GPU training, where a replicate is built on each GPU.
2.
Automatically building the graph for inference, where a replicate is built under inference mode.
2.
Automatically building the graph for inference, where a replicate is built under inference mode.
You'll specify a tower function when you use
`TowerTrainer`
.
The function needs to follow some conventions:
1.
It will always be called under a :class:
`TowerContext`
.
which will contain information about reuse, training/inference, scope name, etc.
2.
It might get called multiple times for data-parallel training or inference.
3.
To respect variable reuse, use
`tf.get_variable`
instead of
`tf.Variable`
in the function.
### MultiGPU Trainers
### MultiGPU Trainers
...
@@ -41,5 +50,5 @@ Note some common problems when using these trainers:
...
@@ -41,5 +50,5 @@ Note some common problems when using these trainers:
Splitting a tensor to GPUs makes no sense at all, only to put unnecessary shape constraints on the data.
Splitting a tensor to GPUs makes no sense at all, only to put unnecessary shape constraints on the data.
By letting each GPU train on its own input tensors, they can train on inputs of different shapes simultaneously.
By letting each GPU train on its own input tensors, they can train on inputs of different shapes simultaneously.
2.
Your model code (the tower function
) will get called multipile times.
2.
The tower function (your model code
) will get called multipile times.
You'll need to be very careful when modifying global states in those functions, e.g. adding ops to TF collections.
You'll need to be very careful when modifying global states in those functions, e.g. adding ops to TF collections.
docs/tutorial/training-interface.md
View file @
be50085f
...
@@ -8,7 +8,7 @@ when you don't want to customize too much.
...
@@ -8,7 +8,7 @@ when you don't want to customize too much.
### With ModelDesc and TrainConfig
### With ModelDesc and TrainConfig
This is an interface that's most familiar to old tensorpack users,
This is an interface that's most familiar to old tensorpack users,
and is
now mainly useful for single-cost tasks
.
and is
used for single-cost tasks only
.
A lot of examples are written in this interface.
A lot of examples are written in this interface.
[
SingleCost trainers
](
../modules/train.html#tensorpack.train.SingleCostTrainer
)
[
SingleCost trainers
](
../modules/train.html#tensorpack.train.SingleCostTrainer
)
...
@@ -35,10 +35,9 @@ class MyModel(ModelDesc):
...
@@ -35,10 +35,9 @@ class MyModel(ModelDesc):
You can use any symbolic functions in
`_build_graph`
, including TensorFlow core library
You can use any symbolic functions in
`_build_graph`
, including TensorFlow core library
functions and other symbolic libraries.
functions and other symbolic libraries.
But you need to follow the requirement of
`_build_graph`
will be the tower function,
[
get_cost_fn
](
../modules/train.html#tensorpack.train.SingleCostTrainer.setup_graph
)
,
so you need to follow
[
some rules
](
trainer.md#tower-trainer
)
.
because this function will be used as part of
`get_cost_fn`
.
You also need to set
`self.cost`
in this function.
At last you need to set
`self.cost`
.
After defining such a model, use it with
`TrainConfig`
and
`launch_train_with_config`
:
After defining such a model, use it with
`TrainConfig`
and
`launch_train_with_config`
:
...
@@ -60,11 +59,11 @@ See the docs of
...
@@ -60,11 +59,11 @@ See the docs of
[
TrainConfig
](
../modules/train.html#tensorpack.train.TrainConfig
)
[
TrainConfig
](
../modules/train.html#tensorpack.train.TrainConfig
)
and
and
[
launch_train_with_config
](
../modules/train.html#tensorpack.train.launch_train_with_config
)
[
launch_train_with_config
](
../modules/train.html#tensorpack.train.launch_train_with_config
)
for
usage and
detailed functionalities.
for detailed functionalities.
### Raw Trainer Interface
### Raw Trainer Interface
You can also access methods of trainer directly, to get a finer control
:
To get a lower-level control, you can also access methods of trainer directly
:
__Build__
the graph: For general trainer, build the graph by yourself.
__Build__
the graph: For general trainer, build the graph by yourself.
For single-cost trainer, build the graph by
For single-cost trainer, build the graph by
...
...
tensorpack/train/base.py
View file @
be50085f
...
@@ -9,55 +9,84 @@ from six.moves import range
...
@@ -9,55 +9,84 @@ from six.moves import range
import
six
import
six
from
..callbacks
import
(
from
..callbacks
import
(
Callback
,
Callbacks
,
Monitors
,
TrainingMonitor
,
Callback
,
Callbacks
,
Monitors
,
TrainingMonitor
)
MovingAverageSummary
,
ProgressBar
,
MergeAllSummaries
,
TFEventWriter
,
JSONWriter
,
ScalarPrinter
,
RunUpdateOps
)
from
..utils
import
logger
from
..utils
import
logger
from
..utils.argtools
import
call_only_once
from
..utils.argtools
import
call_only_once
from
..tfutils
import
get_global_step_value
from
..tfutils.tower
import
TowerFuncWrapper
from
..tfutils.tower
import
TowerFuncWrapper
from
..tfutils.model_utils
import
describe_trainable_vars
from
..tfutils.model_utils
import
describe_trainable_vars
from
..tfutils.sessinit
import
JustCurrentSession
from
..tfutils.sessinit
import
JustCurrentSession
from
..tfutils.sesscreate
import
ReuseSessionCreator
,
NewSessionCreator
from
..tfutils.sesscreate
import
ReuseSessionCreator
,
NewSessionCreator
from
..callbacks.steps
import
MaintainStepCounter
from
..callbacks.steps
import
MaintainStepCounter
import
tensorpack.trainv1
as
old_train
# noqa
from
.config
import
TrainConfig
,
DEFAULT_MONITORS
,
DEFAULT_CALLBACKS
from
..trainv1.base
import
StopTraining
,
TrainLoop
from
..trainv1.config
import
TrainConfig
__all__
=
[
'StopTraining'
,
'TrainConfig'
,
__all__
=
[
'StopTraining'
,
'TrainConfig'
,
'Trainer'
]
'Trainer'
,
'DEFAULT_MONITORS'
,
'DEFAULT_CALLBACKS'
]
def
DEFAULT_CALLBACKS
(
):
class
StopTraining
(
BaseException
):
"""
"""
Return the default callbacks,
An exception thrown to stop training.
which will be used in :class:`TrainConfig` and :meth:`Trainer.train_with_defaults`.
"""
They are:
pass
class
TrainLoop
(
object
):
"""
Manage the double for loop.
"""
def
__init__
(
self
):
self
.
_epoch_num
=
0
self
.
_global_step
=
0
self
.
_local_step
=
-
1
def
config
(
self
,
steps_per_epoch
,
starting_epoch
,
max_epoch
):
"""
Configure the loop given the settings.
"""
self
.
starting_epoch
=
starting_epoch
self
.
max_epoch
=
max_epoch
self
.
steps_per_epoch
=
steps_per_epoch
self
.
_epoch_num
=
starting_epoch
-
1
1. MovingAverageSummary()
def
update_global_step
(
self
):
2. ProgressBar()
"""
3. MergeAllSummaries()
Update the Python-side global_step from TF.
4. RunUpdateOps()
This must be called under initialized default session.
"""
"""
return
[
self
.
_global_step
=
get_global_step_value
()
MovingAverageSummary
(),
ProgressBar
(),
MergeAllSummaries
(),
RunUpdateOps
()]
@
property
def
epoch_num
(
self
):
"""
The number of the currently ongoing epoch.
def
DEFAULT_MONITORS
():
An epoch is defined to cover the moment before calling `before_epoch` until after calling `trigger_epoch`.
i.e., in the `trigger_epoch` of epoch 3, `self.epoch_num` is 3.
If you need use `self.epoch_num` in your callback, you'll need to know this.
"""
"""
Return the default monitors,
return
self
.
_epoch_num
which will be used in :class:`TrainConfig` and :meth:`Trainer.train_with_defaults`.
They are:
1. TFEventWriter()
@
property
2. JSONWriter()
def
global_step
(
self
):
3. ScalarPrinter()
"""
The tensorflow global_step, i.e. how many times ``hooked_sess.run`` has been called.
Note:
1. global_step is incremented **after** each ``hooked_sess.run`` returns from TF runtime.
2. If you make zero or more than one calls to ``hooked_sess.run`` in one
:meth:`run_step`, local_step and global_step may increment at different speed.
"""
return
self
.
_global_step
@
property
def
local_step
(
self
):
"""
The number of steps that have finished in the current epoch.
"""
"""
return
[
TFEventWriter
(),
JSONWriter
(),
ScalarPrinter
()]
return
self
.
_local_step
class
Trainer
(
object
):
class
Trainer
(
object
):
...
@@ -277,7 +306,8 @@ class Trainer(object):
...
@@ -277,7 +306,8 @@ class Trainer(object):
or
'config'
in
kwargs
:
or
'config'
in
kwargs
:
name
=
cls
.
__name__
name
=
cls
.
__name__
try
:
try
:
old_trainer
=
getattr
(
old_train
,
name
)
import
tensorpack.trainv1
as
old_train_mod
# noqa
old_trainer
=
getattr
(
old_train_mod
,
name
)
except
AttributeError
:
except
AttributeError
:
# custom trainer. has to live with it
# custom trainer. has to live with it
return
super
(
Trainer
,
cls
)
.
__new__
(
cls
)
return
super
(
Trainer
,
cls
)
.
__new__
(
cls
)
...
...
tensorpack/train/config.py
0 → 100644
View file @
be50085f
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# File: config.py
from
..callbacks
import
(
MovingAverageSummary
,
ProgressBar
,
MergeAllSummaries
,
TFEventWriter
,
JSONWriter
,
ScalarPrinter
,
RunUpdateOps
)
from
..dataflow.base
import
DataFlow
from
..graph_builder.model_desc
import
ModelDescBase
from
..utils
import
logger
from
..tfutils
import
(
JustCurrentSession
,
SessionInit
)
from
..tfutils.sesscreate
import
NewSessionCreator
from
..input_source
import
InputSource
from
..utils.develop
import
log_deprecated
__all__
=
[
'TrainConfig'
,
'DEFAULT_CALLBACKS'
,
'DEFAULT_MONITORS'
]
def
DEFAULT_CALLBACKS
():
"""
Return the default callbacks,
which will be used in :class:`TrainConfig` and :meth:`Trainer.train_with_defaults`.
They are:
1. MovingAverageSummary()
2. ProgressBar()
3. MergeAllSummaries()
4. RunUpdateOps()
"""
return
[
MovingAverageSummary
(),
ProgressBar
(),
MergeAllSummaries
(),
RunUpdateOps
()]
def
DEFAULT_MONITORS
():
"""
Return the default monitors,
which will be used in :class:`TrainConfig` and :meth:`Trainer.train_with_defaults`.
They are:
1. TFEventWriter()
2. JSONWriter()
3. ScalarPrinter()
"""
return
[
TFEventWriter
(),
JSONWriter
(),
ScalarPrinter
()]
class
TrainConfig
(
object
):
"""
A collection of options to be used for trainers.
"""
def
__init__
(
self
,
dataflow
=
None
,
data
=
None
,
model
=
None
,
callbacks
=
None
,
extra_callbacks
=
None
,
monitors
=
None
,
session_creator
=
None
,
session_config
=
None
,
session_init
=
None
,
starting_epoch
=
1
,
steps_per_epoch
=
None
,
max_epoch
=
99999
,
nr_tower
=
1
,
tower
=
None
,
**
kwargs
):
"""
Args:
dataflow (DataFlow):
data (InputSource):
model (ModelDescBase):
callbacks (list): a list of :class:`Callback` to perform during training.
extra_callbacks (list): the same as ``callbacks``. This argument
is only used to provide the defaults in addition to ``callbacks``. The defaults are
``MovingAverageSummary()``, ``ProgressBar()``,
``MergeAllSummaries()``, ``RunUpdateOps()``. The list of
callbacks that will be used in the end is ``callbacks + extra_callbacks``.
monitors (list): a list of :class:`TrainingMonitor`.
Defaults to ``TFEventWriter()``, ``JSONWriter()``, ``ScalarPrinter()``.
session_creator (tf.train.SessionCreator): Defaults to :class:`sesscreate.NewSessionCreator()`
with the config returned by :func:`tfutils.get_default_sess_config()`.
session_config (tf.ConfigProto): when session_creator is None, use this to create the session.
session_init (SessionInit): how to initialize variables of a session. Defaults to do nothing.
starting_epoch (int): The index of the first epoch.
steps_per_epoch (int): the number of steps (defined by :meth:`Trainer.run_step`) to run in each epoch.
Defaults to the input data size.
max_epoch (int): maximum number of epoch to run training.
nr_tower (int): number of training towers, used by multigpu trainers.
tower ([int]): list of training towers in relative GPU id.
"""
# TODO type checker decorator
def
assert_type
(
v
,
tp
):
assert
isinstance
(
v
,
tp
),
v
.
__class__
# process data & model
assert
data
is
None
or
dataflow
is
None
,
"dataflow and data cannot be both presented in TrainConfig!"
if
dataflow
is
not
None
:
assert_type
(
dataflow
,
DataFlow
)
if
data
is
not
None
:
assert_type
(
data
,
InputSource
)
self
.
dataflow
=
dataflow
self
.
data
=
data
if
model
is
not
None
:
assert_type
(
model
,
ModelDescBase
)
self
.
model
=
model
if
callbacks
is
None
:
callbacks
=
[]
assert_type
(
callbacks
,
list
)
self
.
_callbacks
=
callbacks
+
\
(
extra_callbacks
or
DEFAULT_CALLBACKS
())
self
.
monitors
=
monitors
or
DEFAULT_MONITORS
()
if
session_init
is
None
:
session_init
=
JustCurrentSession
()
self
.
session_init
=
session_init
assert_type
(
self
.
session_init
,
SessionInit
)
if
session_creator
is
None
:
if
session_config
is
not
None
:
self
.
session_creator
=
NewSessionCreator
(
config
=
session_config
)
else
:
self
.
session_creator
=
NewSessionCreator
(
config
=
None
)
else
:
self
.
session_creator
=
session_creator
assert
session_config
is
None
,
"Cannot set both session_creator and session_config!"
if
steps_per_epoch
is
None
:
try
:
if
dataflow
is
not
None
:
steps_per_epoch
=
dataflow
.
size
()
elif
data
is
not
None
:
steps_per_epoch
=
data
.
size
()
else
:
raise
NotImplementedError
()
except
NotImplementedError
:
logger
.
error
(
"You must set `TrainConfig(steps_per_epoch)` if data.size() is not available."
)
raise
else
:
steps_per_epoch
=
int
(
steps_per_epoch
)
self
.
steps_per_epoch
=
steps_per_epoch
self
.
starting_epoch
=
int
(
starting_epoch
)
self
.
max_epoch
=
int
(
max_epoch
)
assert
self
.
steps_per_epoch
>
0
and
self
.
max_epoch
>
0
nr_tower
=
max
(
nr_tower
,
1
)
self
.
nr_tower
=
nr_tower
if
tower
is
not
None
:
assert
self
.
nr_tower
==
1
,
"Cannot set both nr_tower and tower in TrainConfig!"
self
.
tower
=
tower
predict_tower
=
kwargs
.
pop
(
'predict_tower'
,
None
)
if
predict_tower
is
not
None
:
log_deprecated
(
"TrainConfig(predict_tower=)"
,
"InferenceRunner now accepts a 'device' argument."
,
"2017-12-31"
)
self
.
predict_tower
=
predict_tower
if
isinstance
(
self
.
predict_tower
,
int
):
self
.
predict_tower
=
[
self
.
predict_tower
]
assert
len
(
kwargs
)
==
0
,
'Unknown arguments: {}'
.
format
(
str
(
kwargs
.
keys
()))
@
property
def
nr_tower
(
self
):
return
len
(
self
.
tower
)
@
nr_tower
.
setter
def
nr_tower
(
self
,
value
):
self
.
tower
=
list
(
range
(
value
))
@
property
def
callbacks
(
self
):
# disable setter
return
self
.
_callbacks
tensorpack/train/interface.py
View file @
be50085f
...
@@ -7,7 +7,7 @@ import tensorflow as tf
...
@@ -7,7 +7,7 @@ import tensorflow as tf
from
..input_source
import
(
from
..input_source
import
(
InputSource
,
FeedInput
,
QueueInput
,
StagingInput
,
DummyConstantInput
)
InputSource
,
FeedInput
,
QueueInput
,
StagingInput
,
DummyConstantInput
)
from
.
.trainv1.
config
import
TrainConfig
from
.config
import
TrainConfig
from
.tower
import
SingleCostTrainer
from
.tower
import
SingleCostTrainer
from
.trainers
import
SimpleTrainer
,
DistributedTrainerReplicated
from
.trainers
import
SimpleTrainer
,
DistributedTrainerReplicated
...
...
tensorpack/train/tower.py
View file @
be50085f
...
@@ -121,12 +121,10 @@ class SingleCostTrainer(TowerTrainer):
...
@@ -121,12 +121,10 @@ class SingleCostTrainer(TowerTrainer):
optimizer. Will only be called once.
optimizer. Will only be called once.
Note:
Note:
1. `get_cost_fn` will always be called under a :class:`TowerContext`.
`get_cost_fn` will be the tower function.
which will contain information about reuse,
It must follows the
training/inference, scope name, etc.
`rules of tower function.
2. `get_cost_fn` might get called multiple times for data-parallel training or inference.
<http://tensorpack.readthedocs.io/en/latest/tutorial/trainer.html#tower-trainer>`_.
3. To respect variable reuse, use `tf.get_variable` instead of
`tf.Variable` in `get_cost_fn`.
"""
"""
get_cost_fn
=
TowerFuncWrapper
(
get_cost_fn
,
inputs_desc
)
get_cost_fn
=
TowerFuncWrapper
(
get_cost_fn
,
inputs_desc
)
get_opt_fn
=
memoized
(
get_opt_fn
)
get_opt_fn
=
memoized
(
get_opt_fn
)
...
...
tensorpack/trainv1/base.py
View file @
be50085f
...
@@ -14,7 +14,6 @@ from ..utils import logger
...
@@ -14,7 +14,6 @@ from ..utils import logger
from
..utils.develop
import
log_deprecated
from
..utils.develop
import
log_deprecated
from
..callbacks
import
Callback
,
Callbacks
from
..callbacks
import
Callback
,
Callbacks
from
..callbacks.monitor
import
Monitors
,
TrainingMonitor
from
..callbacks.monitor
import
Monitors
,
TrainingMonitor
from
..tfutils
import
get_global_step_value
from
..tfutils.model_utils
import
describe_trainable_vars
from
..tfutils.model_utils
import
describe_trainable_vars
from
..tfutils.sesscreate
import
ReuseSessionCreator
from
..tfutils.sesscreate
import
ReuseSessionCreator
from
..tfutils.sessinit
import
JustCurrentSession
from
..tfutils.sessinit
import
JustCurrentSession
...
@@ -25,72 +24,9 @@ from ..graph_builder.predict import SimplePredictBuilder
...
@@ -25,72 +24,9 @@ from ..graph_builder.predict import SimplePredictBuilder
from
..predict.base
import
OnlinePredictor
from
..predict.base
import
OnlinePredictor
from
..callbacks.steps
import
MaintainStepCounter
from
..callbacks.steps
import
MaintainStepCounter
__all__
=
[
'Trainer'
,
'StopTraining'
]
from
..train.base
import
StopTraining
,
TrainLoop
class
StopTraining
(
BaseException
):
"""
An exception thrown to stop training.
"""
pass
class
TrainLoop
(
object
):
"""
Manage the double for loop.
"""
def
__init__
(
self
):
self
.
_epoch_num
=
0
self
.
_global_step
=
0
self
.
_local_step
=
-
1
def
config
(
self
,
steps_per_epoch
,
starting_epoch
,
max_epoch
):
"""
Configure the loop given the settings.
"""
self
.
starting_epoch
=
starting_epoch
self
.
max_epoch
=
max_epoch
self
.
steps_per_epoch
=
steps_per_epoch
self
.
_epoch_num
=
starting_epoch
-
1
def
update_global_step
(
self
):
"""
Update the Python-side global_step from TF.
This must be called under initialized default session.
"""
self
.
_global_step
=
get_global_step_value
()
@
property
def
epoch_num
(
self
):
"""
The number of the currently ongoing epoch.
An epoch is defined to cover the moment before calling `before_epoch` until after calling `trigger_epoch`.
i.e., in the `trigger_epoch` of epoch 3, `self.epoch_num` is 3.
If you need use `self.epoch_num` in your callback, you'll need to know this.
"""
return
self
.
_epoch_num
@
property
def
global_step
(
self
):
"""
The tensorflow global_step, i.e. how many times ``hooked_sess.run`` has been called.
Note:
1. global_step is incremented **after** each ``hooked_sess.run`` returns from TF runtime.
2. If you make zero or more than one calls to ``hooked_sess.run`` in one
:meth:`run_step`, local_step and global_step may increment at different speed.
"""
return
self
.
_global_step
@
property
__all__
=
[
'Trainer'
,
'StopTraining'
]
def
local_step
(
self
):
"""
The number of steps that have finished in the current epoch.
"""
return
self
.
_local_step
class
Trainer
(
object
):
class
Trainer
(
object
):
...
...
tensorpack/trainv1/config.py
View file @
be50085f
...
@@ -2,156 +2,6 @@
...
@@ -2,156 +2,6 @@
# File: config.py
# File: config.py
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
from
..callbacks
import
(
MovingAverageSummary
,
ProgressBar
,
MergeAllSummaries
,
TFEventWriter
,
JSONWriter
,
ScalarPrinter
,
RunUpdateOps
)
from
..dataflow.base
import
DataFlow
from
..graph_builder.model_desc
import
ModelDescBase
from
..utils
import
logger
from
..tfutils
import
(
JustCurrentSession
,
SessionInit
)
from
..tfutils.sesscreate
import
NewSessionCreator
from
..input_source
import
InputSource
from
..utils.develop
import
log_deprecated
__all__
=
[
'TrainConfig'
]
__all__
=
[
'TrainConfig'
]
from
..train.config
import
TrainConfig
def
DEFAULT_CALLBACKS
():
return
[
MovingAverageSummary
(),
ProgressBar
(),
MergeAllSummaries
(),
RunUpdateOps
()]
def
DEFAULT_MONITORS
():
return
[
TFEventWriter
(),
JSONWriter
(),
ScalarPrinter
()]
class
TrainConfig
(
object
):
"""
A collection of options to be used for trainers.
"""
def
__init__
(
self
,
dataflow
=
None
,
data
=
None
,
model
=
None
,
callbacks
=
None
,
extra_callbacks
=
None
,
monitors
=
None
,
session_creator
=
None
,
session_config
=
None
,
session_init
=
None
,
starting_epoch
=
1
,
steps_per_epoch
=
None
,
max_epoch
=
99999
,
nr_tower
=
1
,
tower
=
None
,
**
kwargs
):
"""
Args:
dataflow (DataFlow):
data (InputSource):
model (ModelDescBase):
callbacks (list): a list of :class:`Callback` to perform during training.
extra_callbacks (list): the same as ``callbacks``. This argument
is only used to provide the defaults in addition to ``callbacks``. The defaults are
``MovingAverageSummary()``, ``ProgressBar()``,
``MergeAllSummaries()``, ``RunUpdateOps()``. The list of
callbacks that will be used in the end is ``callbacks + extra_callbacks``.
monitors (list): a list of :class:`TrainingMonitor`.
Defaults to ``TFEventWriter()``, ``JSONWriter()``, ``ScalarPrinter()``.
session_creator (tf.train.SessionCreator): Defaults to :class:`sesscreate.NewSessionCreator()`
with the config returned by :func:`tfutils.get_default_sess_config()`.
session_config (tf.ConfigProto): when session_creator is None, use this to create the session.
session_init (SessionInit): how to initialize variables of a session. Defaults to do nothing.
starting_epoch (int): The index of the first epoch.
steps_per_epoch (int): the number of steps (defined by :meth:`Trainer.run_step`) to run in each epoch.
Defaults to the input data size.
max_epoch (int): maximum number of epoch to run training.
nr_tower (int): number of training towers, used by multigpu trainers.
tower ([int]): list of training towers in relative GPU id.
"""
# TODO type checker decorator
def
assert_type
(
v
,
tp
):
assert
isinstance
(
v
,
tp
),
v
.
__class__
# process data & model
assert
data
is
None
or
dataflow
is
None
,
"dataflow and data cannot be both presented in TrainConfig!"
if
dataflow
is
not
None
:
assert_type
(
dataflow
,
DataFlow
)
if
data
is
not
None
:
assert_type
(
data
,
InputSource
)
self
.
dataflow
=
dataflow
self
.
data
=
data
if
model
is
not
None
:
assert_type
(
model
,
ModelDescBase
)
self
.
model
=
model
if
callbacks
is
None
:
callbacks
=
[]
assert_type
(
callbacks
,
list
)
self
.
_callbacks
=
callbacks
+
\
(
extra_callbacks
or
DEFAULT_CALLBACKS
())
self
.
monitors
=
monitors
or
DEFAULT_MONITORS
()
if
session_init
is
None
:
session_init
=
JustCurrentSession
()
self
.
session_init
=
session_init
assert_type
(
self
.
session_init
,
SessionInit
)
if
session_creator
is
None
:
if
session_config
is
not
None
:
self
.
session_creator
=
NewSessionCreator
(
config
=
session_config
)
else
:
self
.
session_creator
=
NewSessionCreator
(
config
=
None
)
else
:
self
.
session_creator
=
session_creator
assert
session_config
is
None
,
"Cannot set both session_creator and session_config!"
if
steps_per_epoch
is
None
:
try
:
if
dataflow
is
not
None
:
steps_per_epoch
=
dataflow
.
size
()
elif
data
is
not
None
:
steps_per_epoch
=
data
.
size
()
else
:
raise
NotImplementedError
()
except
NotImplementedError
:
logger
.
error
(
"You must set `TrainConfig(steps_per_epoch)` if data.size() is not available."
)
raise
else
:
steps_per_epoch
=
int
(
steps_per_epoch
)
self
.
steps_per_epoch
=
steps_per_epoch
self
.
starting_epoch
=
int
(
starting_epoch
)
self
.
max_epoch
=
int
(
max_epoch
)
assert
self
.
steps_per_epoch
>
0
and
self
.
max_epoch
>
0
nr_tower
=
max
(
nr_tower
,
1
)
self
.
nr_tower
=
nr_tower
if
tower
is
not
None
:
assert
self
.
nr_tower
==
1
,
"Cannot set both nr_tower and tower in TrainConfig!"
self
.
tower
=
tower
predict_tower
=
kwargs
.
pop
(
'predict_tower'
,
None
)
if
predict_tower
is
not
None
:
log_deprecated
(
"TrainConfig(predict_tower=)"
,
"InferenceRunner now accepts a 'device' argument."
,
"2017-12-31"
)
self
.
predict_tower
=
predict_tower
if
isinstance
(
self
.
predict_tower
,
int
):
self
.
predict_tower
=
[
self
.
predict_tower
]
assert
len
(
kwargs
)
==
0
,
'Unknown arguments: {}'
.
format
(
str
(
kwargs
.
keys
()))
@
property
def
nr_tower
(
self
):
return
len
(
self
.
tower
)
@
nr_tower
.
setter
def
nr_tower
(
self
,
value
):
self
.
tower
=
list
(
range
(
value
))
@
property
def
callbacks
(
self
):
# disable setter
return
self
.
_callbacks
tensorpack/trainv1/interface.py
View file @
be50085f
...
@@ -5,7 +5,4 @@
...
@@ -5,7 +5,4 @@
__all__
=
[
'launch_train_with_config'
]
__all__
=
[
'launch_train_with_config'
]
from
..train.interface
import
launch_train_with_config
def
launch_train_with_config
(
config
,
trainer
):
from
..train.interface
import
launch_train_with_config
as
old_launch
old_launch
(
config
,
trainer
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment