Commit d8dd7ca9 authored by Yuxin Wu's avatar Yuxin Wu

Support global_step in AccumGradOptimizer

parent 1f844978
...@@ -180,9 +180,6 @@ class AccumGradOptimizer(ProxyOptimizer): ...@@ -180,9 +180,6 @@ class AccumGradOptimizer(ProxyOptimizer):
@HIDE_DOC @HIDE_DOC
def apply_gradients(self, grads_and_vars, global_step=None, name=None): def apply_gradients(self, grads_and_vars, global_step=None, name=None):
assert global_step is None, \
"AccumGradOptimizer doesn't support the option global_step! " \
"Please maintain it yourself."
grads_and_vars = FilterNoneGrad().process(grads_and_vars) grads_and_vars = FilterNoneGrad().process(grads_and_vars)
vs = [] vs = []
for g, v in grads_and_vars: for g, v in grads_and_vars:
...@@ -219,7 +216,16 @@ class AccumGradOptimizer(ProxyOptimizer): ...@@ -219,7 +216,16 @@ class AccumGradOptimizer(ProxyOptimizer):
with tf.control_dependencies([update_slot_op]): with tf.control_dependencies([update_slot_op]):
if name is None: if name is None:
name = 'cond_update_grad' name = 'cond_update_grad'
op = tf.cond(pred, update_grad, tf.no_op, name=name).op op = tf.cond(pred, update_grad, tf.no_op)
if global_step is not None:
# Tensorpack maintains global_step by other means,
# so this option is useless in tensorpack trainers.
# But we include the implementation here for completeness
global_step_increment = tf.assign_add(global_step, 1)
op = tf.group(op, global_step_increment, name=name)
else:
op = tf.identity(op, name=name).op
return op return op
...@@ -230,7 +236,7 @@ if __name__ == '__main__': ...@@ -230,7 +236,7 @@ if __name__ == '__main__':
cost = tf.reduce_sum(tf.abs(x), name='cost') cost = tf.reduce_sum(tf.abs(x), name='cost')
opt = tf.train.GradientDescentOptimizer(0.01) opt = tf.train.GradientDescentOptimizer(0.01)
opt = AccumGradOptimizer(opt, 5) opt = AccumGradOptimizer(opt, 5)
min_op = opt.minimize(cost) min_op = opt.minimize(cost, global_step=tf.train.get_or_create_global_step())
sess = tf.Session() sess = tf.Session()
sess.run(tf.global_variables_initializer()) sess.run(tf.global_variables_initializer())
...@@ -238,3 +244,4 @@ if __name__ == '__main__': ...@@ -238,3 +244,4 @@ if __name__ == '__main__':
for k in range(20): for k in range(20):
min_op.run() min_op.run()
print(x.eval()) print(x.eval())
print(tf.train.get_or_create_global_step().eval())
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment