Commit 0bdc9985 authored by Yuxin Wu's avatar Yuxin Wu

use ValidationCallback to monitor test cost only

parent 1cc356ea
...@@ -19,7 +19,7 @@ class ValidationCallback(PeriodicCallback): ...@@ -19,7 +19,7 @@ class ValidationCallback(PeriodicCallback):
""" """
Basic routine for validation callbacks. Basic routine for validation callbacks.
""" """
def __init__(self, ds, prefix, period, cost_var_name='cost:0'): def __init__(self, ds, prefix, period=1, cost_var_name='cost:0'):
super(ValidationCallback, self).__init__(period) super(ValidationCallback, self).__init__(period)
self.ds = ds self.ds = ds
self.prefix = prefix self.prefix = prefix
...@@ -68,6 +68,10 @@ class ValidationCallback(PeriodicCallback): ...@@ -68,6 +68,10 @@ class ValidationCallback(PeriodicCallback):
'{}_cost'.format(self.prefix), cost_avg), self.global_step) '{}_cost'.format(self.prefix), cost_avg), self.global_step)
logger.info("{}_cost: {:.4f}".format(self.prefix, cost_avg)) logger.info("{}_cost: {:.4f}".format(self.prefix, cost_avg))
def _trigger(self):
for dp, outputs in self._run_validation():
pass
class ValidationError(ValidationCallback): class ValidationError(ValidationCallback):
running_graph = 'test' running_graph = 'test'
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment