Commit f6b502d7 authored by Yuxin Wu's avatar Yuxin Wu

bsds dataset

parent 1c3d8741
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
# File: __init__.py # File: __init__.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com> # Author: Yuxin Wu <ppwwyyxx@gmail.com>
import cv2 # fix https://github.com/tensorflow/tensorflow/issues/1924
from . import models from . import models
from . import train from . import train
from . import utils from . import utils
......
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# File: bsds500.py
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
import os, glob
import cv2
import numpy as np
from scipy.io import loadmat
from ...utils import logger, get_rng
from ...utils.fs import download
from ..base import DataFlow
__all__ = ['BSDS500']
DATA_URL = "http://www.eecs.berkeley.edu/Research/Projects/CS/vision/grouping/BSR/BSR_bsds500.tgz"
IMG_W, IMG_H = 481, 321
class BSDS500(DataFlow):
"""
`Berkeley Segmentation Data Set and Benchmarks 500
<http://www.eecs.berkeley.edu/Research/Projects/CS/vision/grouping/resources.html#bsds500>`_.
Produce (image, label) pair, where image has shape (321, 481, 3) and
ranges in [0,255]. Label is binary and has shape (321, 481).
Those pixels annotated as boundaries by >= 3 out of 5 annotators are
considered positive examples. This is used in `Holistically-Nested Edge Detection
<http://arxiv.org/abs/1504.06375>`_.
"""
def __init__(self, name, data_dir=None, shuffle=True):
"""
:param name: 'train', 'test', 'val'
:param data_dir: a directory containing the original 'BSR' directory.
"""
# check and download data
if data_dir is None:
data_dir = os.path.join(os.path.dirname(__file__), 'bsds500_data')
if not os.path.isdir(os.path.join(data_dir, 'BSR')):
download(DATA_URL, data_dir)
filename = DATA_URL.split('/')[-1]
filepath = os.path.join(data_dir, filename)
import tarfile
tarfile.open(filepath, 'r:gz').extractall(data_dir)
self.data_root = os.path.join(data_dir, 'BSR', 'BSDS500', 'data')
assert os.path.isdir(self.data_root)
self.shuffle = shuffle
assert name in ['train', 'test', 'val']
self._load(name)
self.rng = get_rng(self)
def reset_state(self):
self.rng = get_rng(self)
def _load(self, name):
image_glob = os.path.join(self.data_root, 'images', name, '*.jpg')
image_files = glob.glob(image_glob)
gt_dir = os.path.join(self.data_root, 'groundTruth', name)
self.data = np.zeros((len(image_files), IMG_H, IMG_W, 3), dtype='uint8')
self.label = np.zeros((len(image_files), IMG_H, IMG_W), dtype='bool')
for idx, f in enumerate(image_files):
im = cv2.imread(f, cv2.IMREAD_COLOR)
assert im is not None
if im.shape[0] > im.shape[1]:
im = np.transpose(im, (1,0,2))
assert im.shape[:2] == (IMG_H, IMG_W), "{} != {}".format(im.shape[:2], (IMG_H, IMG_W))
imgid = os.path.basename(f).split('.')[0]
gt_file = os.path.join(gt_dir, imgid)
gt = loadmat(gt_file)['groundTruth'][0]
gt = sum(gt[k]['Boundaries'][0][0] for k in range(5))
gt[gt < 3] = 0
gt[gt != 0] = 1
if gt.shape[0] > gt.shape[1]:
gt = gt.transpose()
assert gt.shape == (IMG_H, IMG_W)
self.data[idx] = im
self.label[idx] = gt
def size(self):
return self.data.shape[0]
def get_data(self):
idxs = np.arange(self.data.shape[0])
if self.shuffle:
self.rng.shuffle(idxs)
for k in idxs:
yield [self.data[k], self.label[k]]
if __name__ == '__main__':
a = BSDS500('val')
for k in a.get_data():
cv2.imshow("haha", k[1].astype('uint8')*255)
cv2.waitKey(1000)
...@@ -9,7 +9,6 @@ import random ...@@ -9,7 +9,6 @@ import random
import six import six
from six.moves import urllib, range from six.moves import urllib, range
import copy import copy
import tarfile
import logging import logging
from ...utils import logger, get_rng from ...utils import logger, get_rng
...@@ -31,6 +30,7 @@ def maybe_download_and_extract(dest_directory): ...@@ -31,6 +30,7 @@ def maybe_download_and_extract(dest_directory):
download(DATA_URL, dest_directory) download(DATA_URL, dest_directory)
filename = DATA_URL.split('/')[-1] filename = DATA_URL.split('/')[-1]
filepath = os.path.join(dest_directory, filename) filepath = os.path.join(dest_directory, filename)
import tarfile
tarfile.open(filepath, 'r:gz').extractall(dest_directory) tarfile.open(filepath, 'r:gz').extractall(dest_directory)
def read_cifar10(filenames): def read_cifar10(filenames):
......
...@@ -117,7 +117,7 @@ class ILSVRC12(DataFlow): ...@@ -117,7 +117,7 @@ class ILSVRC12(DataFlow):
for k in idxs: for k in idxs:
tp = self.imglist[k] tp = self.imglist[k]
fname = os.path.join(self.dir, self.name, tp[0]).strip() fname = os.path.join(self.dir, self.name, tp[0]).strip()
im = cv2.imread(fname) im = cv2.imread(fname, cv2.IMREAD_COLOR)
assert im is not None, fname assert im is not None, fname
if im.ndim == 2: if im.ndim == 2:
im = np.expand_dims(im, 2).repeat(3,2) im = np.expand_dims(im, 2).repeat(3,2)
......
...@@ -26,8 +26,8 @@ class SVHNDigit(DataFlow): ...@@ -26,8 +26,8 @@ class SVHNDigit(DataFlow):
def __init__(self, name, data_dir=None, shuffle=True): def __init__(self, name, data_dir=None, shuffle=True):
""" """
name: 'train', 'test', or 'extra' :param name: 'train', 'test', or 'extra'
data_dir: a directory containing the original {train,test,extra}_32x32.mat :param data_dir: a directory containing the original {train,test,extra}_32x32.mat
""" """
self.shuffle = shuffle self.shuffle = shuffle
self.rng = get_rng(self) self.rng = get_rng(self)
......
...@@ -54,3 +54,5 @@ class AugmentImageComponent(MapDataComponent): ...@@ -54,3 +54,5 @@ class AugmentImageComponent(MapDataComponent):
def reset_state(self): def reset_state(self):
self.ds.reset_state() self.ds.reset_state()
self.augs.reset_state() self.augs.reset_state()
...@@ -54,3 +54,28 @@ def logSoftmax(x): ...@@ -54,3 +54,28 @@ def logSoftmax(x):
return logprob return logprob
def class_balanced_binary_class_cross_entropy(pred, label, name='cross_entropy_loss'):
"""
The class-balanced cross entropy loss for binary classification,
as in `Holistically-Nested Edge Detection
<http://arxiv.org/abs/1504.06375>`_.
:param pred: size: b x ANYTHING. the predictions in [0,1].
:param label: size: b x ANYTHING. the ground truth in {0,1}.
:returns: class-balanced binary classification cross entropy loss
"""
z = batch_flatten(pred)
y = batch_flatten(label)
count_neg = tf.reduce_sum(1. - y)
count_pos = tf.reduce_sum(y)
total = tf.add(count_neg, count_pos)
beta = tf.truediv(count_neg, total)
eps = 1e-8
loss_pos = tf.mul(-beta, tf.reduce_sum(tf.mul(tf.log(tf.abs(z) + eps), y), 1))
loss_neg = tf.mul(1. - beta, tf.reduce_sum(tf.mul(tf.log(tf.abs(1. - z) + eps), 1. - y), 1))
cost = tf.sub(loss_pos, loss_neg)
cost = tf.reduce_mean(cost, name=name)
return cost
...@@ -98,7 +98,8 @@ class Trainer(object): ...@@ -98,7 +98,8 @@ class Trainer(object):
for step in tqdm.trange( for step in tqdm.trange(
self.config.step_per_epoch, self.config.step_per_epoch,
leave=True, mininterval=0.5, leave=True, mininterval=0.5,
dynamic_ncols=True, ascii=True): dynamic_ncols=True, ascii=True,
bar_format='{l_bar}{bar}|{n_fmt}/{total_fmt} [{elapsed}<{remaining},{rate_noinv_fmt}]'):
if self.coord.should_stop(): if self.coord.should_stop():
return return
self.run_step() self.run_step()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment