Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
d50341b8
Commit
d50341b8
authored
Jan 05, 2017
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
refactor around trainer and add some docs
parent
651a5aea
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
73 additions
and
26 deletions
+73
-26
tensorpack/train/base.py
tensorpack/train/base.py
+3
-0
tensorpack/train/config.py
tensorpack/train/config.py
+1
-1
tensorpack/train/feedfree.py
tensorpack/train/feedfree.py
+8
-7
tensorpack/train/input_data.py
tensorpack/train/input_data.py
+27
-8
tensorpack/train/multigpu.py
tensorpack/train/multigpu.py
+25
-6
tensorpack/train/trainer.py
tensorpack/train/trainer.py
+9
-4
No files found.
tensorpack/train/base.py
View file @
d50341b8
...
...
@@ -77,6 +77,9 @@ class Trainer(object):
Can be overwritten by subclasses to exploit more
parallelism among predictors.
"""
if
len
(
self
.
config
.
predict_tower
)
>
1
:
logger
.
warn
(
"[Speed] Have set multiple predict_tower, but only have naive `get_predict_funcs` implementation"
)
return
[
self
.
get_predict_func
(
input_names
,
output_names
)
for
k
in
range
(
n
)]
def
trigger_epoch
(
self
):
...
...
tensorpack/train/config.py
View file @
d50341b8
...
...
@@ -43,7 +43,7 @@ class TrainConfig(object):
max_epoch (int): maximum number of epoch to run training.
nr_tower (int): number of training towers.
tower (list of int): list of training towers in relative id.
predict_tower (list of int): list of prediction towers in their relative gpu id.
predict_tower (list of int): list of prediction towers in their relative gpu id.
Use -1 for cpu.
"""
# TODO type checker decorator
...
...
tensorpack/train/feedfree.py
View file @
d50341b8
...
...
@@ -92,7 +92,7 @@ class SimpleFeedfreeTrainer(
self
.
_input_method
=
config
.
data
assert
isinstance
(
self
.
_input_method
,
FeedfreeInput
),
self
.
_input_method
super
(
SimpleFeedfreeTrainer
,
self
)
.
__init__
(
config
)
self
.
_setup_predictor_factory
(
config
.
predict_tower
)
self
.
_setup_predictor_factory
()
assert
len
(
self
.
config
.
tower
)
==
1
,
\
"SimpleFeedfreeTrainer doesn't support multigpu!"
...
...
@@ -111,22 +111,23 @@ class SimpleFeedfreeTrainer(
class
QueueInputTrainer
(
SimpleFeedfreeTrainer
):
"""
A trainer which automatically wraps ``config.dataflow``
A trainer which automatically wraps ``config.dataflow`` by a
:class:`QueueInput`.
"""
def
__init__
(
self
,
config
,
input_queue
=
None
,
predict_tower
=
None
):
"""
Single tower Trainer, takes input from a queue
:param config: a `TrainConfig` instance. config.dataflow must exist
:param input_queue: a `tf.QueueBase` instance
:param predict_tower: list of gpu relative idx to run prediction. default to be [0].
Use -1 for cpu
.
Args:
config(TrainConfig): a `TrainConfig` instance. config.dataflow must exist.
input_queue(tf.QueueBase): an input queue. Defaults to the
:class:`QueueInput` default
.
"""
config
.
data
=
QueueInput
(
config
.
dataflow
,
input_queue
)
if
predict_tower
is
not
None
:
logger
.
warn
(
"[Deprecated] Argument `predict_tower` is deprecated for trainer. "
"Use TrainConfig
.predict_tower
instead!"
)
"Use TrainConfig
(predict_tower=...)
instead!"
)
config
.
predict_tower
=
predict_tower
assert
len
(
config
.
tower
)
==
1
,
\
"QueueInputTrainer doesn't support multigpu! Use Sync/AsyncMultiGPUTrainer instead."
...
...
tensorpack/train/input_data.py
View file @
d50341b8
...
...
@@ -19,12 +19,17 @@ __all__ = ['InputData', 'QueueInput', 'FeedfreeInput', 'TensorInput',
@
six
.
add_metaclass
(
ABCMeta
)
class
InputData
(
object
):
""" Base class for the abstract InputData. """
pass
class
FeedInput
(
InputData
):
""" Input by iterating over a DataFlow and feed datapoints. """
def
__init__
(
self
,
ds
):
"""
Args:
ds (DataFlow): the input DataFlow.
"""
assert
isinstance
(
ds
,
DataFlow
),
ds
self
.
ds
=
ds
...
...
@@ -44,8 +49,14 @@ class FeedInput(InputData):
class
FeedfreeInput
(
InputData
):
""" Abstract base for input without feed,
e.g. by queue or other operations. """
def
get_input_tensors
(
self
):
"""
Returns:
list: A list of tensors corresponding to the inputs of the model.
"""
return
self
.
_get_input_tensors
()
@
abstractmethod
...
...
@@ -100,12 +111,14 @@ class EnqueueThread(threading.Thread):
class
QueueInput
(
FeedfreeInput
):
""" Input by enqueueing datapoints from a DataFlow to a TF queue, and dequeue
tensors to the graph. """
def
__init__
(
self
,
ds
,
queue
=
None
):
"""
:param ds: a `DataFlow` instance
:param queue: a `tf.QueueBase` instance to be used to buffer datapoints
.
Defaults to a FIFO queue of size 50.
Args:
ds(DataFlow): the input DataFlow
.
queue (tf.QueueBase):
Defaults to a FIFO queue of size 50.
"""
assert
isinstance
(
ds
,
DataFlow
),
ds
self
.
queue
=
queue
...
...
@@ -142,11 +155,10 @@ class QueueInput(FeedfreeInput):
return
ret
class
DummyConstantInput
(
Queu
eInput
):
"""
o
nly for debugging performance issues """
class
DummyConstantInput
(
Feedfre
eInput
):
"""
Input some constant variables. O
nly for debugging performance issues """
def
__init__
(
self
,
ds
,
shapes
):
super
(
DummyConstantInput
,
self
)
.
__init__
(
ds
)
def
__init__
(
self
,
shapes
):
self
.
shapes
=
shapes
logger
.
warn
(
"Using dummy input for debug!"
)
...
...
@@ -163,8 +175,15 @@ class DummyConstantInput(QueueInput):
class
TensorInput
(
FeedfreeInput
):
""" Input from a list of tensors, e.g. a TF data reading pipeline. """
def
__init__
(
self
,
get_tensor_fn
,
size
=
None
):
"""
Args:
get_tensor_fn: a function which returns a list of input tensors
when called.
size(int): size of this input. Use None to leave it undefined.
"""
self
.
get_tensor_fn
=
get_tensor_fn
self
.
_size
=
size
...
...
tensorpack/train/multigpu.py
View file @
d50341b8
...
...
@@ -21,7 +21,7 @@ from .trainer import MultiPredictorTowerTrainer
from
.feedfree
import
SingleCostFeedfreeTrainer
from
.input_data
import
QueueInput
__all__
=
[
'
AsyncMultiGPUTrainer'
,
'S
yncMultiGPUTrainer'
]
__all__
=
[
'
SyncMultiGPUTrainer'
,
'As
yncMultiGPUTrainer'
]
class
MultiGPUTrainer
(
Trainer
):
...
...
@@ -51,8 +51,16 @@ class MultiGPUTrainer(Trainer):
class
SyncMultiGPUTrainer
(
MultiGPUTrainer
,
SingleCostFeedfreeTrainer
,
MultiPredictorTowerTrainer
):
"""
A multi-tower multi-GPU trainer which synchronoizes the gradients computed
from each tower and averages them.
"""
def
__init__
(
self
,
config
,
input_queue
=
None
,
predict_tower
=
None
):
"""
Args:
config, input_queue: same as in :class:`QueueInputTrainer`.
"""
if
config
.
dataflow
is
not
None
:
self
.
_input_method
=
QueueInput
(
config
.
dataflow
,
input_queue
)
else
:
...
...
@@ -65,7 +73,7 @@ class SyncMultiGPUTrainer(MultiGPUTrainer,
config
.
predict_tower
=
predict_tower
super
(
SyncMultiGPUTrainer
,
self
)
.
__init__
(
config
)
self
.
_setup_predictor_factory
(
config
.
predict_tower
)
self
.
_setup_predictor_factory
()
assert
len
(
config
.
tower
)
>=
1
,
"MultiGPUTrainer must be used with at least one GPU."
assert
tf
.
test
.
is_gpu_available
()
...
...
@@ -117,11 +125,22 @@ class SyncMultiGPUTrainer(MultiGPUTrainer,
class
AsyncMultiGPUTrainer
(
MultiGPUTrainer
,
SingleCostFeedfreeTrainer
,
MultiPredictorTowerTrainer
):
"""
A multi-tower multi-GPU trainer where each tower independently
asynchronously updates the model without locking.
"""
def
__init__
(
self
,
config
,
input_queue
=
None
,
averag
e_gradient
=
True
,
scal
e_gradient
=
True
,
predict_tower
=
None
):
"""
Args:
config, input_queue: same as in :class:`QueueInputTrainer`.
scale_gradient (bool): if True, will scale each gradient by
``1.0/nr_tower``, to make Async and Sync Trainer have the same
effective learning rate.
"""
if
config
.
dataflow
is
not
None
:
self
.
_input_method
=
QueueInput
(
config
.
dataflow
,
input_queue
)
else
:
...
...
@@ -134,8 +153,8 @@ class AsyncMultiGPUTrainer(MultiGPUTrainer,
"Use TrainConfig.predict_tower instead!"
)
config
.
predict_tower
=
predict_tower
self
.
_setup_predictor_factory
(
config
.
predict_tower
)
self
.
_
average_gradient
=
averag
e_gradient
self
.
_setup_predictor_factory
()
self
.
_
scale_gradient
=
scal
e_gradient
assert
tf
.
test
.
is_gpu_available
()
def
_setup
(
self
):
...
...
@@ -143,7 +162,7 @@ class AsyncMultiGPUTrainer(MultiGPUTrainer,
grad_list
=
MultiGPUTrainer
.
_multi_tower_grads
(
self
.
config
.
tower
,
lambda
:
self
.
_get_cost_and_grad
()[
1
])
gradprocs
=
self
.
model
.
get_gradient_processor
()
if
self
.
_
averag
e_gradient
and
self
.
config
.
nr_tower
>
1
:
if
self
.
_
scal
e_gradient
and
self
.
config
.
nr_tower
>
1
:
# pretend to average the grads, in order to make async and
# sync have consistent effective learning rate
gradprocs
.
insert
(
0
,
ScaleGradient
((
'.*'
,
1.0
/
self
.
config
.
nr_tower
),
log
=
False
))
...
...
tensorpack/train/trainer.py
View file @
d50341b8
...
...
@@ -54,9 +54,14 @@ class PredictorFactory(object):
class
SimpleTrainer
(
Trainer
):
""" A naive demo trainer """
""" A naive demo trainer which iterates over a DataFlow and feed into the
graph. It's not efficient compared to QueueInputTrainer or others."""
def
__init__
(
self
,
config
):
"""
Args:
config (TrainConfig): the training config.
"""
super
(
SimpleTrainer
,
self
)
.
__init__
(
config
)
self
.
_predictor_factory
=
PredictorFactory
(
self
.
sess
,
self
.
model
,
[
0
])
if
config
.
dataflow
is
None
:
...
...
@@ -66,6 +71,7 @@ class SimpleTrainer(Trainer):
self
.
_input_method
=
FeedInput
(
config
.
dataflow
)
def
run_step
(
self
):
""" Feed data into the graph and run the updates. """
feed
=
self
.
_input_method
.
next_feed
()
self
.
sess
.
run
([
self
.
train_op
],
feed_dict
=
feed
)
# faster since train_op return None
...
...
@@ -99,11 +105,10 @@ class SimpleTrainer(Trainer):
class
MultiPredictorTowerTrainer
(
Trainer
):
""" A trainer with possibly multiple prediction tower """
def
_setup_predictor_factory
(
self
,
predict_tower
):
def
_setup_predictor_factory
(
self
):
# by default, use the first training gpu for prediction
predict_tower
=
predict_tower
or
[
0
]
self
.
_predictor_factory
=
PredictorFactory
(
self
.
sess
,
self
.
model
,
predict_tower
)
self
.
sess
,
self
.
model
,
self
.
config
.
predict_tower
)
def
get_predict_func
(
self
,
input_names
,
output_names
,
tower
=
0
):
"""
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment