Commit b06fa732 authored by Yuxin Wu's avatar Yuxin Wu

util updates

parent b6c75ae5
...@@ -73,9 +73,7 @@ class Model(ModelDesc): ...@@ -73,9 +73,7 @@ class Model(ModelDesc):
tf.add_to_collection(MOVING_SUMMARY_VARS_KEY, cost) tf.add_to_collection(MOVING_SUMMARY_VARS_KEY, cost)
# compute the number of failed samples, for ValidationError to use at test time # compute the number of failed samples, for ValidationError to use at test time
wrong = tf.not_equal( wrong = prediction_incorrect(logits, label)
tf.cast(tf.argmax(prob, 1), tf.int32), label)
wrong = tf.cast(wrong, tf.float32)
nr_wrong = tf.reduce_sum(wrong, name='wrong') nr_wrong = tf.reduce_sum(wrong, name='wrong')
# monitor training error # monitor training error
tf.add_to_collection( tf.add_to_collection(
......
...@@ -64,9 +64,7 @@ class Model(ModelDesc): ...@@ -64,9 +64,7 @@ class Model(ModelDesc):
tf.add_to_collection(MOVING_SUMMARY_VARS_KEY, cost) tf.add_to_collection(MOVING_SUMMARY_VARS_KEY, cost)
# compute the number of failed samples, for ValidationError to use at test time # compute the number of failed samples, for ValidationError to use at test time
wrong = tf.not_equal( wrong = prediction_incorrect(logits, label)
tf.cast(tf.argmax(prob, 1), tf.int32), label)
wrong = tf.cast(wrong, tf.float32)
nr_wrong = tf.reduce_sum(wrong, name='wrong') nr_wrong = tf.reduce_sum(wrong, name='wrong')
# monitor training error # monitor training error
tf.add_to_collection( tf.add_to_collection(
...@@ -90,7 +88,6 @@ def get_config(): ...@@ -90,7 +88,6 @@ def get_config():
dataset_train = BatchData(dataset.Mnist('train'), 128) dataset_train = BatchData(dataset.Mnist('train'), 128)
dataset_test = BatchData(dataset.Mnist('test'), 256, remainder=True) dataset_test = BatchData(dataset.Mnist('test'), 256, remainder=True)
step_per_epoch = dataset_train.size() step_per_epoch = dataset_train.size()
step_per_epoch = 30
# prepare session # prepare session
sess_config = get_default_sess_config() sess_config = get_default_sess_config()
......
...@@ -3,13 +3,14 @@ ...@@ -3,13 +3,14 @@
# Author: Yuxin Wu <ppwwyyxx@gmail.com> # Author: Yuxin Wu <ppwwyyxx@gmail.com>
import numpy as np import numpy as np
import random
import copy import copy
from six.moves import range from six.moves import range
from .base import DataFlow, ProxyDataFlow from .base import DataFlow, ProxyDataFlow
from ..utils import * from ..utils import *
__all__ = ['BatchData', 'FixedSizeData', 'FakeData', 'MapData', __all__ = ['BatchData', 'FixedSizeData', 'FakeData', 'MapData',
'MapDataComponent', 'RandomChooseData' ] 'MapDataComponent', 'RandomChooseData', 'RandomMixData']
class BatchData(ProxyDataFlow): class BatchData(ProxyDataFlow):
def __init__(self, ds, batch_size, remainder=False): def __init__(self, ds, batch_size, remainder=False):
...@@ -182,3 +183,35 @@ class RandomChooseData(DataFlow): ...@@ -182,3 +183,35 @@ class RandomChooseData(DataFlow):
yield next(itr) yield next(itr)
except StopIteration: except StopIteration:
return return
class RandomMixData(DataFlow):
"""
Randomly choose from several dataflow, will eventually exhaust all dataflow.
So it's a perfect mix.
"""
def __init__(self, df_lists):
"""
df_lists: list of dataflow
all DataFlow in df_lists must have size() implemented
"""
self.df_lists = df_lists
self.sizes = [k.size() for k in self.df_lists]
def reset_state(self):
for d in self.df_lists:
d.reset_state()
def size(self):
return sum(self.sizes)
def get_data(self):
sums = np.cumsum(self.sizes)
idxs = np.arange(self.size())
np.random.shuffle(idxs)
idxs = np.array(map(
lambda x: np.searchsorted(sums, x, 'right'), idxs))
itrs = [k.get_data() for k in self.df_lists]
assert idxs.max() == len(itrs) - 1, "{}!={}".format(idxs.max(), len(itrs)-1)
for k in idxs:
yield next(itrs[k])
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
# File: conv2d.py # File: conv2d.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com> # Author: Yuxin Wu <ppwwyyxx@gmail.com>
import numpy as np
import tensorflow as tf import tensorflow as tf
import math import math
from ._common import * from ._common import *
...@@ -21,6 +22,7 @@ def Conv2D(x, out_channel, kernel_shape, ...@@ -21,6 +22,7 @@ def Conv2D(x, out_channel, kernel_shape,
split: split channels. used in Alexnet split: split channels. used in Alexnet
""" """
in_shape = x.get_shape().as_list() in_shape = x.get_shape().as_list()
num_in = np.prod(in_shape[1:])
in_channel = in_shape[-1] in_channel = in_shape[-1]
assert in_channel % split == 0 assert in_channel % split == 0
assert out_channel % split == 0 assert out_channel % split == 0
...@@ -31,7 +33,8 @@ def Conv2D(x, out_channel, kernel_shape, ...@@ -31,7 +33,8 @@ def Conv2D(x, out_channel, kernel_shape,
stride = shape4d(stride) stride = shape4d(stride)
if W_init is None: if W_init is None:
W_init = tf.truncated_normal_initializer(stddev=1e-2) #W_init = tf.truncated_normal_initializer(stddev=3e-2)
W_init = tf.contrib.layers.xavier_initializer_conv2d()
if b_init is None: if b_init is None:
b_init = tf.constant_initializer() b_init = tf.constant_initializer()
......
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
# Author: Yuxin Wu <ppwwyyxx@gmail.com> # Author: Yuxin Wu <ppwwyyxx@gmail.com>
import threading import threading
import multiprocessing
from contextlib import contextmanager from contextlib import contextmanager
import tensorflow as tf import tensorflow as tf
import atexit import atexit
......
...@@ -16,6 +16,19 @@ def one_hot(y, num_labels): ...@@ -16,6 +16,19 @@ def one_hot(y, num_labels):
onehot_labels.set_shape([None, num_labels]) onehot_labels.set_shape([None, num_labels])
return tf.cast(onehot_labels, tf.float32) return tf.cast(onehot_labels, tf.float32)
def prediction_incorrect(logits, label):
"""
logits: batchxN
label: batch
return a binary vector with 1 means incorrect prediction
"""
with tf.op_scope([logits, label], 'incorrect'):
wrong = tf.not_equal(
tf.argmax(logits, 1),
tf.cast(label, tf.int64))
wrong = tf.cast(wrong, tf.float32)
return wrong
def flatten(x): def flatten(x):
return tf.reshape(x, [-1]) return tf.reshape(x, [-1])
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment