Commit c648e3ac authored by Yuxin Wu's avatar Yuxin Wu

Merge branch 'master' into fpn

parents b244fd0a 61bb05b5
...@@ -14,7 +14,7 @@ from ..tfutils.common import ( ...@@ -14,7 +14,7 @@ from ..tfutils.common import (
get_op_tensor_name, get_global_step_var) get_op_tensor_name, get_global_step_var)
from .base import Callback from .base import Callback
__all__ = ['TensorPrinter', 'ProgressBar'] __all__ = ['TensorPrinter', 'ProgressBar', 'SessionRunTimeout']
class TensorPrinter(Callback): class TensorPrinter(Callback):
...@@ -132,3 +132,21 @@ class MaintainStepCounter(Callback): ...@@ -132,3 +132,21 @@ class MaintainStepCounter(Callback):
def _after_run(self, _, __): def _after_run(self, _, __):
# Keep python-side global_step in agreement with TF-side # Keep python-side global_step in agreement with TF-side
self.trainer.loop._global_step += 1 self.trainer.loop._global_step += 1
class SessionRunTimeout(Callback):
"""
Add timeout option to each sess.run call.
"""
def __init__(self, timeout_in_ms):
"""
Args:
timeout_in_ms (int):
"""
self._timeout = int(timeout_in_ms)
opt = tf.RunOptions(timeout_in_ms=timeout_in_ms)
self._runargs = tf.train.SessionRunArgs(fetches=[], options=opt)
def _before_run(self, _):
return self._runargs
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment