Commit 795f016a authored by Yuxin Wu's avatar Yuxin Wu

Move more symbolic functions to examples

parent 3e30bda4
...@@ -13,13 +13,38 @@ import sys ...@@ -13,13 +13,38 @@ import sys
os.environ['TENSORPACK_TRAIN_API'] = 'v2' # will become default soon os.environ['TENSORPACK_TRAIN_API'] = 'v2' # will become default soon
from tensorpack import * from tensorpack import *
import tensorpack.tfutils.symbolic_functions as symbf
from tensorpack.dataflow import dataset from tensorpack.dataflow import dataset
from tensorpack.utils.gpu import get_nr_gpu from tensorpack.utils.gpu import get_nr_gpu
from tensorpack.tfutils import optimizer from tensorpack.tfutils import optimizer
from tensorpack.tfutils.summary import * from tensorpack.tfutils.summary import *
def class_balanced_sigmoid_cross_entropy(logits, label, name='cross_entropy_loss'):
"""
The class-balanced cross entropy loss,
as in `Holistically-Nested Edge Detection
<http://arxiv.org/abs/1504.06375>`_.
Args:
logits: of shape (b, ...).
label: of the same shape. the ground truth in {0,1}.
Returns:
class-balanced cross entropy loss.
"""
with tf.name_scope('class_balanced_sigmoid_cross_entropy'):
y = tf.cast(label, tf.float32)
count_neg = tf.reduce_sum(1. - y)
count_pos = tf.reduce_sum(y)
beta = count_neg / (count_neg + count_pos)
pos_weight = beta / (1 - beta)
cost = tf.nn.weighted_cross_entropy_with_logits(logits=logits, targets=y, pos_weight=pos_weight)
cost = tf.reduce_mean(cost * (1 - beta))
zero = tf.equal(count_pos, 0.0)
return tf.where(zero, 0.0, cost, name=name)
class Model(ModelDesc): class Model(ModelDesc):
def _get_inputs(self): def _get_inputs(self):
return [InputDesc(tf.float32, [None, None, None, 3], 'image'), return [InputDesc(tf.float32, [None, None, None, 3], 'image'),
...@@ -76,7 +101,7 @@ class Model(ModelDesc): ...@@ -76,7 +101,7 @@ class Model(ModelDesc):
costs = [] costs = []
for idx, b in enumerate([b1, b2, b3, b4, b5, final_map]): for idx, b in enumerate([b1, b2, b3, b4, b5, final_map]):
output = tf.nn.sigmoid(b, name='output{}'.format(idx + 1)) output = tf.nn.sigmoid(b, name='output{}'.format(idx + 1))
xentropy = symbf.class_balanced_sigmoid_cross_entropy( xentropy = class_balanced_sigmoid_cross_entropy(
b, edgemap, b, edgemap,
name='xentropy{}'.format(idx + 1)) name='xentropy{}'.format(idx + 1))
costs.append(xentropy) costs.append(xentropy)
......
...@@ -4,6 +4,8 @@ ...@@ -4,6 +4,8 @@
import cv2 import cv2
import sys import sys
import os import os
from contextlib import contextmanager
import numpy as np import numpy as np
import tensorflow as tf import tensorflow as tf
import tensorflow.contrib.slim as slim import tensorflow.contrib.slim as slim
...@@ -15,6 +17,28 @@ import tensorpack.utils.viz as viz ...@@ -15,6 +17,28 @@ import tensorpack.utils.viz as viz
IMAGE_SIZE = 224 IMAGE_SIZE = 224
@contextmanager
def guided_relu():
"""
Returns:
A context where the gradient of :meth:`tf.nn.relu` is replaced by
guided back-propagation, as described in the paper:
`Striving for Simplicity: The All Convolutional Net
<https://arxiv.org/abs/1412.6806>`_
"""
from tensorflow.python.ops import gen_nn_ops # noqa
@tf.RegisterGradient("GuidedReLU")
def GuidedReluGrad(op, grad):
return tf.where(0. < grad,
gen_nn_ops._relu_grad(grad, op.outputs[0]),
tf.zeros(grad.get_shape()))
g = tf.get_default_graph()
with g.gradient_override_map({'Relu': 'GuidedReLU'}):
yield
class Model(tp.ModelDesc): class Model(tp.ModelDesc):
def _get_inputs(self): def _get_inputs(self):
return [tp.InputDesc(tf.float32, (IMAGE_SIZE, IMAGE_SIZE, 3), 'image')] return [tp.InputDesc(tf.float32, (IMAGE_SIZE, IMAGE_SIZE, 3), 'image')]
...@@ -22,7 +46,7 @@ class Model(tp.ModelDesc): ...@@ -22,7 +46,7 @@ class Model(tp.ModelDesc):
def _build_graph(self, inputs): def _build_graph(self, inputs):
orig_image = inputs[0] orig_image = inputs[0]
mean = tf.get_variable('resnet_v1_50/mean_rgb', shape=[3]) mean = tf.get_variable('resnet_v1_50/mean_rgb', shape=[3])
with tp.symbolic_functions.guided_relu(): with guided_relu():
with slim.arg_scope(resnet_v1.resnet_arg_scope(is_training=False)): with slim.arg_scope(resnet_v1.resnet_arg_scope(is_training=False)):
image = tf.expand_dims(orig_image - mean, 0) image = tf.expand_dims(orig_image - mean, 0)
logits, _ = resnet_v1.resnet_v1_50(image, 1000) logits, _ = resnet_v1.resnet_v1_50(image, 1000)
......
...@@ -15,7 +15,6 @@ os.environ['TENSORPACK_TRAIN_API'] = 'v2' # will become default soon ...@@ -15,7 +15,6 @@ os.environ['TENSORPACK_TRAIN_API'] = 'v2' # will become default soon
from tensorpack import * from tensorpack import *
from tensorpack.dataflow import dataset from tensorpack.dataflow import dataset
import tensorflow as tf import tensorflow as tf
import tensorpack.tfutils.symbolic_functions as symbf
IMAGE_SIZE = 28 IMAGE_SIZE = 28
...@@ -106,8 +105,7 @@ class Model(ModelDesc): ...@@ -106,8 +105,7 @@ class Model(ModelDesc):
cost = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=logits, labels=label) cost = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=logits, labels=label)
cost = tf.reduce_mean(cost, name='cross_entropy_loss') cost = tf.reduce_mean(cost, name='cross_entropy_loss')
wrong = symbf.prediction_incorrect(logits, label, name='incorrect') accuracy = tf.reduce_mean(tf.to_float(tf.nn.in_top_k(logits, label, 1)), name='accuracy')
accuracy = symbf.accuracy(logits, label)
wd_cost = tf.multiply(1e-5, wd_cost = tf.multiply(1e-5,
regularize_cost('fc.*/W', tf.nn.l2_loss), regularize_cost('fc.*/W', tf.nn.l2_loss),
...@@ -144,7 +142,7 @@ def get_config(): ...@@ -144,7 +142,7 @@ def get_config():
callbacks=[ callbacks=[
ModelSaver(), ModelSaver(),
InferenceRunner( InferenceRunner(
dataset_test, [ScalarStats('cross_entropy_loss'), ClassificationError('incorrect')]), dataset_test, ScalarStats(['cross_entropy_loss', 'accuracy'])),
], ],
steps_per_epoch=dataset_train.size(), steps_per_epoch=dataset_train.size(),
max_epoch=100, max_epoch=100,
......
...@@ -9,7 +9,6 @@ import os ...@@ -9,7 +9,6 @@ import os
os.environ['TENSORPACK_TRAIN_API'] = 'v2' # will become default soon os.environ['TENSORPACK_TRAIN_API'] = 'v2' # will become default soon
from tensorpack import * from tensorpack import *
from tensorpack.tfutils.symbolic_functions import prediction_incorrect
from tensorpack.dataflow import dataset from tensorpack.dataflow import dataset
from tensorpack.tfutils.summary import * from tensorpack.tfutils.summary import *
import tensorflow as tf import tensorflow as tf
......
...@@ -17,6 +17,7 @@ def prediction_incorrect(logits, label, topk=1, name='incorrect_vector'): ...@@ -17,6 +17,7 @@ def prediction_incorrect(logits, label, topk=1, name='incorrect_vector'):
tf.float32, name=name) tf.float32, name=name)
@deprecated("Please implement it by yourself.", "2018-02-28")
def accuracy(logits, label, topk=1, name='accuracy'): def accuracy(logits, label, topk=1, name='accuracy'):
""" """
Args: Args:
...@@ -46,6 +47,7 @@ def batch_flatten(x): ...@@ -46,6 +47,7 @@ def batch_flatten(x):
return tf.reshape(x, tf.stack([tf.shape(x)[0], -1])) return tf.reshape(x, tf.stack([tf.shape(x)[0], -1]))
@deprecated("Please implement it by yourself.", "2018-02-28")
def class_balanced_cross_entropy(pred, label, name='cross_entropy_loss'): def class_balanced_cross_entropy(pred, label, name='cross_entropy_loss'):
""" """
The class-balanced cross entropy loss, The class-balanced cross entropy loss,
...@@ -73,6 +75,7 @@ def class_balanced_cross_entropy(pred, label, name='cross_entropy_loss'): ...@@ -73,6 +75,7 @@ def class_balanced_cross_entropy(pred, label, name='cross_entropy_loss'):
return cost return cost
@deprecated("Please implement it by yourself.", "2018-02-28")
def class_balanced_sigmoid_cross_entropy(logits, label, name='cross_entropy_loss'): def class_balanced_sigmoid_cross_entropy(logits, label, name='cross_entropy_loss'):
""" """
This function accepts logits rather than predictions, and is more numerically stable than This function accepts logits rather than predictions, and is more numerically stable than
...@@ -203,6 +206,7 @@ def psnr(prediction, ground_truth, maxp=None, name='psnr'): ...@@ -203,6 +206,7 @@ def psnr(prediction, ground_truth, maxp=None, name='psnr'):
@contextmanager @contextmanager
@deprecated("Please implement it by yourself.", "2018-02-28")
def guided_relu(): def guided_relu():
""" """
Returns: Returns:
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment