Commit 6ca57de4 authored by Yuxin Wu's avatar Yuxin Wu

fix bug in PeriodicRunHooks

parent 4eb2be09
...@@ -63,10 +63,13 @@ class PeriodicRunHooks(ProxyCallback): ...@@ -63,10 +63,13 @@ class PeriodicRunHooks(ProxyCallback):
def _before_run(self, ctx): def _before_run(self, ctx):
if self.global_step % self._every_k_steps == 0: if self.global_step % self._every_k_steps == 0:
self._enabled = True
return self.cb._before_run(ctx) return self.cb._before_run(ctx)
else:
self._enabled = False
def _after_run(self, ctx, rv): def _after_run(self, ctx, rv):
if self.global_step % self._every_k_steps == 0: if self._enabled:
self.cb._after_run(ctx, rv) self.cb._after_run(ctx, rv)
def __str__(self): def __str__(self):
...@@ -80,8 +83,8 @@ class EnableCallbackIf(ProxyCallback): ...@@ -80,8 +83,8 @@ class EnableCallbackIf(ProxyCallback):
The other methods will be called the same. The other methods will be called the same.
Note: Note:
If you need to use ``{before,after}_run``, make sure If you use ``{before,after}_run``,
that ``pred`` will eval to the same results in both methods every step. ``pred`` will be evaluated only in ``before_run``.
""" """
def __init__(self, callback, pred): def __init__(self, callback, pred):
""" """
...@@ -94,10 +97,13 @@ class EnableCallbackIf(ProxyCallback): ...@@ -94,10 +97,13 @@ class EnableCallbackIf(ProxyCallback):
def _before_run(self, ctx): def _before_run(self, ctx):
if self._pred(self): if self._pred(self):
self._enabled = True
return super(EnableCallbackIf, self)._before_run(ctx) return super(EnableCallbackIf, self)._before_run(ctx)
else:
self._enabled = False
def _after_run(self, ctx, rv): def _after_run(self, ctx, rv):
if self._pred(self): if self._enabled:
super(EnableCallbackIf, self)._after_run(ctx, rv) super(EnableCallbackIf, self)._after_run(ctx, rv)
def _before_epoch(self): def _before_epoch(self):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment