Commit 7237a1c8 authored by Yuxin Wu's avatar Yuxin Wu

summary / dataflow fix

parent d8a647f2
...@@ -130,7 +130,7 @@ class FakeData(DataFlow): ...@@ -130,7 +130,7 @@ class FakeData(DataFlow):
class MapData(ProxyDataFlow): class MapData(ProxyDataFlow):
""" Map a function to the datapoint""" """ Map a function to the datapoint"""
def __init__(self, ds, func): def __init__(self, ds, func):
super(MapData, self).__init_(ds) super(MapData, self).__init__(ds)
self.func = func self.func = func
def get_data(self): def get_data(self):
......
...@@ -7,6 +7,7 @@ import numpy as np ...@@ -7,6 +7,7 @@ import numpy as np
import cv2 import cv2
import copy import copy
from .base import DataFlow, ProxyDataFlow from .base import DataFlow, ProxyDataFlow
from .common import MapDataComponent
from .imgaug import AugmentorList, Image from .imgaug import AugmentorList, Image
__all__ = ['ImageFromFile', 'AugmentImageComponent'] __all__ = ['ImageFromFile', 'AugmentImageComponent']
...@@ -37,7 +38,7 @@ class ImageFromFile(DataFlow): ...@@ -37,7 +38,7 @@ class ImageFromFile(DataFlow):
yield [im] yield [im]
class AugmentImageComponent(ProxyDataFlow): class AugmentImageComponent(MapDataComponent):
""" """
Augment image in each data point Augment image in each data point
Args: Args:
...@@ -46,16 +47,10 @@ class AugmentImageComponent(ProxyDataFlow): ...@@ -46,16 +47,10 @@ class AugmentImageComponent(ProxyDataFlow):
index: the index of image in each data point. default to be 0 index: the index of image in each data point. default to be 0
""" """
def __init__(self, ds, augmentors, index=0): def __init__(self, ds, augmentors, index=0):
super(AugmentImageComponent, self).__init__(ds)
self.augs = AugmentorList(augmentors) self.augs = AugmentorList(augmentors)
self.index = index super(AugmentImageComponent, self).__init__(
ds, lambda x: self.augs.augment(Image(x)).arr, index)
def reset_state(self): def reset_state(self):
self.ds.reset_state() self.ds.reset_state()
self.augs.reset_state() self.augs.reset_state()
def get_data(self):
for dp in self.ds.get_data():
dp = copy.deepcopy(dp)
dp[self.index] = self.augs.augment(Image(dp[self.index])).arr
yield dp
...@@ -5,7 +5,7 @@ ...@@ -5,7 +5,7 @@
import tensorflow as tf import tensorflow as tf
from . import logger from . import logger, get_global_step_var
from .naming import * from .naming import *
def create_summary(name, v): def create_summary(name, v):
...@@ -30,8 +30,8 @@ def add_activation_summary(x, name=None): ...@@ -30,8 +30,8 @@ def add_activation_summary(x, name=None):
"Summary a scalar with histogram? Maybe use scalar instead. FIXME!" "Summary a scalar with histogram? Maybe use scalar instead. FIXME!"
if name is None: if name is None:
name = x.name name = x.name
tf.histogram_summary(name + '/activations', x) tf.histogram_summary(name + '/activation', x)
tf.scalar_summary(name + '/sparsity', tf.nn.zero_fraction(x)) tf.scalar_summary(name + '/activation_sparsity', tf.nn.zero_fraction(x))
def add_param_summary(regex): def add_param_summary(regex):
""" """
...@@ -45,6 +45,7 @@ def add_param_summary(regex): ...@@ -45,6 +45,7 @@ def add_param_summary(regex):
if p.get_shape().ndims == 0: if p.get_shape().ndims == 0:
tf.scalar_summary(name, p) tf.scalar_summary(name, p)
else: else:
tf.scalar_summary(name + '/sparsity', tf.nn.zero_fraction(p))
tf.histogram_summary(name, p) tf.histogram_summary(name, p)
def summary_moving_average(cost_var): def summary_moving_average(cost_var):
...@@ -52,7 +53,7 @@ def summary_moving_average(cost_var): ...@@ -52,7 +53,7 @@ def summary_moving_average(cost_var):
MOVING_SUMMARY_VARS_KEY, as well as the argument MOVING_SUMMARY_VARS_KEY, as well as the argument
Return a op to maintain these average Return a op to maintain these average
""" """
global_step_var = tf.get_default_graph().get_tensor_by_name(GLOBAL_STEP_VAR_NAME) global_step_var = get_global_step_var()
averager = tf.train.ExponentialMovingAverage( averager = tf.train.ExponentialMovingAverage(
0.99, num_updates=global_step_var, name='moving_averages') 0.99, num_updates=global_step_var, name='moving_averages')
vars_to_summary = [cost_var] + \ vars_to_summary = [cost_var] + \
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment