Commit dbc0b36e authored by Yuxin Wu's avatar Yuxin Wu

CheckNumerics callbacks

parent 07e28eea
......@@ -219,18 +219,27 @@ class DumpTensorAsImage(Callback):
cv2.imwrite(fname, res.astype('uint8'))
class CheckNumerics(Callback):
class CheckNumerics(RunOp):
"""
When triggered, check variables in the graph for NaN and Inf.
Raise exceptions if such an error is found.
Check variables in the graph for NaN and Inf.
Raise an exception if such an error is found.
"""
def _setup_graph(self):
_chief_only = True
def __init__(self, run_as_trigger=True, run_step=False):
"""
Args: same as in :class:`RunOp`.
"""
super().__init__(
self._get_op,
run_as_trigger=run_as_trigger,
run_step=run_step)
def _get_op(self):
vars = tf.trainable_variables()
ops = [tf.check_numerics(v, "CheckNumerics['{}']".format(v.op.name)).op for v in vars]
self._check_op = tf.group(*ops)
def _trigger(self):
self._check_op.run()
check_op = tf.group(*ops, name="CheckAllNumerics")
return check_op
try:
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment