Commit 004e3048 authored by Yuxin Wu's avatar Yuxin Wu

don't add scope for new summary module

parent 9fc5d856
...@@ -258,6 +258,7 @@ class StatMonitorParamSetter(HyperParamSetter): ...@@ -258,6 +258,7 @@ class StatMonitorParamSetter(HyperParamSetter):
if hist_max > hist_first + self.threshold: # large enough if hist_max > hist_first + self.threshold: # large enough
return None return None
self.last_changed_epoch = self.epoch_num self.last_changed_epoch = self.epoch_num
logger.info("[StatMonitorParamSetter] Triggered, history: " + ','.join(hist)) logger.info("[StatMonitorParamSetter] Triggered, history: " +
','.join(map(str, hist)))
return self.value_func(self.get_current_value()) return self.value_func(self.get_current_value())
...@@ -110,14 +110,14 @@ def summary_moving_average(tensors=None): ...@@ -110,14 +110,14 @@ def summary_moving_average(tensors=None):
""" """
if tensors is None: if tensors is None:
tensors = tf.get_collection(MOVING_SUMMARY_VARS_KEY) tensors = tf.get_collection(MOVING_SUMMARY_VARS_KEY)
with tf.name_scope('EMA-summary'):
# TODO will produce EMA_summary/tower0/xxx. not elegant # TODO will produce tower0/xxx. not elegant
with tf.name_scope(None): with tf.name_scope(None):
averager = tf.train.ExponentialMovingAverage( averager = tf.train.ExponentialMovingAverage(
0.99, num_updates=get_global_step_var(), name='EMA') 0.90, num_updates=get_global_step_var(), name='EMA')
avg_maintain_op = averager.apply(tensors) avg_maintain_op = averager.apply(tensors)
for idx, c in enumerate(tensors): for idx, c in enumerate(tensors):
name = re.sub('tower[p0-9]+/', '', c.op.name) name = re.sub('tower[p0-9]+/', '', c.op.name)
tf.summary.scalar(name, averager.average(c)) tf.summary.scalar(name + '-summary', averager.average(c))
return avg_maintain_op return avg_maintain_op
# -*- coding: UTF-8 -*- # -*- coding: UTF-8 -*-
# File: stat.py # File: stats.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com> # Author: Yuxin Wu <ppwwyyxx@gmail.com>
import numpy as np import numpy as np
__all__ = ['StatCounter', 'Accuracy', 'BinaryStatistics', 'RatioCounter'] __all__ = ['StatCounter', 'Accuracy', 'BinaryStatistics', 'RatioCounter',
'OnlineMoments']
class StatCounter(object): class StatCounter(object):
""" A simple counter""" """ A simple counter"""
...@@ -116,3 +117,32 @@ class BinaryStatistics(object): ...@@ -116,3 +117,32 @@ class BinaryStatistics(object):
if self.nr_pos == 0: if self.nr_pos == 0:
return 0 return 0
return 1 - self.recall return 1 - self.recall
class OnlineMoments(object):
def __init__(self):
self._mean = None
self._var = None
self._n = 0
def feed(self, x):
self._n += 1
if self._mean is None:
self._mean = x
self._var = 0
else:
diff = (x - self._mean)
ninv = 1.0 / self._n
self._mean += diff * ninv
self._var = (self._n-2.0)/(self._n-1) * self._var + diff * diff * ninv
@property
def mean(self):
return self._mean
@property
def variance(self):
return self._var
@property
def std(self):
return np.sqrt(self._var)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment