Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
d6a89f0b
Commit
d6a89f0b
authored
Feb 15, 2017
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
an optimizer which accumulates gradients. fix #141
parent
f60989d3
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
39 additions
and
9 deletions
+39
-9
tensorpack/tfutils/optimizer.py
tensorpack/tfutils/optimizer.py
+39
-9
No files found.
tensorpack/tfutils/optimizer.py
View file @
d6a89f0b
...
...
@@ -16,7 +16,8 @@ class ProxyOptimizer(tf.train.Optimizer):
"""
A transparent proxy which delegates all methods of :class:`tf.train.Optimizer`
"""
def
__init__
(
self
,
opt
):
def
__init__
(
self
,
opt
,
name
=
'ProxyOptimizer'
):
super
(
ProxyOptimizer
,
self
)
.
__init__
(
False
,
name
)
self
.
_opt
=
opt
def
compute_gradients
(
self
,
*
args
,
**
kwargs
):
...
...
@@ -119,15 +120,26 @@ class VariableAssignmentOptimizer(PostProcessOptimizer):
class
AccumGradOptimizer
(
ProxyOptimizer
):
"""
An optimizer which accumulates gradients across :math:`k` :meth:`minimize` calls,
and apply them together in every :math:`k`th :meth:`minimize` call.
This is equivalent to using a :math:`k` times larger batch size plus a
:math:`k` times larger learning rate, but use much less memory.
"""
def
__init__
(
self
,
opt
,
niter
):
super
(
AccumGradOptimizer
,
self
)
.
__init__
(
opt
)
self
.
_niter
=
niter
self
.
_name
=
"AccumGrad"
self
.
_counter
=
None
"""
Args:
opt (tf.train.Optimizer): the underlying sub-optimizer.
niter (int): number of iterations to accumulate gradients.
"""
super
(
AccumGradOptimizer
,
self
)
.
__init__
(
opt
,
'AccumGrad'
)
self
.
_niter
=
int
(
niter
)
def
_create_accum_slots
(
self
,
var_list
):
slots
=
[]
for
v
in
var_list
:
# TODO an option to not colocate the accumulators with variables (to save more memory)
s
=
self
.
_zeros_slot
(
v
,
"accum"
,
self
.
_name
)
slots
.
append
(
s
)
return
slots
...
...
@@ -143,32 +155,50 @@ class AccumGradOptimizer(ProxyOptimizer):
"AccumGradOptimizer only works for dense update! "
\
"Types of v and g are {} and {}"
.
format
(
type
(
v
),
type
(
g
))
vs
.
append
(
v
)
with
tf
.
control_dependencies
(
None
):
slots
=
self
.
_create_accum_slots
(
vs
)
slots_and_vars
=
[(
s
,
gv
[
1
])
for
s
,
gv
in
zip
(
slots
,
grads_and_vars
)]
# Create the counter on the same device as the first variable.
with
tf
.
variable_scope
(
self
.
_name
),
\
tf
.
colocate_with
(
vs
[
0
]):
vs
[
0
]
.
graph
.
colocate_with
(
vs
[
0
]):
counter
=
tf
.
Variable
(
0
,
name
=
"counter"
,
trainable
=
False
,
dtype
=
tf
.
int32
)
ops
=
[]
for
s
,
gv
in
zip
(
slots
,
grads_and_vars
):
g
,
v
=
gv
ops
.
append
(
s
.
assign_add
(
s
,
g
))
ops
.
append
(
s
.
assign_add
(
g
))
update_counter
=
tf
.
assign_add
(
counter
,
1
,
name
=
'update_counter'
)
update_slot_op
=
tf
.
group
(
update_counter
,
*
ops
,
name
=
'update_slot'
)
def
update_grad
():
update_op
=
self
.
_opt
.
apply_gradients
(
slots_and_vars
)
with
tf
.
control_dependencies
([
update_op
]):
clear_ops
=
[
tf
.
assign
(
s
,
0.0
)
for
s
in
slots
]
clear_ops
=
[
tf
.
assign
(
s
,
tf
.
zeros_like
(
s
)
)
for
s
in
slots
]
return
tf
.
group
(
*
clear_ops
,
name
=
'update_grad'
)
pred
=
tf
.
equal
(
tf
.
mod
(
counter
,
self
.
_niter
),
0
)
with
tf
.
control_dependencies
([
update_slot_op
]):
if
name
is
None
:
name
=
'cond_update_grad'
op
=
tf
.
cond
(
pred
,
update_grad
,
lambda
:
tf
.
no_op
(),
name
=
name
)
op
=
tf
.
cond
(
pred
,
update_grad
,
tf
.
no_op
,
name
=
name
)
.
op
return
op
if
__name__
==
'__main__'
:
# run it with "python -m tensorpack.tfutils.optimizer"
x
=
tf
.
get_variable
(
'x'
,
shape
=
[
6
])
cost
=
tf
.
reduce_sum
(
tf
.
abs
(
x
),
name
=
'cost'
)
opt
=
tf
.
train
.
GradientDescentOptimizer
(
0.01
)
# opt = AccumGradOptimizer(opt, 5)
min_op
=
opt
.
minimize
(
cost
)
sess
=
tf
.
Session
()
sess
.
run
(
tf
.
global_variables_initializer
())
with
sess
.
as_default
():
for
k
in
range
(
20
):
min_op
.
run
()
print
(
x
.
eval
())
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment