Commit 61bb05b5 authored by Yuxin Wu's avatar Yuxin Wu

Add SessionRunTimeout callback

parent 04a64849
......@@ -14,7 +14,7 @@ from ..tfutils.common import (
get_op_tensor_name, get_global_step_var)
from .base import Callback
__all__ = ['TensorPrinter', 'ProgressBar']
__all__ = ['TensorPrinter', 'ProgressBar', 'SessionRunTimeout']
class TensorPrinter(Callback):
......@@ -132,3 +132,21 @@ class MaintainStepCounter(Callback):
def _after_run(self, _, __):
# Keep python-side global_step in agreement with TF-side
self.trainer.loop._global_step += 1
class SessionRunTimeout(Callback):
"""
Add timeout option to each sess.run call.
"""
def __init__(self, timeout_in_ms):
"""
Args:
timeout_in_ms (int):
"""
self._timeout = int(timeout_in_ms)
opt = tf.RunOptions(timeout_in_ms=timeout_in_ms)
self._runargs = tf.train.SessionRunArgs(fetches=[], options=opt)
def _before_run(self, _):
return self._runargs
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment