Commit 6ca57de4 authored by Yuxin Wu's avatar Yuxin Wu

fix bug in PeriodicRunHooks

parent 4eb2be09
......@@ -63,10 +63,13 @@ class PeriodicRunHooks(ProxyCallback):
def _before_run(self, ctx):
if self.global_step % self._every_k_steps == 0:
self._enabled = True
return self.cb._before_run(ctx)
else:
self._enabled = False
def _after_run(self, ctx, rv):
if self.global_step % self._every_k_steps == 0:
if self._enabled:
self.cb._after_run(ctx, rv)
def __str__(self):
......@@ -80,8 +83,8 @@ class EnableCallbackIf(ProxyCallback):
The other methods will be called the same.
Note:
If you need to use ``{before,after}_run``, make sure
that ``pred`` will eval to the same results in both methods every step.
If you use ``{before,after}_run``,
``pred`` will be evaluated only in ``before_run``.
"""
def __init__(self, callback, pred):
"""
......@@ -94,10 +97,13 @@ class EnableCallbackIf(ProxyCallback):
def _before_run(self, ctx):
if self._pred(self):
self._enabled = True
return super(EnableCallbackIf, self)._before_run(ctx)
else:
self._enabled = False
def _after_run(self, ctx, rv):
if self._pred(self):
if self._enabled:
super(EnableCallbackIf, self)._after_run(ctx, rv)
def _before_epoch(self):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment