Commit ed702d1d authored by Yuxin Wu's avatar Yuxin Wu

Add GoogleNetResize into imgaug

parent b097e7d5
......@@ -53,13 +53,13 @@ Speed:
1. After warmup, the training speed will slowly decrease due to more accurate proposals.
1. This implementation is about 10% slower than detectron,
probably due to the lack of specialized ops (e.g. AffineChannel, ROIAlign) in TensorFlow.
It's certainly faster than other TF implementation.
1. The code should have around 70% GPU utilization on V100s, and 85%~90% scaling
efficiency from 1 V100 to 8 V100s.
1. This implementation does not contain specialized CUDA ops (e.g. AffineChannel, ROIAlign),
so it can be slightly (~10%) slower than Detectron (Caffe2) and
maskrcnn-benchmark (PyTorch).
Possible Future Enhancements:
1. Define a better interface to load custom dataset.
......
......@@ -25,44 +25,13 @@ from tensorpack.utils.stats import RatioCounter
"""
class GoogleNetResize(imgaug.ImageAugmentor):
"""
crop 8%~100% of the original image
See `Going Deeper with Convolutions` by Google.
"""
def __init__(self, crop_area_fraction=0.08,
aspect_ratio_low=0.75, aspect_ratio_high=1.333,
target_shape=224):
self._init(locals())
def _augment(self, img, _):
h, w = img.shape[:2]
area = h * w
for _ in range(10):
targetArea = self.rng.uniform(self.crop_area_fraction, 1.0) * area
aspectR = self.rng.uniform(self.aspect_ratio_low, self.aspect_ratio_high)
ww = int(np.sqrt(targetArea * aspectR) + 0.5)
hh = int(np.sqrt(targetArea / aspectR) + 0.5)
if self.rng.uniform() < 0.5:
ww, hh = hh, ww
if hh <= h and ww <= w:
x1 = 0 if w == ww else self.rng.randint(0, w - ww)
y1 = 0 if h == hh else self.rng.randint(0, h - hh)
out = img[y1:y1 + hh, x1:x1 + ww]
out = cv2.resize(out, (self.target_shape, self.target_shape), interpolation=cv2.INTER_CUBIC)
return out
out = imgaug.ResizeShortestEdge(self.target_shape, interp=cv2.INTER_CUBIC).augment(img)
out = imgaug.CenterCrop(self.target_shape).augment(out)
return out
def fbresnet_augmentor(isTrain):
"""
Augmentor used in fb.resnet.torch, for BGR images in range [0,255].
"""
if isTrain:
augmentors = [
GoogleNetResize(),
imgaug.GoogleNetRandomCropAndResize(),
# It's OK to remove the following augs if your CPU is not fast enough.
# Removing brightness/contrast/saturation does not have a significant effect on accuracy.
# Removing lighting leads to a tiny drop in accuracy.
......
......@@ -16,7 +16,7 @@ from tensorpack.tfutils.scope_utils import under_name_scope
from tensorpack.utils import logger
from tensorpack.utils.gpu import get_num_gpu
from imagenet_utils import GoogleNetResize, ImageNetModel, eval_on_ILSVRC12, get_imagenet_dataflow
from imagenet_utils import ImageNetModel, eval_on_ILSVRC12, get_imagenet_dataflow
@layer_register(log_shape=True)
......@@ -160,7 +160,7 @@ def get_data(name, batch):
if isTrain:
augmentors = [
# use lighter augs if model is too small
GoogleNetResize(crop_area_fraction=0.49 if args.ratio < 1 else 0.08),
imgaug.GoogleNetRandomCropAndResize(crop_area_fraction=(0.49 if args.ratio < 1 else 0.08, 1.)),
imgaug.RandomOrderAug(
[imgaug.BrightnessScale((0.6, 1.4), clip=False),
imgaug.Contrast((0.6, 1.4), clip=False),
......
# -*- coding: utf-8 -*-
# File: crop.py
import numpy as np
import cv2
from ...utils.argtools import shape2d
from .base import ImageAugmentor
from .transform import CropTransform, TransformAugmentorBase
from .misc import ResizeShortestEdge
__all__ = ['RandomCrop', 'CenterCrop', 'RandomCropRandomShape']
__all__ = ['RandomCrop', 'CenterCrop', 'RandomCropRandomShape', 'GoogleNetRandomCropAndResize']
class RandomCrop(TransformAugmentorBase):
......@@ -80,3 +84,48 @@ class RandomCropRandomShape(TransformAugmentorBase):
y0 = 0 if diffh == 0 else self.rng.randint(diffh)
x0 = 0 if diffw == 0 else self.rng.randint(diffw)
return CropTransform(y0, x0, h, w)
class GoogleNetRandomCropAndResize(ImageAugmentor):
"""
The random crop and resize augmentation proposed in
Sec. 6 of `Going Deeper with Convolutions` by Google.
This implementation follows the details in `fb.resnet.torch`.
It attempts to crop a random rectangle with 8%~100% area of the original image,
and keep the aspect ratio between 3/4 to 4/3. Then it resize this crop to the target shape.
If such crop cannot be found in 10 iterations, it will to a ResizeShortestEdge + CenterCrop.
"""
def __init__(self, crop_area_fraction=(0.08, 1.),
aspect_ratio_range=(0.75, 1.333),
target_shape=224, interp=cv2.INTER_LINEAR):
"""
Args:
crop_area_fraction (tuple(float)): Defaults to crop 8%-100% area.
aspect_ratio_range (tuple(float)): Defaults to make aspect ratio in 3/4-4/3.
target_shape (int): Defaults to 224, the standard ImageNet image shape.
"""
self._init(locals())
def _augment(self, img, _):
h, w = img.shape[:2]
area = h * w
for _ in range(10):
targetArea = self.rng.uniform(*self.crop_area_fraction) * area
aspectR = self.rng.uniform(*self.aspect_ratio_range)
ww = int(np.sqrt(targetArea * aspectR) + 0.5)
hh = int(np.sqrt(targetArea / aspectR) + 0.5)
if self.rng.uniform() < 0.5:
ww, hh = hh, ww
if hh <= h and ww <= w:
x1 = 0 if w == ww else self.rng.randint(0, w - ww)
y1 = 0 if h == hh else self.rng.randint(0, h - hh)
out = img[y1:y1 + hh, x1:x1 + ww]
out = cv2.resize(out, (self.target_shape, self.target_shape), interpolation=self.interp)
return out
out = ResizeShortestEdge(self.target_shape, interp=cv2.INTER_CUBIC).augment(img)
out = CenterCrop(self.target_shape).augment(out)
return out
def _augment_coords(self, coords, param):
raise NotImplementedError()
......@@ -79,11 +79,10 @@ class LeastLoadedDeviceSetter(object):
self.ps_sizes = [0] * len(self.ps_devices)
def __call__(self, op):
if get_tf_version_tuple() >= (1, 8):
from tensorflow.python.training.device_util import canonicalize
else:
def canonicalize(name): # tensorflow/tensorflow#11484
return tf.DeviceSpec.from_string(name).to_string()
# from tensorflow.python.training.device_util import canonicalize
# from tensorflow.python.distribute.device_util import canonicalize
def canonicalize(name): # tensorflow/tensorflow#11484
return tf.DeviceSpec.from_string(name).to_string()
if op.device:
return op.device
......
......@@ -179,7 +179,7 @@ def _pick_tqdm_interval(file):
return 15
if 'OMPI_COMM_WORLD_SIZE' in os.environ:
if int(os.environ['OMPI_COMM_WORLD_SIZE']) > 1:
if int(os.environ['OMPI_COMM_WORLD_SIZE']) > 8:
return 60
# If not a tty, don't refresh progress bar that often
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment