Commit 0bdc9985 authored by Yuxin Wu's avatar Yuxin Wu

use ValidationCallback to monitor test cost only

parent 1cc356ea
......@@ -19,7 +19,7 @@ class ValidationCallback(PeriodicCallback):
"""
Basic routine for validation callbacks.
"""
def __init__(self, ds, prefix, period, cost_var_name='cost:0'):
def __init__(self, ds, prefix, period=1, cost_var_name='cost:0'):
super(ValidationCallback, self).__init__(period)
self.ds = ds
self.prefix = prefix
......@@ -68,6 +68,10 @@ class ValidationCallback(PeriodicCallback):
'{}_cost'.format(self.prefix), cost_avg), self.global_step)
logger.info("{}_cost: {:.4f}".format(self.prefix, cost_avg))
def _trigger(self):
for dp, outputs in self._run_validation():
pass
class ValidationError(ValidationCallback):
running_graph = 'test'
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment