Commit 0266827f authored by Yuxin Wu's avatar Yuxin Wu

misc changes

parent ce2cc714
...@@ -133,7 +133,7 @@ def get_data(name): ...@@ -133,7 +133,7 @@ def get_data(name):
if isTrain: if isTrain:
shape_aug = [ shape_aug = [
imgaug.RandomResize(xrange=(0.7,1.5), yrange=(0.7,1.5), imgaug.RandomResize(xrange=(0.7,1.5), yrange=(0.7,1.5),
aspect_ratio_thres=0.1), aspect_ratio_thres=0.15),
imgaug.RotationAndCropValid(90), imgaug.RotationAndCropValid(90),
CropMultiple16(), CropMultiple16(),
imgaug.Flip(horiz=True), imgaug.Flip(horiz=True),
...@@ -192,8 +192,7 @@ def get_config(): ...@@ -192,8 +192,7 @@ def get_config():
ModelSaver(), ModelSaver(),
HumanHyperParamSetter('learning_rate'), HumanHyperParamSetter('learning_rate'),
InferenceRunner(dataset_val, InferenceRunner(dataset_val,
BinaryClassificationStats('prediction', BinaryClassificationStats('prediction', 'edgemap'))
'edgemap'))
]), ]),
model=Model(), model=Model(),
step_per_epoch=step_per_epoch, step_per_epoch=step_per_epoch,
......
...@@ -85,6 +85,7 @@ class RandomResize(ImageAugmentor): ...@@ -85,6 +85,7 @@ class RandomResize(ImageAugmentor):
cnt += 1 cnt += 1
if cnt > 50: if cnt > 50:
logger.warn("RandomResize failed to augment an image") logger.warn("RandomResize failed to augment an image")
return img.shape[1], img.shape[0]
def _augment(self, img, dsize): def _augment(self, img, dsize):
return cv2.resize(img, dsize, interpolation=cv2.INTER_CUBIC) return cv2.resize(img, dsize, interpolation=cv2.INTER_CUBIC)
......
...@@ -29,15 +29,6 @@ def batch_flatten(x): ...@@ -29,15 +29,6 @@ def batch_flatten(x):
return tf.reshape(x, [-1, np.prod(shape)]) return tf.reshape(x, [-1, np.prod(shape)])
return tf.reshape(x, tf.pack([tf.shape(x)[0], -1])) return tf.reshape(x, tf.pack([tf.shape(x)[0], -1]))
def logSoftmax(x):
"""
Batch log softmax.
:param x: NxC tensor.
:returns: NxC tensor.
"""
logger.warn("symbf.logSoftmax is deprecated in favor of tf.nn.log_softmax")
return tf.nn.log_softmax(x)
def class_balanced_binary_class_cross_entropy(pred, label, name='cross_entropy_loss'): def class_balanced_binary_class_cross_entropy(pred, label, name='cross_entropy_loss'):
""" """
The class-balanced cross entropy loss for binary classification, The class-balanced cross entropy loss for binary classification,
...@@ -56,10 +47,9 @@ def class_balanced_binary_class_cross_entropy(pred, label, name='cross_entropy_l ...@@ -56,10 +47,9 @@ def class_balanced_binary_class_cross_entropy(pred, label, name='cross_entropy_l
beta = count_neg / (count_neg + count_pos) beta = count_neg / (count_neg + count_pos)
eps = 1e-8 eps = 1e-8
loss_pos = -beta * tf.reduce_mean(y * tf.log(tf.abs(z) + eps), 1) loss_pos = -beta * tf.reduce_mean(y * tf.log(z + eps))
loss_neg = (1. - beta) * tf.reduce_mean((1. - y) * tf.log(tf.abs(1. - z) + eps), 1) loss_neg = (1. - beta) * tf.reduce_mean((1. - y) * tf.log(1. - z + eps))
cost = tf.sub(loss_pos, loss_neg) cost = tf.sub(loss_pos, loss_neg, name=name)
cost = tf.reduce_mean(cost, name=name)
return cost return cost
def print_stat(x, message=None): def print_stat(x, message=None):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment