Commit b06fa732 authored by Yuxin Wu's avatar Yuxin Wu

util updates

parent b6c75ae5
......@@ -73,9 +73,7 @@ class Model(ModelDesc):
tf.add_to_collection(MOVING_SUMMARY_VARS_KEY, cost)
# compute the number of failed samples, for ValidationError to use at test time
wrong = tf.not_equal(
tf.cast(tf.argmax(prob, 1), tf.int32), label)
wrong = tf.cast(wrong, tf.float32)
wrong = prediction_incorrect(logits, label)
nr_wrong = tf.reduce_sum(wrong, name='wrong')
# monitor training error
tf.add_to_collection(
......
......@@ -64,9 +64,7 @@ class Model(ModelDesc):
tf.add_to_collection(MOVING_SUMMARY_VARS_KEY, cost)
# compute the number of failed samples, for ValidationError to use at test time
wrong = tf.not_equal(
tf.cast(tf.argmax(prob, 1), tf.int32), label)
wrong = tf.cast(wrong, tf.float32)
wrong = prediction_incorrect(logits, label)
nr_wrong = tf.reduce_sum(wrong, name='wrong')
# monitor training error
tf.add_to_collection(
......@@ -90,7 +88,6 @@ def get_config():
dataset_train = BatchData(dataset.Mnist('train'), 128)
dataset_test = BatchData(dataset.Mnist('test'), 256, remainder=True)
step_per_epoch = dataset_train.size()
step_per_epoch = 30
# prepare session
sess_config = get_default_sess_config()
......
......@@ -3,13 +3,14 @@
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
import numpy as np
import random
import copy
from six.moves import range
from .base import DataFlow, ProxyDataFlow
from ..utils import *
__all__ = ['BatchData', 'FixedSizeData', 'FakeData', 'MapData',
'MapDataComponent', 'RandomChooseData' ]
'MapDataComponent', 'RandomChooseData', 'RandomMixData']
class BatchData(ProxyDataFlow):
def __init__(self, ds, batch_size, remainder=False):
......@@ -182,3 +183,35 @@ class RandomChooseData(DataFlow):
yield next(itr)
except StopIteration:
return
class RandomMixData(DataFlow):
"""
Randomly choose from several dataflow, will eventually exhaust all dataflow.
So it's a perfect mix.
"""
def __init__(self, df_lists):
"""
df_lists: list of dataflow
all DataFlow in df_lists must have size() implemented
"""
self.df_lists = df_lists
self.sizes = [k.size() for k in self.df_lists]
def reset_state(self):
for d in self.df_lists:
d.reset_state()
def size(self):
return sum(self.sizes)
def get_data(self):
sums = np.cumsum(self.sizes)
idxs = np.arange(self.size())
np.random.shuffle(idxs)
idxs = np.array(map(
lambda x: np.searchsorted(sums, x, 'right'), idxs))
itrs = [k.get_data() for k in self.df_lists]
assert idxs.max() == len(itrs) - 1, "{}!={}".format(idxs.max(), len(itrs)-1)
for k in idxs:
yield next(itrs[k])
......@@ -3,6 +3,7 @@
# File: conv2d.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
import numpy as np
import tensorflow as tf
import math
from ._common import *
......@@ -21,6 +22,7 @@ def Conv2D(x, out_channel, kernel_shape,
split: split channels. used in Alexnet
"""
in_shape = x.get_shape().as_list()
num_in = np.prod(in_shape[1:])
in_channel = in_shape[-1]
assert in_channel % split == 0
assert out_channel % split == 0
......@@ -31,7 +33,8 @@ def Conv2D(x, out_channel, kernel_shape,
stride = shape4d(stride)
if W_init is None:
W_init = tf.truncated_normal_initializer(stddev=1e-2)
#W_init = tf.truncated_normal_initializer(stddev=3e-2)
W_init = tf.contrib.layers.xavier_initializer_conv2d()
if b_init is None:
b_init = tf.constant_initializer()
......
......@@ -3,6 +3,7 @@
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
import threading
import multiprocessing
from contextlib import contextmanager
import tensorflow as tf
import atexit
......
......@@ -16,6 +16,19 @@ def one_hot(y, num_labels):
onehot_labels.set_shape([None, num_labels])
return tf.cast(onehot_labels, tf.float32)
def prediction_incorrect(logits, label):
"""
logits: batchxN
label: batch
return a binary vector with 1 means incorrect prediction
"""
with tf.op_scope([logits, label], 'incorrect'):
wrong = tf.not_equal(
tf.argmax(logits, 1),
tf.cast(label, tf.int64))
wrong = tf.cast(wrong, tf.float32)
return wrong
def flatten(x):
return tf.reshape(x, [-1])
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment