Commit cba97f75 authored by Yuxin Wu's avatar Yuxin Wu

bbox and tf_func

parent 0ec586b0
......@@ -38,13 +38,12 @@ def get_data(train_or_test):
if isTrain:
augmentors = [
imgaug.CenterPaste((40, 40)),
imgaug.RandomCrop((32, 32)),
#imgaug.Flip(horiz=True),
imgaug.Brightness(10),
imgaug.Contrast((0.8,1.2)),
imgaug.GaussianDeform( # this is slow
imgaug.GaussianDeform( # this is slow. without it, can only reach 1.9% error
[(0.2, 0.2), (0.2, 0.8), (0.8,0.8), (0.8,0.2)],
(32, 32), 0.2, 3),
(40, 40), 0.2, 3),
imgaug.RandomCrop((32, 32)),
imgaug.MapImage(lambda x: x - pp_mean),
]
else:
......
......@@ -7,10 +7,12 @@ import tarfile
import cv2
import numpy as np
from six.moves import range
import xml.etree.ElementTree as ET
from ...utils import logger, get_rng, get_dataset_dir, memoized
from ...utils.loadcaffe import get_caffe_pb
from ...utils.fs import mkdir_p, download
from ...utils.timer import timed_operation
from ..base import RNGDataFlow
__all__ = ['ILSVRCMeta', 'ILSVRC12']
......@@ -20,7 +22,6 @@ def log_once(s): logger.warn(s)
CAFFE_ILSVRC12_URL = "http://dl.caffe.berkeleyvision.org/caffe_ilsvrc12.tar.gz"
# TODO move caffe_pb outside
class ILSVRCMeta(object):
"""
Some metadata for ILSVRC dataset.
......@@ -90,15 +91,16 @@ class ILSVRCMeta(object):
class ILSVRC12(RNGDataFlow):
def __init__(self, dir, name, meta_dir=None, shuffle=True,
dir_structure='original'):
dir_structure='original', include_bb=False):
"""
:param dir: A directory containing a subdir named `name`, where the
original ILSVRC12_`name`.tar gets decompressed.
:param name: 'train' or 'val' or 'test'
:param dir_structure: the dir structure of 'val' or 'test'.
if is 'original' then keep the original decompressed dir with list
of image files. if equals to 'train', use the `train/` dir
:param dir_structure: The dir structure of 'val' or 'test'.
If is 'original' then keep the original decompressed dir with list
of image files (as below). If equals to 'train', use the `train/` dir
structure with class name as subdirectories.
:param include_bb: Include the bounding box. Useful in training.
Dir should have the following structure:
......@@ -116,6 +118,11 @@ class ILSVRC12(RNGDataFlow):
test/
ILSVRC2012_test_00000001.JPEG
...
bbox/
n02134418/
n02134418_198.xml
...
...
After decompress ILSVRC12_img_train.tar, you can use the following
command to build the above structure for `train/`:
......@@ -125,6 +132,7 @@ class ILSVRC12(RNGDataFlow):
find -type f -name '*.tar' | parallel -P 10 'echo {} && mkdir -p {/.} && tar xf {} -C {/.}'
Or:
for i in *.tar; do dir=${i%.tar}; echo $dir; mkdir -p $dir; tar xf $i -C $dir; done
"""
assert name in ['train', 'test', 'val']
self.full_dir = os.path.join(dir, name)
......@@ -136,12 +144,19 @@ class ILSVRC12(RNGDataFlow):
self.dir_structure = dir_structure
self.synset = meta.get_synset_1000()
if include_bb:
assert name == 'train', 'Bounding box only available for training'
self.bblist = ILSVRC12.get_training_bbox(
os.path.join(dir, 'bbox'), self.imglist)
self.include_bb = include_bb
def size(self):
return len(self.imglist)
def get_data(self):
"""
Produce original images or shape [h, w, 3], and label
Produce original images of shape [h, w, 3], and label,
and optionally a bbox of [xmin, ymin, xmax, ymax] in [0, 1]
"""
idxs = np.arange(len(self.imglist))
add_label_to_fname = (self.name != 'train' and self.dir_structure != 'original')
......@@ -157,15 +172,55 @@ class ILSVRC12(RNGDataFlow):
assert im is not None, fname
if im.ndim == 2:
im = np.expand_dims(im, 2).repeat(3,2)
if self.include_bb:
bb = self.bblist[k]
if not bb:
bb = [0, 0, 1, 1]
yield [im, label, bb]
else:
yield [im, label]
@staticmethod
def get_training_bbox(bbox_dir, imglist):
ret = []
def parse_bbox(fname):
root = ET.parse(fname).getroot()
size = root.find('size').getchildren()
size = map(int, [size[0].text, size[1].text])
box = root.find('object').find('bndbox').getchildren()
box = map(lambda x: float(x.text), box)
box[0] /= size[0]
box[1] /= size[1]
box[2] /= size[0]
box[3] /= size[1]
return np.asarray(box, dtype='float32')
with timed_operation('Loading Bounding Boxes ...'):
cnt = 0
import tqdm
for k in tqdm.trange(len(imglist)):
fname = imglist[k][0]
fname = fname[:-4] + 'xml'
fname = os.path.join(bbox_dir, fname)
try:
ret.append(parse_bbox(fname))
cnt += 1
except KeyboardInterrupt:
raise
except:
ret.append(None)
logger.info("{}/{} images have bounding box.".format(cnt, len(imglist)))
return ret
if __name__ == '__main__':
meta = ILSVRCMeta()
print(meta.get_per_pixel_mean())
#print(meta.get_synset_words_1000())
#ds = ILSVRC12('/home/wyx/data/imagenet', 'val')
ds = ILSVRC12('/home/wyx/data/fake_ilsvrc/', 'train', include_bb=True,
shuffle=False)
ds.reset_state()
for k in ds.get_data():
from IPython import embed; embed()
......
......@@ -16,7 +16,6 @@ class ImageAugmentor(object):
self.reset_state()
def _init(self, params=None):
self.reset_state()
if params:
for k, v in params.items():
if k != 'self':
......
......@@ -22,7 +22,7 @@ class Flip(ImageAugmentor):
:param prob: probability of flip.
"""
if horiz and vert:
raise ValueError("Please use two Flip, with both 0.5 prob")
raise ValueError("Please use two Flip instead.")
elif horiz:
self.code = 1
elif vert:
......
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# File: tf_func.py
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
from .base import ProxyDataFlow
from ..utils import logger
try:
import tensorflow as tf
except ImportError:
logger.warn("Cannot import tensorflow. TFFuncMapper won't be available.")
__all__ = []
else:
__all__ = ['TFFuncMapper']
class TFFuncMapper(ProxyDataFlow):
def __init__(self, ds,
get_placeholders, symbf, apply_symbf_on_dp, device='/cpu:0'):
"""
:param get_placeholders: a function returning the placeholders
:param symbf: a symbolic function taking the placeholders
:param apply_symbf_on_dp: apply the above function to datapoint
"""
super(TFFuncMapper, self).__init__(ds)
self.get_placeholders = get_placeholders
self.symbf = symbf
self.apply_symbf_on_dp = apply_symbf_on_dp
self.device = device
def reset_state(self):
super(TFFuncMapper, self).reset_state()
self.graph = tf.Graph()
with self.graph.as_default(), \
tf.device(self.device):
self.placeholders = self.get_placeholders()
self.output_vars = self.symbf(self.placeholders)
self.sess = tf.Session()
def run_func(vals):
return self.sess.run(self.output_vars,
feed_dict=dict(zip(self.placeholders, vals)))
self.run_func = run_func
def get_data(self):
for dp in self.ds.get_data():
dp = self.apply_symbf_on_dp(dp, self.run_func)
if dp:
yield dp
if __name__ == '__main__':
from .raw import FakeData
from .prefetch import PrefetchDataZMQ
from .image import AugmentImageComponent
from . import imgaug
ds = FakeData([[224, 224, 3]], 100000, random=False)
def tf_aug(v):
v = v[0]
v = tf.image.random_brightness(v, 0.1)
v = tf.image.random_contrast(v, 0.8, 1.2)
v = tf.image.random_flip_left_right(v)
return v
ds = TFFuncMapper(ds,
lambda: [tf.placeholder(tf.float32, [224, 224, 3], name='img')],
tf_aug,
lambda dp, f: [f([dp[0]])[0]]
)
#ds = AugmentImageComponent(ds,
#[imgaug.Brightness(0.1, clip=False),
#imgaug.Contrast((0.8, 1.2), clip=False),
#imgaug.Flip(horiz=True)
#])
#ds = PrefetchDataZMQ(ds, 4)
ds.reset_state()
import tqdm
itr = ds.get_data()
for k in tqdm.trange(100000):
next(itr)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment