Commit 53b6112d authored by Yuxin Wu's avatar Yuxin Wu

better grad summary

parent 75e8d9fe
......@@ -6,7 +6,7 @@
import tensorflow as tf
from abc import ABCMeta, abstractmethod
import re
from ..utils import logger
from ..utils import logger, MOVING_SUMMARY_VARS_KEY
__all__ = ['GradientProcessor', 'SummaryGradient', 'CheckGradient',
'ScaleGradient', 'MapGradient']
......@@ -34,8 +34,9 @@ class SummaryGradient(GradientProcessor):
def _process(self, grads):
for grad, var in grads:
tf.histogram_summary(var.op.name + '/grad', grad)
tf.scalar_summary(var.op.name + '/gradRMS',
tf.sqrt(tf.reduce_mean(tf.square(grad))))
tf.add_to_collection(MOVING_SUMMARY_VARS_KEY,
tf.sqrt(tf.reduce_mean(tf.square(grad)),
name=var.op.name + '/gradRMS'))
return grads
......
......@@ -31,11 +31,12 @@ class SimpleTrainer(Trainer):
self.input_vars = input_vars
model.build_graph(input_vars, True)
cost_var = model.get_cost()
avg_maintain_op = summary_moving_average()
tf.add_to_collection(MOVING_SUMMARY_VARS_KEY, cost_var)
grads = self.config.optimizer.compute_gradients(cost_var)
grads = self.process_grads(grads)
avg_maintain_op = summary_moving_average()
self.train_op = tf.group(
self.config.optimizer.apply_gradients(grads, get_global_step_var()),
avg_maintain_op)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment