Commit cba97f75 authored by Yuxin Wu's avatar Yuxin Wu

bbox and tf_func

parent 0ec586b0
...@@ -38,13 +38,12 @@ def get_data(train_or_test): ...@@ -38,13 +38,12 @@ def get_data(train_or_test):
if isTrain: if isTrain:
augmentors = [ augmentors = [
imgaug.CenterPaste((40, 40)), imgaug.CenterPaste((40, 40)),
imgaug.RandomCrop((32, 32)),
#imgaug.Flip(horiz=True),
imgaug.Brightness(10), imgaug.Brightness(10),
imgaug.Contrast((0.8,1.2)), imgaug.Contrast((0.8,1.2)),
imgaug.GaussianDeform( # this is slow imgaug.GaussianDeform( # this is slow. without it, can only reach 1.9% error
[(0.2, 0.2), (0.2, 0.8), (0.8,0.8), (0.8,0.2)], [(0.2, 0.2), (0.2, 0.8), (0.8,0.8), (0.8,0.2)],
(32, 32), 0.2, 3), (40, 40), 0.2, 3),
imgaug.RandomCrop((32, 32)),
imgaug.MapImage(lambda x: x - pp_mean), imgaug.MapImage(lambda x: x - pp_mean),
] ]
else: else:
......
...@@ -7,10 +7,12 @@ import tarfile ...@@ -7,10 +7,12 @@ import tarfile
import cv2 import cv2
import numpy as np import numpy as np
from six.moves import range from six.moves import range
import xml.etree.ElementTree as ET
from ...utils import logger, get_rng, get_dataset_dir, memoized from ...utils import logger, get_rng, get_dataset_dir, memoized
from ...utils.loadcaffe import get_caffe_pb from ...utils.loadcaffe import get_caffe_pb
from ...utils.fs import mkdir_p, download from ...utils.fs import mkdir_p, download
from ...utils.timer import timed_operation
from ..base import RNGDataFlow from ..base import RNGDataFlow
__all__ = ['ILSVRCMeta', 'ILSVRC12'] __all__ = ['ILSVRCMeta', 'ILSVRC12']
...@@ -20,7 +22,6 @@ def log_once(s): logger.warn(s) ...@@ -20,7 +22,6 @@ def log_once(s): logger.warn(s)
CAFFE_ILSVRC12_URL = "http://dl.caffe.berkeleyvision.org/caffe_ilsvrc12.tar.gz" CAFFE_ILSVRC12_URL = "http://dl.caffe.berkeleyvision.org/caffe_ilsvrc12.tar.gz"
# TODO move caffe_pb outside
class ILSVRCMeta(object): class ILSVRCMeta(object):
""" """
Some metadata for ILSVRC dataset. Some metadata for ILSVRC dataset.
...@@ -90,15 +91,16 @@ class ILSVRCMeta(object): ...@@ -90,15 +91,16 @@ class ILSVRCMeta(object):
class ILSVRC12(RNGDataFlow): class ILSVRC12(RNGDataFlow):
def __init__(self, dir, name, meta_dir=None, shuffle=True, def __init__(self, dir, name, meta_dir=None, shuffle=True,
dir_structure='original'): dir_structure='original', include_bb=False):
""" """
:param dir: A directory containing a subdir named `name`, where the :param dir: A directory containing a subdir named `name`, where the
original ILSVRC12_`name`.tar gets decompressed. original ILSVRC12_`name`.tar gets decompressed.
:param name: 'train' or 'val' or 'test' :param name: 'train' or 'val' or 'test'
:param dir_structure: the dir structure of 'val' or 'test'. :param dir_structure: The dir structure of 'val' or 'test'.
if is 'original' then keep the original decompressed dir with list If is 'original' then keep the original decompressed dir with list
of image files. if equals to 'train', use the `train/` dir of image files (as below). If equals to 'train', use the `train/` dir
structure with class name as subdirectories. structure with class name as subdirectories.
:param include_bb: Include the bounding box. Useful in training.
Dir should have the following structure: Dir should have the following structure:
...@@ -116,6 +118,11 @@ class ILSVRC12(RNGDataFlow): ...@@ -116,6 +118,11 @@ class ILSVRC12(RNGDataFlow):
test/ test/
ILSVRC2012_test_00000001.JPEG ILSVRC2012_test_00000001.JPEG
... ...
bbox/
n02134418/
n02134418_198.xml
...
...
After decompress ILSVRC12_img_train.tar, you can use the following After decompress ILSVRC12_img_train.tar, you can use the following
command to build the above structure for `train/`: command to build the above structure for `train/`:
...@@ -125,6 +132,7 @@ class ILSVRC12(RNGDataFlow): ...@@ -125,6 +132,7 @@ class ILSVRC12(RNGDataFlow):
find -type f -name '*.tar' | parallel -P 10 'echo {} && mkdir -p {/.} && tar xf {} -C {/.}' find -type f -name '*.tar' | parallel -P 10 'echo {} && mkdir -p {/.} && tar xf {} -C {/.}'
Or: Or:
for i in *.tar; do dir=${i%.tar}; echo $dir; mkdir -p $dir; tar xf $i -C $dir; done for i in *.tar; do dir=${i%.tar}; echo $dir; mkdir -p $dir; tar xf $i -C $dir; done
""" """
assert name in ['train', 'test', 'val'] assert name in ['train', 'test', 'val']
self.full_dir = os.path.join(dir, name) self.full_dir = os.path.join(dir, name)
...@@ -136,12 +144,19 @@ class ILSVRC12(RNGDataFlow): ...@@ -136,12 +144,19 @@ class ILSVRC12(RNGDataFlow):
self.dir_structure = dir_structure self.dir_structure = dir_structure
self.synset = meta.get_synset_1000() self.synset = meta.get_synset_1000()
if include_bb:
assert name == 'train', 'Bounding box only available for training'
self.bblist = ILSVRC12.get_training_bbox(
os.path.join(dir, 'bbox'), self.imglist)
self.include_bb = include_bb
def size(self): def size(self):
return len(self.imglist) return len(self.imglist)
def get_data(self): def get_data(self):
""" """
Produce original images or shape [h, w, 3], and label Produce original images of shape [h, w, 3], and label,
and optionally a bbox of [xmin, ymin, xmax, ymax] in [0, 1]
""" """
idxs = np.arange(len(self.imglist)) idxs = np.arange(len(self.imglist))
add_label_to_fname = (self.name != 'train' and self.dir_structure != 'original') add_label_to_fname = (self.name != 'train' and self.dir_structure != 'original')
...@@ -157,15 +172,55 @@ class ILSVRC12(RNGDataFlow): ...@@ -157,15 +172,55 @@ class ILSVRC12(RNGDataFlow):
assert im is not None, fname assert im is not None, fname
if im.ndim == 2: if im.ndim == 2:
im = np.expand_dims(im, 2).repeat(3,2) im = np.expand_dims(im, 2).repeat(3,2)
yield [im, label] if self.include_bb:
bb = self.bblist[k]
if not bb:
bb = [0, 0, 1, 1]
yield [im, label, bb]
else:
yield [im, label]
@staticmethod
def get_training_bbox(bbox_dir, imglist):
ret = []
def parse_bbox(fname):
root = ET.parse(fname).getroot()
size = root.find('size').getchildren()
size = map(int, [size[0].text, size[1].text])
box = root.find('object').find('bndbox').getchildren()
box = map(lambda x: float(x.text), box)
box[0] /= size[0]
box[1] /= size[1]
box[2] /= size[0]
box[3] /= size[1]
return np.asarray(box, dtype='float32')
with timed_operation('Loading Bounding Boxes ...'):
cnt = 0
import tqdm
for k in tqdm.trange(len(imglist)):
fname = imglist[k][0]
fname = fname[:-4] + 'xml'
fname = os.path.join(bbox_dir, fname)
try:
ret.append(parse_bbox(fname))
cnt += 1
except KeyboardInterrupt:
raise
except:
ret.append(None)
logger.info("{}/{} images have bounding box.".format(cnt, len(imglist)))
return ret
if __name__ == '__main__': if __name__ == '__main__':
meta = ILSVRCMeta() meta = ILSVRCMeta()
print(meta.get_per_pixel_mean())
#print(meta.get_synset_words_1000()) #print(meta.get_synset_words_1000())
#ds = ILSVRC12('/home/wyx/data/imagenet', 'val') ds = ILSVRC12('/home/wyx/data/fake_ilsvrc/', 'train', include_bb=True,
shuffle=False)
ds.reset_state()
for k in ds.get_data(): for k in ds.get_data():
from IPython import embed; embed() from IPython import embed; embed()
......
...@@ -16,7 +16,6 @@ class ImageAugmentor(object): ...@@ -16,7 +16,6 @@ class ImageAugmentor(object):
self.reset_state() self.reset_state()
def _init(self, params=None): def _init(self, params=None):
self.reset_state()
if params: if params:
for k, v in params.items(): for k, v in params.items():
if k != 'self': if k != 'self':
......
...@@ -22,7 +22,7 @@ class Flip(ImageAugmentor): ...@@ -22,7 +22,7 @@ class Flip(ImageAugmentor):
:param prob: probability of flip. :param prob: probability of flip.
""" """
if horiz and vert: if horiz and vert:
raise ValueError("Please use two Flip, with both 0.5 prob") raise ValueError("Please use two Flip instead.")
elif horiz: elif horiz:
self.code = 1 self.code = 1
elif vert: elif vert:
......
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# File: tf_func.py
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
from .base import ProxyDataFlow
from ..utils import logger
try:
import tensorflow as tf
except ImportError:
logger.warn("Cannot import tensorflow. TFFuncMapper won't be available.")
__all__ = []
else:
__all__ = ['TFFuncMapper']
class TFFuncMapper(ProxyDataFlow):
def __init__(self, ds,
get_placeholders, symbf, apply_symbf_on_dp, device='/cpu:0'):
"""
:param get_placeholders: a function returning the placeholders
:param symbf: a symbolic function taking the placeholders
:param apply_symbf_on_dp: apply the above function to datapoint
"""
super(TFFuncMapper, self).__init__(ds)
self.get_placeholders = get_placeholders
self.symbf = symbf
self.apply_symbf_on_dp = apply_symbf_on_dp
self.device = device
def reset_state(self):
super(TFFuncMapper, self).reset_state()
self.graph = tf.Graph()
with self.graph.as_default(), \
tf.device(self.device):
self.placeholders = self.get_placeholders()
self.output_vars = self.symbf(self.placeholders)
self.sess = tf.Session()
def run_func(vals):
return self.sess.run(self.output_vars,
feed_dict=dict(zip(self.placeholders, vals)))
self.run_func = run_func
def get_data(self):
for dp in self.ds.get_data():
dp = self.apply_symbf_on_dp(dp, self.run_func)
if dp:
yield dp
if __name__ == '__main__':
from .raw import FakeData
from .prefetch import PrefetchDataZMQ
from .image import AugmentImageComponent
from . import imgaug
ds = FakeData([[224, 224, 3]], 100000, random=False)
def tf_aug(v):
v = v[0]
v = tf.image.random_brightness(v, 0.1)
v = tf.image.random_contrast(v, 0.8, 1.2)
v = tf.image.random_flip_left_right(v)
return v
ds = TFFuncMapper(ds,
lambda: [tf.placeholder(tf.float32, [224, 224, 3], name='img')],
tf_aug,
lambda dp, f: [f([dp[0]])[0]]
)
#ds = AugmentImageComponent(ds,
#[imgaug.Brightness(0.1, clip=False),
#imgaug.Contrast((0.8, 1.2), clip=False),
#imgaug.Flip(horiz=True)
#])
#ds = PrefetchDataZMQ(ds, 4)
ds.reset_state()
import tqdm
itr = ds.get_data()
for k in tqdm.trange(100000):
next(itr)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment