Commit f585913b authored by Yuxin Wu's avatar Yuxin Wu

Add rgb=None in "imgaug.Contrast"

parent 63b8fb00
......@@ -62,7 +62,7 @@ follow these conventions and will need some workarounds if used within tensorpac
#### What You Can Do Inside Tower Function
1. Call any symbolic functions as long as they follow the above rules.
2. The function will be called under a
2. The tower function will be called under a
[TowerContext](../modules/tfutils.html#tensorpack.tfutils.tower.BaseTowerContext),
which can be accessed by [get_current_tower_context()](../modules/tfutils.html#tensorpack.tfutils.tower.get_current_tower_context).
The context contains information about training/inference mode, scope name, etc.
......@@ -93,7 +93,7 @@ Note some __common problems__ when using these trainers:
inputs on each GPU needs to have consistent shapes.
```
2. The tower function (your model code) will get called multipile times on each GPU.
2. The tower function (your model code) will get called once on each GPU.
You must follow the abovementieond rules of tower function.
### Distributed Trainers
......
......@@ -51,7 +51,7 @@ class Brightness(ImageAugmentor):
"""
Args:
delta (float): Randomly add a value within [-delta,delta]
clip (bool): clip results to [0,255].
clip (bool): clip results to [0,255] if data type is uint8.
"""
super(Brightness, self).__init__()
assert delta > 0
......@@ -78,7 +78,7 @@ class BrightnessScale(ImageAugmentor):
"""
Args:
range (tuple): Randomly scale the image by a factor in (range[0], range[1])
clip (bool): clip results to [0,255].
clip (bool): clip results to [0,255] if data type is uint8.
"""
super(BrightnessScale, self).__init__()
self._init(locals())
......@@ -101,11 +101,12 @@ class Contrast(ImageAugmentor):
Apply ``x = (x - mean) * contrast_factor + mean`` to each channel.
"""
def __init__(self, factor_range, rgb=True, clip=True):
def __init__(self, factor_range, rgb=None, clip=True):
"""
Args:
factor_range (list or tuple): an interval to randomly sample the `contrast_factor`.
clip (bool): clip to [0, 255] if True.
rgb (bool or None): if None, use the mean per-channel.
clip (bool): clip to [0, 255] if data type is uint8.
"""
super(Contrast, self).__init__()
self._init(locals())
......@@ -117,9 +118,12 @@ class Contrast(ImageAugmentor):
old_dtype = img.dtype
if img.ndim == 3:
if self.rgb is not None:
m = cv2.COLOR_RGB2GRAY if self.rgb else cv2.COLOR_BGR2GRAY
grey = cv2.cvtColor(img, m)
grey = cv2.cvtColor(img.astype('float32'), m)
mean = np.mean(grey)
else:
mean = np.mean(img, axis=(0, 1), keepdims=True)
else:
mean = np.mean(img)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment