Commit f6b502d7 authored by Yuxin Wu's avatar Yuxin Wu

bsds dataset

parent 1c3d8741
......@@ -2,6 +2,7 @@
# File: __init__.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
import cv2 # fix https://github.com/tensorflow/tensorflow/issues/1924
from . import models
from . import train
from . import utils
......
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# File: bsds500.py
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
import os, glob
import cv2
import numpy as np
from scipy.io import loadmat
from ...utils import logger, get_rng
from ...utils.fs import download
from ..base import DataFlow
__all__ = ['BSDS500']
DATA_URL = "http://www.eecs.berkeley.edu/Research/Projects/CS/vision/grouping/BSR/BSR_bsds500.tgz"
IMG_W, IMG_H = 481, 321
class BSDS500(DataFlow):
"""
`Berkeley Segmentation Data Set and Benchmarks 500
<http://www.eecs.berkeley.edu/Research/Projects/CS/vision/grouping/resources.html#bsds500>`_.
Produce (image, label) pair, where image has shape (321, 481, 3) and
ranges in [0,255]. Label is binary and has shape (321, 481).
Those pixels annotated as boundaries by >= 3 out of 5 annotators are
considered positive examples. This is used in `Holistically-Nested Edge Detection
<http://arxiv.org/abs/1504.06375>`_.
"""
def __init__(self, name, data_dir=None, shuffle=True):
"""
:param name: 'train', 'test', 'val'
:param data_dir: a directory containing the original 'BSR' directory.
"""
# check and download data
if data_dir is None:
data_dir = os.path.join(os.path.dirname(__file__), 'bsds500_data')
if not os.path.isdir(os.path.join(data_dir, 'BSR')):
download(DATA_URL, data_dir)
filename = DATA_URL.split('/')[-1]
filepath = os.path.join(data_dir, filename)
import tarfile
tarfile.open(filepath, 'r:gz').extractall(data_dir)
self.data_root = os.path.join(data_dir, 'BSR', 'BSDS500', 'data')
assert os.path.isdir(self.data_root)
self.shuffle = shuffle
assert name in ['train', 'test', 'val']
self._load(name)
self.rng = get_rng(self)
def reset_state(self):
self.rng = get_rng(self)
def _load(self, name):
image_glob = os.path.join(self.data_root, 'images', name, '*.jpg')
image_files = glob.glob(image_glob)
gt_dir = os.path.join(self.data_root, 'groundTruth', name)
self.data = np.zeros((len(image_files), IMG_H, IMG_W, 3), dtype='uint8')
self.label = np.zeros((len(image_files), IMG_H, IMG_W), dtype='bool')
for idx, f in enumerate(image_files):
im = cv2.imread(f, cv2.IMREAD_COLOR)
assert im is not None
if im.shape[0] > im.shape[1]:
im = np.transpose(im, (1,0,2))
assert im.shape[:2] == (IMG_H, IMG_W), "{} != {}".format(im.shape[:2], (IMG_H, IMG_W))
imgid = os.path.basename(f).split('.')[0]
gt_file = os.path.join(gt_dir, imgid)
gt = loadmat(gt_file)['groundTruth'][0]
gt = sum(gt[k]['Boundaries'][0][0] for k in range(5))
gt[gt < 3] = 0
gt[gt != 0] = 1
if gt.shape[0] > gt.shape[1]:
gt = gt.transpose()
assert gt.shape == (IMG_H, IMG_W)
self.data[idx] = im
self.label[idx] = gt
def size(self):
return self.data.shape[0]
def get_data(self):
idxs = np.arange(self.data.shape[0])
if self.shuffle:
self.rng.shuffle(idxs)
for k in idxs:
yield [self.data[k], self.label[k]]
if __name__ == '__main__':
a = BSDS500('val')
for k in a.get_data():
cv2.imshow("haha", k[1].astype('uint8')*255)
cv2.waitKey(1000)
......@@ -9,7 +9,6 @@ import random
import six
from six.moves import urllib, range
import copy
import tarfile
import logging
from ...utils import logger, get_rng
......@@ -31,6 +30,7 @@ def maybe_download_and_extract(dest_directory):
download(DATA_URL, dest_directory)
filename = DATA_URL.split('/')[-1]
filepath = os.path.join(dest_directory, filename)
import tarfile
tarfile.open(filepath, 'r:gz').extractall(dest_directory)
def read_cifar10(filenames):
......
......@@ -117,7 +117,7 @@ class ILSVRC12(DataFlow):
for k in idxs:
tp = self.imglist[k]
fname = os.path.join(self.dir, self.name, tp[0]).strip()
im = cv2.imread(fname)
im = cv2.imread(fname, cv2.IMREAD_COLOR)
assert im is not None, fname
if im.ndim == 2:
im = np.expand_dims(im, 2).repeat(3,2)
......
......@@ -26,8 +26,8 @@ class SVHNDigit(DataFlow):
def __init__(self, name, data_dir=None, shuffle=True):
"""
name: 'train', 'test', or 'extra'
data_dir: a directory containing the original {train,test,extra}_32x32.mat
:param name: 'train', 'test', or 'extra'
:param data_dir: a directory containing the original {train,test,extra}_32x32.mat
"""
self.shuffle = shuffle
self.rng = get_rng(self)
......
......@@ -54,3 +54,5 @@ class AugmentImageComponent(MapDataComponent):
def reset_state(self):
self.ds.reset_state()
self.augs.reset_state()
......@@ -54,3 +54,28 @@ def logSoftmax(x):
return logprob
def class_balanced_binary_class_cross_entropy(pred, label, name='cross_entropy_loss'):
"""
The class-balanced cross entropy loss for binary classification,
as in `Holistically-Nested Edge Detection
<http://arxiv.org/abs/1504.06375>`_.
:param pred: size: b x ANYTHING. the predictions in [0,1].
:param label: size: b x ANYTHING. the ground truth in {0,1}.
:returns: class-balanced binary classification cross entropy loss
"""
z = batch_flatten(pred)
y = batch_flatten(label)
count_neg = tf.reduce_sum(1. - y)
count_pos = tf.reduce_sum(y)
total = tf.add(count_neg, count_pos)
beta = tf.truediv(count_neg, total)
eps = 1e-8
loss_pos = tf.mul(-beta, tf.reduce_sum(tf.mul(tf.log(tf.abs(z) + eps), y), 1))
loss_neg = tf.mul(1. - beta, tf.reduce_sum(tf.mul(tf.log(tf.abs(1. - z) + eps), 1. - y), 1))
cost = tf.sub(loss_pos, loss_neg)
cost = tf.reduce_mean(cost, name=name)
return cost
......@@ -98,7 +98,8 @@ class Trainer(object):
for step in tqdm.trange(
self.config.step_per_epoch,
leave=True, mininterval=0.5,
dynamic_ncols=True, ascii=True):
dynamic_ncols=True, ascii=True,
bar_format='{l_bar}{bar}|{n_fmt}/{total_fmt} [{elapsed}<{remaining},{rate_noinv_fmt}]'):
if self.coord.should_stop():
return
self.run_step()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment