Commit 965aa953 authored by Yuxin Wu's avatar Yuxin Wu

some fixes

parent 86d4c589
...@@ -6,7 +6,7 @@ Still in development, but usable. ...@@ -6,7 +6,7 @@ Still in development, but usable.
See some interesting [examples](examples) to learn about the framework: See some interesting [examples](examples) to learn about the framework:
+ [DoReFa-Net: training binary / low bitwidth CNN](examples/DoReFa-Net) + [DoReFa-Net: training binary / low bitwidth CNN](examples/DoReFa-Net)
+ [Double-DQN for playing Atari games](examples/Atari2600) + [Double-DQN and A3C for playing Atari games](examples/Atari2600)
+ [ResNet for Cifar10 classification](examples/ResNet) + [ResNet for Cifar10 classification](examples/ResNet)
+ [IncpetionV3 on ImageNet](examples/Inception/inceptionv3.py) + [IncpetionV3 on ImageNet](examples/Inception/inceptionv3.py)
+ [char-rnn language model](examples/char-rnn) + [char-rnn language model](examples/char-rnn)
......
...@@ -88,17 +88,17 @@ class Model(ModelDesc): ...@@ -88,17 +88,17 @@ class Model(ModelDesc):
l = Conv2D('conv5_3', l, 512) l = Conv2D('conv5_3', l, 512)
b5 = branch('branch5', l, 16) b5 = branch('branch5', l, 16)
final_map = Conv2D('convfcweight', #final_map = Conv2D('convfcweight',
tf.concat(3, [b1, b2, b3, b4, b5]), 1, 1, #tf.concat(3, [b1, b2, b3, b4, b5]), 1, 1,
W_init=tf.constant_initializer(0.2), use_bias=False) #W_init=tf.constant_initializer(0.2), use_bias=False)
final_map = tf.squeeze(final_map, [3], name='predmap') #final_map = tf.squeeze(final_map, [3], name='predmap')
#final_map = tf.squeeze(tf.mul(0.2, b1 + b2 + b3 + b4 + b5), final_map = tf.squeeze(tf.mul(0.2, b1 + b2 + b3 + b4 + b5),
#[3], name='predmap') [3], name='predmap')
costs = [] costs = []
for idx, b in enumerate([b1, b2, b3, b4, b5, final_map]): for idx, b in enumerate([b1, b2, b3, b4, b5, final_map]):
output = tf.nn.sigmoid(b, name='output{}'.format(idx+1)) output = tf.nn.sigmoid(b, name='output{}'.format(idx+1))
xentropy = class_balanced_binary_class_cross_entropy( xentropy = class_balanced_sigmoid_binary_class_cross_entropy(
output, edgemap, b, edgemap,
name='xentropy{}'.format(idx+1)) name='xentropy{}'.format(idx+1))
costs.append(xentropy) costs.append(xentropy)
...@@ -138,6 +138,10 @@ def get_data(name): ...@@ -138,6 +138,10 @@ def get_data(name):
h0, w0, newh, neww = param h0, w0, newh, neww = param
return img[h0:h0+newh,w0:w0+neww] return img[h0:h0+newh,w0:w0+neww]
def f(m):
m[m>=0.50] = 1
m[m<0.50] = 0
return m
if isTrain: if isTrain:
shape_aug = [ shape_aug = [
imgaug.RandomResize(xrange=(0.7,1.5), yrange=(0.7,1.5), imgaug.RandomResize(xrange=(0.7,1.5), yrange=(0.7,1.5),
...@@ -152,18 +156,12 @@ def get_data(name): ...@@ -152,18 +156,12 @@ def get_data(name):
IMAGE_SHAPE = (320, 480) IMAGE_SHAPE = (320, 480)
shape_aug = [imgaug.CenterCrop(IMAGE_SHAPE)] shape_aug = [imgaug.CenterCrop(IMAGE_SHAPE)]
ds = AugmentImageComponents(ds, shape_aug, (0, 1)) ds = AugmentImageComponents(ds, shape_aug, (0, 1))
def f(m):
m[m>=0.51] = 1
m[m<0.51] = 0
return m
ds = MapDataComponent(ds, f, 1) ds = MapDataComponent(ds, f, 1)
if isTrain: if isTrain:
augmentors = [ augmentors = [
imgaug.Brightness(63, clip=False), imgaug.Brightness(63, clip=False),
imgaug.Contrast((0.4,1.5)), imgaug.Contrast((0.4,1.5)),
imgaug.GaussianNoise(),
] ]
ds = AugmentImageComponent(ds, augmentors) ds = AugmentImageComponent(ds, augmentors)
ds = BatchDataByShape(ds, 8, idx=0) ds = BatchDataByShape(ds, 8, idx=0)
...@@ -212,7 +210,7 @@ def get_config(): ...@@ -212,7 +210,7 @@ def get_config():
def run(model_path, image_path): def run(model_path, image_path):
pred_config = PredictConfig( pred_config = PredictConfig(
model=Model(False), model=Model(),
input_data_mapping=[0], input_data_mapping=[0],
session_init=get_model_loader(model_path), session_init=get_model_loader(model_path),
output_var_names=['output' + str(k) for k in range(1, 7)]) output_var_names=['output' + str(k) for k in range(1, 7)])
......
...@@ -5,7 +5,7 @@ Examples with __reproducible__ and meaningful performance. ...@@ -5,7 +5,7 @@ Examples with __reproducible__ and meaningful performance.
+ [An illustrative mnist example](mnist-convnet.py) + [An illustrative mnist example](mnist-convnet.py)
+ [A tiny SVHN ConvNet with 97.5% accuracy](svhn-digit-convnet.py) + [A tiny SVHN ConvNet with 97.5% accuracy](svhn-digit-convnet.py)
+ [Reproduce some reinforcement learning papers](Atari2600) + Reinforcement learning (DQN, A3C) on [Atari games](Atari2600) and [demos on OpenAI Gym](OpenAIGym).
+ [char-rnn for fun](char-rnn) + [char-rnn for fun](char-rnn)
+ [DisturbLabel, because I don't believe the paper](DisturbLabel) + [DisturbLabel, because I don't believe the paper](DisturbLabel)
+ [DoReFa-Net: binary / low-bitwidth CNN on ImageNet](DoReFa-Net) + [DoReFa-Net: binary / low-bitwidth CNN on ImageNet](DoReFa-Net)
......
...@@ -47,11 +47,11 @@ class ModelSaver(Callback): ...@@ -47,11 +47,11 @@ class ModelSaver(Callback):
if name not in var_dict: if name not in var_dict:
if name != v.name: if name != v.name:
logger.info( logger.info(
"{} renamed to {} when saving model.".format(v.name, name)) "[ModelSaver] {} renamed to {} when saving model.".format(v.name, name))
var_dict[name] = v var_dict[name] = v
else: else:
logger.warn("Variable {} won't be saved \ logger.info("[ModelSaver] Variable {} won't be saved \
because {} will be saved".format(v.name, var_dict[name].name)) due to an alternative in a different tower".format(v.name, var_dict[name].name))
return var_dict return var_dict
def _trigger_epoch(self): def _trigger_epoch(self):
......
...@@ -23,9 +23,9 @@ class JpegNoise(ImageAugmentor): ...@@ -23,9 +23,9 @@ class JpegNoise(ImageAugmentor):
class GaussianNoise(ImageAugmentor): class GaussianNoise(ImageAugmentor):
def __init__(self, scale=10, clip=True): def __init__(self, sigma=1, clip=True):
""" """
Add a gaussian noise of the same shape to img. Add a gaussian noise N(0, sigma^2) of the same shape to img.
""" """
super(GaussianNoise, self).__init__() super(GaussianNoise, self).__init__()
self._init(locals()) self._init(locals())
...@@ -34,7 +34,7 @@ class GaussianNoise(ImageAugmentor): ...@@ -34,7 +34,7 @@ class GaussianNoise(ImageAugmentor):
return self.rng.randn(*img.shape) return self.rng.randn(*img.shape)
def _augment(self, img, noise): def _augment(self, img, noise):
ret = img + noise ret = img + noise * self.sigma
if self.clip: if self.clip:
ret = np.clip(ret, 0, 255) ret = np.clip(ret, 0, 255)
return ret return ret
......
...@@ -44,10 +44,9 @@ class Flip(ImageAugmentor): ...@@ -44,10 +44,9 @@ class Flip(ImageAugmentor):
def _fprop_coord(self, coord, param): def _fprop_coord(self, coord, param):
raise NotImplementedError() raise NotImplementedError()
class Resize(ImageAugmentor): class Resize(ImageAugmentor):
""" Resize image to a target size""" """ Resize image to a target size"""
def __init__(self, shape): def __init__(self, shape, interp=cv2.INTER_CUBIC):
""" """
:param shape: shape in (h, w) :param shape: shape in (h, w)
""" """
...@@ -56,7 +55,7 @@ class Resize(ImageAugmentor): ...@@ -56,7 +55,7 @@ class Resize(ImageAugmentor):
def _augment(self, img, _): def _augment(self, img, _):
return cv2.resize( return cv2.resize(
img, self.shape[::-1], img, self.shape[::-1],
interpolation=cv2.INTER_CUBIC) interpolation=self.interp)
class RandomResize(ImageAugmentor): class RandomResize(ImageAugmentor):
""" randomly rescale w and h of the image""" """ randomly rescale w and h of the image"""
......
...@@ -42,9 +42,10 @@ def BatchNorm(x, use_local_stat=None, decay=0.9, epsilon=1e-5): ...@@ -42,9 +42,10 @@ def BatchNorm(x, use_local_stat=None, decay=0.9, epsilon=1e-5):
n_out = shape[-1] # channel n_out = shape[-1] # channel
assert n_out is not None assert n_out is not None
beta = tf.get_variable('beta', [n_out]) beta = tf.get_variable('beta', [n_out],
initializer=tf.zeros_initializer)
gamma = tf.get_variable('gamma', [n_out], gamma = tf.get_variable('gamma', [n_out],
initializer=tf.ones_initializer) initializer=tf.ones_initializer)
if len(shape) == 2: if len(shape) == 2:
batch_mean, batch_var = tf.nn.moments(x, [0], keep_dims=False) batch_mean, batch_var = tf.nn.moments(x, [0], keep_dims=False)
......
...@@ -46,12 +46,38 @@ def class_balanced_binary_class_cross_entropy(pred, label, name='cross_entropy_l ...@@ -46,12 +46,38 @@ def class_balanced_binary_class_cross_entropy(pred, label, name='cross_entropy_l
count_pos = tf.reduce_sum(y) count_pos = tf.reduce_sum(y)
beta = count_neg / (count_neg + count_pos) beta = count_neg / (count_neg + count_pos)
eps = 1e-8 eps = 1e-12
loss_pos = -beta * tf.reduce_mean(y * tf.log(z + eps)) loss_pos = -beta * tf.reduce_mean(y * tf.log(z + eps))
loss_neg = (1. - beta) * tf.reduce_mean((1. - y) * tf.log(1. - z + eps)) loss_neg = (1. - beta) * tf.reduce_mean((1. - y) * tf.log(1. - z + eps))
cost = tf.sub(loss_pos, loss_neg, name=name) cost = tf.sub(loss_pos, loss_neg, name=name)
return cost return cost
def class_balanced_sigmoid_binary_class_cross_entropy(pred, label, name='cross_entropy_loss'):
"""
The class-balanced cross entropy loss for binary classification,
as in `Holistically-Nested Edge Detection
<http://arxiv.org/abs/1504.06375>`_.
:param pred: size: b x ANYTHING. the logits.
:param label: size: b x ANYTHING. the ground truth in {0,1}.
:returns: class-balanced binary classification cross entropy loss
"""
z = batch_flatten(pred)
y = tf.cast(batch_flatten(label), tf.float32)
count_neg = tf.reduce_sum(1. - y)
count_pos = tf.reduce_sum(y)
beta = count_neg / (count_neg + count_pos)
#eps = 1e-12
logstable = tf.log(1 + tf.exp(-tf.abs(z)))
loss_pos = -beta * tf.reduce_mean(-y *
(logstable - tf.minimum(0, z)))
loss_neg = (1. - beta) * tf.reduce_mean((y - 1.) *
(logstable + tf.maximum(z, 0)))
cost = tf.sub(loss_pos, loss_neg, name=name)
return cost
def print_stat(x, message=None): def print_stat(x, message=None):
""" a simple print op. """ a simple print op.
Use it like: x = print_stat(x) Use it like: x = print_stat(x)
......
...@@ -112,7 +112,6 @@ def get_dataset_path(*args): ...@@ -112,7 +112,6 @@ def get_dataset_path(*args):
assert os.path.isdir(d), d assert os.path.isdir(d), d
return os.path.join(d, *args) return os.path.join(d, *args)
def get_tqdm_kwargs(**kwargs): def get_tqdm_kwargs(**kwargs):
default = dict( default = dict(
smoothing=0.5, smoothing=0.5,
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment