Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
17126868
Commit
17126868
authored
May 06, 2017
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
remove some deprecation in trainers
parent
ca0f0bd0
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
22 additions
and
50 deletions
+22
-50
tensorpack/train/feedfree.py
tensorpack/train/feedfree.py
+6
-12
tensorpack/train/input_data.py
tensorpack/train/input_data.py
+1
-1
tensorpack/train/multigpu.py
tensorpack/train/multigpu.py
+15
-37
No files found.
tensorpack/train/feedfree.py
View file @
17126868
...
@@ -5,7 +5,6 @@
...
@@ -5,7 +5,6 @@
import
tensorflow
as
tf
import
tensorflow
as
tf
from
..utils.develop
import
log_deprecated
from
..tfutils.tower
import
TowerContext
,
get_current_tower_context
from
..tfutils.tower
import
TowerContext
,
get_current_tower_context
from
.input_data
import
QueueInput
,
FeedfreeInput
from
.input_data
import
QueueInput
,
FeedfreeInput
...
@@ -21,7 +20,7 @@ class FeedfreeTrainerBase(Trainer):
...
@@ -21,7 +20,7 @@ class FeedfreeTrainerBase(Trainer):
"""
"""
def
build_train_tower
(
self
):
def
build_train_tower
(
self
):
"""
"""
Get input tensors from `self.input_method` and build the graph.
Get input tensors from `self.input_method` and build the
forward
graph.
"""
"""
def
f
():
def
f
():
self
.
_input_tensors
=
self
.
_input_method
.
get_input_tensors
()
self
.
_input_tensors
=
self
.
_input_method
.
get_input_tensors
()
...
@@ -64,7 +63,7 @@ class SingleCostFeedfreeTrainer(FeedfreeTrainerBase):
...
@@ -64,7 +63,7 @@ class SingleCostFeedfreeTrainer(FeedfreeTrainerBase):
def
_get_cost_and_grad
(
self
):
def
_get_cost_and_grad
(
self
):
""" get the cost and gradient"""
""" get the cost and gradient"""
self
.
build_train_tower
()
self
.
build_train_tower
()
cost
=
self
.
model
.
get_cost
()
cost
=
self
.
model
.
get_cost
()
# assume single cost
opt
=
self
.
config
.
optimizer
opt
=
self
.
config
.
optimizer
# GATE_NONE faster?
# GATE_NONE faster?
grads
=
opt
.
compute_gradients
(
grads
=
opt
.
compute_gradients
(
...
@@ -90,7 +89,8 @@ class SimpleFeedfreeTrainer(SingleCostFeedfreeTrainer):
...
@@ -90,7 +89,8 @@ class SimpleFeedfreeTrainer(SingleCostFeedfreeTrainer):
assert
isinstance
(
self
.
_input_method
,
FeedfreeInput
),
self
.
_input_method
assert
isinstance
(
self
.
_input_method
,
FeedfreeInput
),
self
.
_input_method
super
(
SimpleFeedfreeTrainer
,
self
)
.
__init__
(
config
)
super
(
SimpleFeedfreeTrainer
,
self
)
.
__init__
(
config
)
assert
len
(
self
.
config
.
tower
)
==
1
,
\
assert
len
(
self
.
config
.
tower
)
==
1
,
\
"SimpleFeedfreeTrainer doesn't support multigpu!"
"Got nr_tower={}, but doesn't support multigpu!"
\
" Use Sync/AsyncMultiGPUTrainer instead."
.
format
(
len
(
self
.
config
.
tower
))
def
_setup
(
self
):
def
_setup
(
self
):
super
(
SimpleFeedfreeTrainer
,
self
)
.
_setup
()
super
(
SimpleFeedfreeTrainer
,
self
)
.
_setup
()
...
@@ -101,7 +101,7 @@ class SimpleFeedfreeTrainer(SingleCostFeedfreeTrainer):
...
@@ -101,7 +101,7 @@ class SimpleFeedfreeTrainer(SingleCostFeedfreeTrainer):
# self.train_op = tf.group(*self._input_tensors)
# self.train_op = tf.group(*self._input_tensors)
def
QueueInputTrainer
(
config
,
input_queue
=
None
,
predict_tower
=
None
):
def
QueueInputTrainer
(
config
,
input_queue
=
None
):
"""
"""
A wrapper trainer which automatically wraps ``config.dataflow`` by a
A wrapper trainer which automatically wraps ``config.dataflow`` by a
:class:`QueueInput`.
:class:`QueueInput`.
...
@@ -117,14 +117,8 @@ def QueueInputTrainer(config, input_queue=None, predict_tower=None):
...
@@ -117,14 +117,8 @@ def QueueInputTrainer(config, input_queue=None, predict_tower=None):
else
:
else
:
assert
isinstance
(
config
.
data
,
QueueInput
),
config
.
data
assert
isinstance
(
config
.
data
,
QueueInput
),
config
.
data
# debug
# from tensorpack.train.input_data import StagingInputWrapper, DummyConstantInput
# from tensorpack.train.input_data import StagingInputWrapper, DummyConstantInput
# config.data = StagingInputWrapper(config.data, ['/gpu:0'])
# config.data = StagingInputWrapper(config.data, ['/gpu:0'])
# config.data = DummyConstantInput([[128,224,224,3], [128]])
# config.data = DummyConstantInput([[128,224,224,3], [128]])
if
predict_tower
is
not
None
:
log_deprecated
(
"Argument `predict_tower` in trainer"
,
"Use TrainConfig(predict_tower=...) instead!"
)
config
.
predict_tower
=
predict_tower
assert
len
(
config
.
tower
)
==
1
,
\
"Got nr_tower={}, but QueueInputTrainer doesn't support multigpu!"
\
" Use Sync/AsyncMultiGPUTrainer instead."
.
format
(
len
(
config
.
tower
))
return
SimpleFeedfreeTrainer
(
config
)
return
SimpleFeedfreeTrainer
(
config
)
tensorpack/train/input_data.py
View file @
17126868
...
@@ -293,7 +293,7 @@ class TensorInput(FeedfreeInput):
...
@@ -293,7 +293,7 @@ class TensorInput(FeedfreeInput):
"""
"""
Args:
Args:
get_tensor_fn: a function which returns a list of input tensors
get_tensor_fn: a function which returns a list of input tensors
when called.
when called.
It will be called under a TowerContext.
size(int): size of this input. Use None to leave it undefined.
size(int): size of this input. Use None to leave it undefined.
"""
"""
self
.
get_tensor_fn
=
get_tensor_fn
self
.
get_tensor_fn
=
get_tensor_fn
...
...
tensorpack/train/multigpu.py
View file @
17126868
...
@@ -63,20 +63,15 @@ class SyncMultiGPUTrainer(MultiGPUTrainer,
...
@@ -63,20 +63,15 @@ class SyncMultiGPUTrainer(MultiGPUTrainer,
from each tower and averages them.
from each tower and averages them.
"""
"""
def
__init__
(
self
,
config
,
input_queue
=
None
,
def
__init__
(
self
,
config
):
average_cost
=
False
):
"""
"""
Args:
Args:
config, input_queue: same as in :class:`QueueInputTrainer`.
config: same as in :class:`QueueInputTrainer`.
average_cost (bool): average the cost (instead of gradients) from
each tower and did backprop only once. This option should make no
difference mathematically, but may affect speed.
"""
"""
if
config
.
dataflow
is
not
None
:
if
config
.
dataflow
is
not
None
:
# use queueinput by default. May need to avoid this in the future (when more input type is available)
# use queueinput by default. May need to avoid this in the future (when more input type is available)
self
.
_input_method
=
QueueInput
(
config
.
dataflow
,
input_queue
)
self
.
_input_method
=
QueueInput
(
config
.
dataflow
)
else
:
else
:
assert
input_queue
is
None
,
input_queue
self
.
_input_method
=
config
.
data
self
.
_input_method
=
config
.
data
assert
len
(
config
.
tower
)
>=
1
,
"MultiGPUTrainer must be used with at least one tower."
assert
len
(
config
.
tower
)
>=
1
,
"MultiGPUTrainer must be used with at least one tower."
...
@@ -89,7 +84,6 @@ class SyncMultiGPUTrainer(MultiGPUTrainer,
...
@@ -89,7 +84,6 @@ class SyncMultiGPUTrainer(MultiGPUTrainer,
self
.
_input_method
=
StagingInputWrapper
(
self
.
_input_method
,
devices
)
self
.
_input_method
=
StagingInputWrapper
(
self
.
_input_method
,
devices
)
super
(
SyncMultiGPUTrainer
,
self
)
.
__init__
(
config
)
super
(
SyncMultiGPUTrainer
,
self
)
.
__init__
(
config
)
self
.
average_cost
=
average_cost
@
staticmethod
@
staticmethod
def
_average_grads
(
tower_grads
):
def
_average_grads
(
tower_grads
):
...
@@ -117,7 +111,7 @@ class SyncMultiGPUTrainer(MultiGPUTrainer,
...
@@ -117,7 +111,7 @@ class SyncMultiGPUTrainer(MultiGPUTrainer,
def
_setup
(
self
):
def
_setup
(
self
):
super
(
SyncMultiGPUTrainer
,
self
)
.
_setup
()
super
(
SyncMultiGPUTrainer
,
self
)
.
_setup
()
if
not
self
.
average_cost
:
grad_list
=
MultiGPUTrainer
.
multi_tower_grads
(
grad_list
=
MultiGPUTrainer
.
multi_tower_grads
(
self
.
config
.
tower
,
lambda
:
self
.
_get_cost_and_grad
()[
1
])
self
.
config
.
tower
,
lambda
:
self
.
_get_cost_and_grad
()[
1
])
...
@@ -128,21 +122,7 @@ class SyncMultiGPUTrainer(MultiGPUTrainer,
...
@@ -128,21 +122,7 @@ class SyncMultiGPUTrainer(MultiGPUTrainer,
grads
=
SyncMultiGPUTrainer
.
_average_grads
(
grad_list
)
grads
=
SyncMultiGPUTrainer
.
_average_grads
(
grad_list
)
# grads = grad_list[0]
# grads = grad_list[0]
else
:
def
get_cost
():
self
.
build_train_tower
()
return
self
.
model
.
get_cost
()
cost_list
=
MultiGPUTrainer
.
multi_tower_costs
(
self
.
config
.
tower
,
get_cost
)
cost
=
tf
.
multiply
(
tf
.
add_n
(
cost_list
),
1.0
/
len
(
cost_list
),
name
=
'averaged_cost'
)
opt
=
self
.
config
.
optimizer
grads
=
opt
.
compute_gradients
(
cost
,
gate_gradients
=
tf
.
train
.
Optimizer
.
GATE_NONE
,
colocate_gradients_with_ops
=
True
)
self
.
train_op
=
self
.
config
.
optimizer
.
apply_gradients
(
grads
,
name
=
'min_op'
)
self
.
train_op
=
self
.
config
.
optimizer
.
apply_gradients
(
grads
,
name
=
'min_op'
)
...
@@ -154,19 +134,17 @@ class AsyncMultiGPUTrainer(MultiGPUTrainer,
...
@@ -154,19 +134,17 @@ class AsyncMultiGPUTrainer(MultiGPUTrainer,
"""
"""
def
__init__
(
self
,
config
,
def
__init__
(
self
,
config
,
input_queue
=
None
,
scale_gradient
=
True
):
scale_gradient
=
True
):
"""
"""
Args:
Args:
config
, input_queue
: same as in :class:`QueueInputTrainer`.
config: same as in :class:`QueueInputTrainer`.
scale_gradient (bool): if True, will scale each gradient by
scale_gradient (bool): if True, will scale each gradient by
``1.0/nr_tower``, to make Async and Sync Trainer have the same
``1.0/nr_tower``, to make Async and Sync Trainer have the same
effective learning rate.
effective learning rate.
"""
"""
if
config
.
dataflow
is
not
None
:
if
config
.
dataflow
is
not
None
:
self
.
_input_method
=
QueueInput
(
config
.
dataflow
,
input_queue
)
self
.
_input_method
=
QueueInput
(
config
.
dataflow
)
else
:
else
:
assert
input_queue
is
None
,
input_queue
self
.
_input_method
=
config
.
data
self
.
_input_method
=
config
.
data
super
(
AsyncMultiGPUTrainer
,
self
)
.
__init__
(
config
)
super
(
AsyncMultiGPUTrainer
,
self
)
.
__init__
(
config
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment