Commit 7237a1c8 authored by Yuxin Wu's avatar Yuxin Wu

summary / dataflow fix

parent d8a647f2
......@@ -130,7 +130,7 @@ class FakeData(DataFlow):
class MapData(ProxyDataFlow):
""" Map a function to the datapoint"""
def __init__(self, ds, func):
super(MapData, self).__init_(ds)
super(MapData, self).__init__(ds)
self.func = func
def get_data(self):
......
......@@ -7,6 +7,7 @@ import numpy as np
import cv2
import copy
from .base import DataFlow, ProxyDataFlow
from .common import MapDataComponent
from .imgaug import AugmentorList, Image
__all__ = ['ImageFromFile', 'AugmentImageComponent']
......@@ -37,7 +38,7 @@ class ImageFromFile(DataFlow):
yield [im]
class AugmentImageComponent(ProxyDataFlow):
class AugmentImageComponent(MapDataComponent):
"""
Augment image in each data point
Args:
......@@ -46,16 +47,10 @@ class AugmentImageComponent(ProxyDataFlow):
index: the index of image in each data point. default to be 0
"""
def __init__(self, ds, augmentors, index=0):
super(AugmentImageComponent, self).__init__(ds)
self.augs = AugmentorList(augmentors)
self.index = index
super(AugmentImageComponent, self).__init__(
ds, lambda x: self.augs.augment(Image(x)).arr, index)
def reset_state(self):
self.ds.reset_state()
self.augs.reset_state()
def get_data(self):
for dp in self.ds.get_data():
dp = copy.deepcopy(dp)
dp[self.index] = self.augs.augment(Image(dp[self.index])).arr
yield dp
......@@ -5,7 +5,7 @@
import tensorflow as tf
from . import logger
from . import logger, get_global_step_var
from .naming import *
def create_summary(name, v):
......@@ -30,8 +30,8 @@ def add_activation_summary(x, name=None):
"Summary a scalar with histogram? Maybe use scalar instead. FIXME!"
if name is None:
name = x.name
tf.histogram_summary(name + '/activations', x)
tf.scalar_summary(name + '/sparsity', tf.nn.zero_fraction(x))
tf.histogram_summary(name + '/activation', x)
tf.scalar_summary(name + '/activation_sparsity', tf.nn.zero_fraction(x))
def add_param_summary(regex):
"""
......@@ -45,6 +45,7 @@ def add_param_summary(regex):
if p.get_shape().ndims == 0:
tf.scalar_summary(name, p)
else:
tf.scalar_summary(name + '/sparsity', tf.nn.zero_fraction(p))
tf.histogram_summary(name, p)
def summary_moving_average(cost_var):
......@@ -52,7 +53,7 @@ def summary_moving_average(cost_var):
MOVING_SUMMARY_VARS_KEY, as well as the argument
Return a op to maintain these average
"""
global_step_var = tf.get_default_graph().get_tensor_by_name(GLOBAL_STEP_VAR_NAME)
global_step_var = get_global_step_var()
averager = tf.train.ExponentialMovingAverage(
0.99, num_updates=global_step_var, name='moving_averages')
vars_to_summary = [cost_var] + \
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment