Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
d50341b8
Commit
d50341b8
authored
Jan 05, 2017
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
refactor around trainer and add some docs
parent
651a5aea
Changes
6
Show whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
73 additions
and
26 deletions
+73
-26
tensorpack/train/base.py
tensorpack/train/base.py
+3
-0
tensorpack/train/config.py
tensorpack/train/config.py
+1
-1
tensorpack/train/feedfree.py
tensorpack/train/feedfree.py
+8
-7
tensorpack/train/input_data.py
tensorpack/train/input_data.py
+27
-8
tensorpack/train/multigpu.py
tensorpack/train/multigpu.py
+25
-6
tensorpack/train/trainer.py
tensorpack/train/trainer.py
+9
-4
No files found.
tensorpack/train/base.py
View file @
d50341b8
...
@@ -77,6 +77,9 @@ class Trainer(object):
...
@@ -77,6 +77,9 @@ class Trainer(object):
Can be overwritten by subclasses to exploit more
Can be overwritten by subclasses to exploit more
parallelism among predictors.
parallelism among predictors.
"""
"""
if
len
(
self
.
config
.
predict_tower
)
>
1
:
logger
.
warn
(
"[Speed] Have set multiple predict_tower, but only have naive `get_predict_funcs` implementation"
)
return
[
self
.
get_predict_func
(
input_names
,
output_names
)
for
k
in
range
(
n
)]
return
[
self
.
get_predict_func
(
input_names
,
output_names
)
for
k
in
range
(
n
)]
def
trigger_epoch
(
self
):
def
trigger_epoch
(
self
):
...
...
tensorpack/train/config.py
View file @
d50341b8
...
@@ -43,7 +43,7 @@ class TrainConfig(object):
...
@@ -43,7 +43,7 @@ class TrainConfig(object):
max_epoch (int): maximum number of epoch to run training.
max_epoch (int): maximum number of epoch to run training.
nr_tower (int): number of training towers.
nr_tower (int): number of training towers.
tower (list of int): list of training towers in relative id.
tower (list of int): list of training towers in relative id.
predict_tower (list of int): list of prediction towers in their relative gpu id.
predict_tower (list of int): list of prediction towers in their relative gpu id.
Use -1 for cpu.
"""
"""
# TODO type checker decorator
# TODO type checker decorator
...
...
tensorpack/train/feedfree.py
View file @
d50341b8
...
@@ -92,7 +92,7 @@ class SimpleFeedfreeTrainer(
...
@@ -92,7 +92,7 @@ class SimpleFeedfreeTrainer(
self
.
_input_method
=
config
.
data
self
.
_input_method
=
config
.
data
assert
isinstance
(
self
.
_input_method
,
FeedfreeInput
),
self
.
_input_method
assert
isinstance
(
self
.
_input_method
,
FeedfreeInput
),
self
.
_input_method
super
(
SimpleFeedfreeTrainer
,
self
)
.
__init__
(
config
)
super
(
SimpleFeedfreeTrainer
,
self
)
.
__init__
(
config
)
self
.
_setup_predictor_factory
(
config
.
predict_tower
)
self
.
_setup_predictor_factory
()
assert
len
(
self
.
config
.
tower
)
==
1
,
\
assert
len
(
self
.
config
.
tower
)
==
1
,
\
"SimpleFeedfreeTrainer doesn't support multigpu!"
"SimpleFeedfreeTrainer doesn't support multigpu!"
...
@@ -111,22 +111,23 @@ class SimpleFeedfreeTrainer(
...
@@ -111,22 +111,23 @@ class SimpleFeedfreeTrainer(
class
QueueInputTrainer
(
SimpleFeedfreeTrainer
):
class
QueueInputTrainer
(
SimpleFeedfreeTrainer
):
"""
"""
A trainer which automatically wraps ``config.dataflow``
A trainer which automatically wraps ``config.dataflow`` by a
:class:`QueueInput`.
"""
"""
def
__init__
(
self
,
config
,
input_queue
=
None
,
predict_tower
=
None
):
def
__init__
(
self
,
config
,
input_queue
=
None
,
predict_tower
=
None
):
"""
"""
Single tower Trainer, takes input from a queue
Single tower Trainer, takes input from a queue
:param config: a `TrainConfig` instance. config.dataflow must exist
Args:
:param input_queue: a `tf.QueueBase` instance
config(TrainConfig): a `TrainConfig` instance. config.dataflow must exist.
:param predict_tower: list of gpu relative idx to run prediction. default to be [0].
input_queue(tf.QueueBase): an input queue. Defaults to the
Use -1 for cpu
.
:class:`QueueInput` default
.
"""
"""
config
.
data
=
QueueInput
(
config
.
dataflow
,
input_queue
)
config
.
data
=
QueueInput
(
config
.
dataflow
,
input_queue
)
if
predict_tower
is
not
None
:
if
predict_tower
is
not
None
:
logger
.
warn
(
"[Deprecated] Argument `predict_tower` is deprecated for trainer. "
logger
.
warn
(
"[Deprecated] Argument `predict_tower` is deprecated for trainer. "
"Use TrainConfig
.predict_tower
instead!"
)
"Use TrainConfig
(predict_tower=...)
instead!"
)
config
.
predict_tower
=
predict_tower
config
.
predict_tower
=
predict_tower
assert
len
(
config
.
tower
)
==
1
,
\
assert
len
(
config
.
tower
)
==
1
,
\
"QueueInputTrainer doesn't support multigpu! Use Sync/AsyncMultiGPUTrainer instead."
"QueueInputTrainer doesn't support multigpu! Use Sync/AsyncMultiGPUTrainer instead."
...
...
tensorpack/train/input_data.py
View file @
d50341b8
...
@@ -19,12 +19,17 @@ __all__ = ['InputData', 'QueueInput', 'FeedfreeInput', 'TensorInput',
...
@@ -19,12 +19,17 @@ __all__ = ['InputData', 'QueueInput', 'FeedfreeInput', 'TensorInput',
@
six
.
add_metaclass
(
ABCMeta
)
@
six
.
add_metaclass
(
ABCMeta
)
class
InputData
(
object
):
class
InputData
(
object
):
""" Base class for the abstract InputData. """
pass
pass
class
FeedInput
(
InputData
):
class
FeedInput
(
InputData
):
""" Input by iterating over a DataFlow and feed datapoints. """
def
__init__
(
self
,
ds
):
def
__init__
(
self
,
ds
):
"""
Args:
ds (DataFlow): the input DataFlow.
"""
assert
isinstance
(
ds
,
DataFlow
),
ds
assert
isinstance
(
ds
,
DataFlow
),
ds
self
.
ds
=
ds
self
.
ds
=
ds
...
@@ -44,8 +49,14 @@ class FeedInput(InputData):
...
@@ -44,8 +49,14 @@ class FeedInput(InputData):
class
FeedfreeInput
(
InputData
):
class
FeedfreeInput
(
InputData
):
""" Abstract base for input without feed,
e.g. by queue or other operations. """
def
get_input_tensors
(
self
):
def
get_input_tensors
(
self
):
"""
Returns:
list: A list of tensors corresponding to the inputs of the model.
"""
return
self
.
_get_input_tensors
()
return
self
.
_get_input_tensors
()
@
abstractmethod
@
abstractmethod
...
@@ -100,12 +111,14 @@ class EnqueueThread(threading.Thread):
...
@@ -100,12 +111,14 @@ class EnqueueThread(threading.Thread):
class
QueueInput
(
FeedfreeInput
):
class
QueueInput
(
FeedfreeInput
):
""" Input by enqueueing datapoints from a DataFlow to a TF queue, and dequeue
tensors to the graph. """
def
__init__
(
self
,
ds
,
queue
=
None
):
def
__init__
(
self
,
ds
,
queue
=
None
):
"""
"""
:param ds: a `DataFlow` instance
Args:
:param queue: a `tf.QueueBase` instance to be used to buffer datapoints
.
ds(DataFlow): the input DataFlow
.
Defaults to a FIFO queue of size 50.
queue (tf.QueueBase):
Defaults to a FIFO queue of size 50.
"""
"""
assert
isinstance
(
ds
,
DataFlow
),
ds
assert
isinstance
(
ds
,
DataFlow
),
ds
self
.
queue
=
queue
self
.
queue
=
queue
...
@@ -142,11 +155,10 @@ class QueueInput(FeedfreeInput):
...
@@ -142,11 +155,10 @@ class QueueInput(FeedfreeInput):
return
ret
return
ret
class
DummyConstantInput
(
Queu
eInput
):
class
DummyConstantInput
(
Feedfre
eInput
):
"""
o
nly for debugging performance issues """
"""
Input some constant variables. O
nly for debugging performance issues """
def
__init__
(
self
,
ds
,
shapes
):
def
__init__
(
self
,
shapes
):
super
(
DummyConstantInput
,
self
)
.
__init__
(
ds
)
self
.
shapes
=
shapes
self
.
shapes
=
shapes
logger
.
warn
(
"Using dummy input for debug!"
)
logger
.
warn
(
"Using dummy input for debug!"
)
...
@@ -163,8 +175,15 @@ class DummyConstantInput(QueueInput):
...
@@ -163,8 +175,15 @@ class DummyConstantInput(QueueInput):
class
TensorInput
(
FeedfreeInput
):
class
TensorInput
(
FeedfreeInput
):
""" Input from a list of tensors, e.g. a TF data reading pipeline. """
def
__init__
(
self
,
get_tensor_fn
,
size
=
None
):
def
__init__
(
self
,
get_tensor_fn
,
size
=
None
):
"""
Args:
get_tensor_fn: a function which returns a list of input tensors
when called.
size(int): size of this input. Use None to leave it undefined.
"""
self
.
get_tensor_fn
=
get_tensor_fn
self
.
get_tensor_fn
=
get_tensor_fn
self
.
_size
=
size
self
.
_size
=
size
...
...
tensorpack/train/multigpu.py
View file @
d50341b8
...
@@ -21,7 +21,7 @@ from .trainer import MultiPredictorTowerTrainer
...
@@ -21,7 +21,7 @@ from .trainer import MultiPredictorTowerTrainer
from
.feedfree
import
SingleCostFeedfreeTrainer
from
.feedfree
import
SingleCostFeedfreeTrainer
from
.input_data
import
QueueInput
from
.input_data
import
QueueInput
__all__
=
[
'
AsyncMultiGPUTrainer'
,
'S
yncMultiGPUTrainer'
]
__all__
=
[
'
SyncMultiGPUTrainer'
,
'As
yncMultiGPUTrainer'
]
class
MultiGPUTrainer
(
Trainer
):
class
MultiGPUTrainer
(
Trainer
):
...
@@ -51,8 +51,16 @@ class MultiGPUTrainer(Trainer):
...
@@ -51,8 +51,16 @@ class MultiGPUTrainer(Trainer):
class
SyncMultiGPUTrainer
(
MultiGPUTrainer
,
class
SyncMultiGPUTrainer
(
MultiGPUTrainer
,
SingleCostFeedfreeTrainer
,
SingleCostFeedfreeTrainer
,
MultiPredictorTowerTrainer
):
MultiPredictorTowerTrainer
):
"""
A multi-tower multi-GPU trainer which synchronoizes the gradients computed
from each tower and averages them.
"""
def
__init__
(
self
,
config
,
input_queue
=
None
,
predict_tower
=
None
):
def
__init__
(
self
,
config
,
input_queue
=
None
,
predict_tower
=
None
):
"""
Args:
config, input_queue: same as in :class:`QueueInputTrainer`.
"""
if
config
.
dataflow
is
not
None
:
if
config
.
dataflow
is
not
None
:
self
.
_input_method
=
QueueInput
(
config
.
dataflow
,
input_queue
)
self
.
_input_method
=
QueueInput
(
config
.
dataflow
,
input_queue
)
else
:
else
:
...
@@ -65,7 +73,7 @@ class SyncMultiGPUTrainer(MultiGPUTrainer,
...
@@ -65,7 +73,7 @@ class SyncMultiGPUTrainer(MultiGPUTrainer,
config
.
predict_tower
=
predict_tower
config
.
predict_tower
=
predict_tower
super
(
SyncMultiGPUTrainer
,
self
)
.
__init__
(
config
)
super
(
SyncMultiGPUTrainer
,
self
)
.
__init__
(
config
)
self
.
_setup_predictor_factory
(
config
.
predict_tower
)
self
.
_setup_predictor_factory
()
assert
len
(
config
.
tower
)
>=
1
,
"MultiGPUTrainer must be used with at least one GPU."
assert
len
(
config
.
tower
)
>=
1
,
"MultiGPUTrainer must be used with at least one GPU."
assert
tf
.
test
.
is_gpu_available
()
assert
tf
.
test
.
is_gpu_available
()
...
@@ -117,11 +125,22 @@ class SyncMultiGPUTrainer(MultiGPUTrainer,
...
@@ -117,11 +125,22 @@ class SyncMultiGPUTrainer(MultiGPUTrainer,
class
AsyncMultiGPUTrainer
(
MultiGPUTrainer
,
class
AsyncMultiGPUTrainer
(
MultiGPUTrainer
,
SingleCostFeedfreeTrainer
,
SingleCostFeedfreeTrainer
,
MultiPredictorTowerTrainer
):
MultiPredictorTowerTrainer
):
"""
A multi-tower multi-GPU trainer where each tower independently
asynchronously updates the model without locking.
"""
def
__init__
(
self
,
config
,
def
__init__
(
self
,
config
,
input_queue
=
None
,
input_queue
=
None
,
averag
e_gradient
=
True
,
scal
e_gradient
=
True
,
predict_tower
=
None
):
predict_tower
=
None
):
"""
Args:
config, input_queue: same as in :class:`QueueInputTrainer`.
scale_gradient (bool): if True, will scale each gradient by
``1.0/nr_tower``, to make Async and Sync Trainer have the same
effective learning rate.
"""
if
config
.
dataflow
is
not
None
:
if
config
.
dataflow
is
not
None
:
self
.
_input_method
=
QueueInput
(
config
.
dataflow
,
input_queue
)
self
.
_input_method
=
QueueInput
(
config
.
dataflow
,
input_queue
)
else
:
else
:
...
@@ -134,8 +153,8 @@ class AsyncMultiGPUTrainer(MultiGPUTrainer,
...
@@ -134,8 +153,8 @@ class AsyncMultiGPUTrainer(MultiGPUTrainer,
"Use TrainConfig.predict_tower instead!"
)
"Use TrainConfig.predict_tower instead!"
)
config
.
predict_tower
=
predict_tower
config
.
predict_tower
=
predict_tower
self
.
_setup_predictor_factory
(
config
.
predict_tower
)
self
.
_setup_predictor_factory
()
self
.
_
average_gradient
=
averag
e_gradient
self
.
_
scale_gradient
=
scal
e_gradient
assert
tf
.
test
.
is_gpu_available
()
assert
tf
.
test
.
is_gpu_available
()
def
_setup
(
self
):
def
_setup
(
self
):
...
@@ -143,7 +162,7 @@ class AsyncMultiGPUTrainer(MultiGPUTrainer,
...
@@ -143,7 +162,7 @@ class AsyncMultiGPUTrainer(MultiGPUTrainer,
grad_list
=
MultiGPUTrainer
.
_multi_tower_grads
(
grad_list
=
MultiGPUTrainer
.
_multi_tower_grads
(
self
.
config
.
tower
,
lambda
:
self
.
_get_cost_and_grad
()[
1
])
self
.
config
.
tower
,
lambda
:
self
.
_get_cost_and_grad
()[
1
])
gradprocs
=
self
.
model
.
get_gradient_processor
()
gradprocs
=
self
.
model
.
get_gradient_processor
()
if
self
.
_
averag
e_gradient
and
self
.
config
.
nr_tower
>
1
:
if
self
.
_
scal
e_gradient
and
self
.
config
.
nr_tower
>
1
:
# pretend to average the grads, in order to make async and
# pretend to average the grads, in order to make async and
# sync have consistent effective learning rate
# sync have consistent effective learning rate
gradprocs
.
insert
(
0
,
ScaleGradient
((
'.*'
,
1.0
/
self
.
config
.
nr_tower
),
log
=
False
))
gradprocs
.
insert
(
0
,
ScaleGradient
((
'.*'
,
1.0
/
self
.
config
.
nr_tower
),
log
=
False
))
...
...
tensorpack/train/trainer.py
View file @
d50341b8
...
@@ -54,9 +54,14 @@ class PredictorFactory(object):
...
@@ -54,9 +54,14 @@ class PredictorFactory(object):
class
SimpleTrainer
(
Trainer
):
class
SimpleTrainer
(
Trainer
):
""" A naive demo trainer """
""" A naive demo trainer which iterates over a DataFlow and feed into the
graph. It's not efficient compared to QueueInputTrainer or others."""
def
__init__
(
self
,
config
):
def
__init__
(
self
,
config
):
"""
Args:
config (TrainConfig): the training config.
"""
super
(
SimpleTrainer
,
self
)
.
__init__
(
config
)
super
(
SimpleTrainer
,
self
)
.
__init__
(
config
)
self
.
_predictor_factory
=
PredictorFactory
(
self
.
sess
,
self
.
model
,
[
0
])
self
.
_predictor_factory
=
PredictorFactory
(
self
.
sess
,
self
.
model
,
[
0
])
if
config
.
dataflow
is
None
:
if
config
.
dataflow
is
None
:
...
@@ -66,6 +71,7 @@ class SimpleTrainer(Trainer):
...
@@ -66,6 +71,7 @@ class SimpleTrainer(Trainer):
self
.
_input_method
=
FeedInput
(
config
.
dataflow
)
self
.
_input_method
=
FeedInput
(
config
.
dataflow
)
def
run_step
(
self
):
def
run_step
(
self
):
""" Feed data into the graph and run the updates. """
feed
=
self
.
_input_method
.
next_feed
()
feed
=
self
.
_input_method
.
next_feed
()
self
.
sess
.
run
([
self
.
train_op
],
feed_dict
=
feed
)
# faster since train_op return None
self
.
sess
.
run
([
self
.
train_op
],
feed_dict
=
feed
)
# faster since train_op return None
...
@@ -99,11 +105,10 @@ class SimpleTrainer(Trainer):
...
@@ -99,11 +105,10 @@ class SimpleTrainer(Trainer):
class
MultiPredictorTowerTrainer
(
Trainer
):
class
MultiPredictorTowerTrainer
(
Trainer
):
""" A trainer with possibly multiple prediction tower """
""" A trainer with possibly multiple prediction tower """
def
_setup_predictor_factory
(
self
,
predict_tower
):
def
_setup_predictor_factory
(
self
):
# by default, use the first training gpu for prediction
# by default, use the first training gpu for prediction
predict_tower
=
predict_tower
or
[
0
]
self
.
_predictor_factory
=
PredictorFactory
(
self
.
_predictor_factory
=
PredictorFactory
(
self
.
sess
,
self
.
model
,
predict_tower
)
self
.
sess
,
self
.
model
,
self
.
config
.
predict_tower
)
def
get_predict_func
(
self
,
input_names
,
output_names
,
tower
=
0
):
def
get_predict_func
(
self
,
input_names
,
output_names
,
tower
=
0
):
"""
"""
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment