Commit 32feff4e authored by Yuxin Wu's avatar Yuxin Wu

fix periodic bug

parent 9a4e6d9d
...@@ -4,8 +4,6 @@ ...@@ -4,8 +4,6 @@
# Author: Yuxin Wu <ppwwyyxx@gmail.com> # Author: Yuxin Wu <ppwwyyxx@gmail.com>
import tensorflow as tf import tensorflow as tf
from tensorflow.python.ops import control_flow_ops
import numpy as np import numpy as np
import os, sys import os, sys
import argparse import argparse
...@@ -18,7 +16,6 @@ from tensorpack.tfutils.summary import * ...@@ -18,7 +16,6 @@ from tensorpack.tfutils.summary import *
from tensorpack.tfutils import * from tensorpack.tfutils import *
from tensorpack.callbacks import * from tensorpack.callbacks import *
from tensorpack.dataflow import * from tensorpack.dataflow import *
from IPython import embed; embed()
""" """
MNIST ConvNet example. MNIST ConvNet example.
......
...@@ -11,23 +11,21 @@ from ..utils import * ...@@ -11,23 +11,21 @@ from ..utils import *
from ..utils.stat import * from ..utils.stat import *
from ..tfutils import * from ..tfutils import *
from ..tfutils.summary import * from ..tfutils.summary import *
from .base import PeriodicCallback, Callback, TestCallbackType from .base import Callback, TestCallbackType
__all__ = ['ClassificationError', 'ValidationCallback', 'ValidationStatPrinter'] __all__ = ['ClassificationError', 'ValidationCallback', 'ValidationStatPrinter']
class ValidationCallback(PeriodicCallback): class ValidationCallback(Callback):
""" """
Base class for validation callbacks. Base class for validation callbacks.
""" """
type = TestCallbackType() type = TestCallbackType()
def __init__(self, ds, prefix, period=1): def __init__(self, ds, prefix):
""" """
:param ds: validation dataset. must be a `DataFlow` instance. :param ds: validation dataset. must be a `DataFlow` instance.
:param prefix: name to use for this validation. :param prefix: name to use for this validation.
:param period: period to perform validation.
""" """
super(ValidationCallback, self).__init__(period)
self.ds = ds self.ds = ds
self.prefix = prefix self.prefix = prefix
...@@ -63,23 +61,18 @@ class ValidationCallback(PeriodicCallback): ...@@ -63,23 +61,18 @@ class ValidationCallback(PeriodicCallback):
yield (dp, outputs) yield (dp, outputs)
pbar.update() pbar.update()
@abstractmethod
def _trigger_periodic(self):
""" Implement the actual callback"""
class ValidationStatPrinter(ValidationCallback): class ValidationStatPrinter(ValidationCallback):
""" """
Write stat and summary of some Op for a validation dataset. Write stat and summary of some Op for a validation dataset.
The result of the given Op must be a scalar, and will be averaged for all batches in the validaion set. The result of the given Op must be a scalar, and will be averaged for all batches in the validaion set.
""" """
def __init__(self, ds, names_to_print, prefix='validation', period=1): def __init__(self, ds, names_to_print, prefix='validation'):
""" """
:param ds: validation dataset. must be a `DataFlow` instance. :param ds: validation dataset. must be a `DataFlow` instance.
:param names_to_print: names of variables to print :param names_to_print: names of variables to print
:param prefix: name to use for this validation. :param prefix: name to use for this validation.
:param period: period to perform validation.
""" """
super(ValidationStatPrinter, self).__init__(ds, prefix, period) super(ValidationStatPrinter, self).__init__(ds, prefix)
self.names = names_to_print self.names = names_to_print
def _find_output_vars(self): def _find_output_vars(self):
...@@ -89,7 +82,7 @@ class ValidationStatPrinter(ValidationCallback): ...@@ -89,7 +82,7 @@ class ValidationStatPrinter(ValidationCallback):
def _get_output_vars(self): def _get_output_vars(self):
return self.vars_to_print return self.vars_to_print
def _trigger_periodic(self): def _trigger_epoch(self):
stats = [] stats = []
for dp, outputs in self._run_validation(): for dp, outputs in self._run_validation():
stats.append(outputs) stats.append(outputs)
...@@ -114,13 +107,12 @@ class ClassificationError(ValidationCallback): ...@@ -114,13 +107,12 @@ class ClassificationError(ValidationCallback):
In theory, the result could be different from what produced by ValidationStatPrinter. In theory, the result could be different from what produced by ValidationStatPrinter.
""" """
def __init__(self, ds, prefix='validation', def __init__(self, ds, prefix='validation',
period=1,
wrong_var_name='wrong:0'): wrong_var_name='wrong:0'):
""" """
:param ds: a batched `DataFlow` instance :param ds: a batched `DataFlow` instance
:param wrong_var_name: name of the `wrong` variable :param wrong_var_name: name of the `wrong` variable
""" """
super(ClassificationError, self).__init__(ds, prefix, period) super(ClassificationError, self).__init__(ds, prefix)
self.wrong_var_name = wrong_var_name self.wrong_var_name = wrong_var_name
def _find_output_vars(self): def _find_output_vars(self):
...@@ -129,7 +121,7 @@ class ClassificationError(ValidationCallback): ...@@ -129,7 +121,7 @@ class ClassificationError(ValidationCallback):
def _get_output_vars(self): def _get_output_vars(self):
return [self.wrong_var] return [self.wrong_var]
def _trigger_periodic(self): def _trigger_epoch(self):
err_stat = Accuracy() err_stat = Accuracy()
for dp, outputs in self._run_validation(): for dp, outputs in self._run_validation():
batch_size = dp[0].shape[0] # assume batched input batch_size = dp[0].shape[0] # assume batched input
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment