Commit f585913b authored by Yuxin Wu's avatar Yuxin Wu

Add rgb=None in "imgaug.Contrast"

parent 63b8fb00
...@@ -62,7 +62,7 @@ follow these conventions and will need some workarounds if used within tensorpac ...@@ -62,7 +62,7 @@ follow these conventions and will need some workarounds if used within tensorpac
#### What You Can Do Inside Tower Function #### What You Can Do Inside Tower Function
1. Call any symbolic functions as long as they follow the above rules. 1. Call any symbolic functions as long as they follow the above rules.
2. The function will be called under a 2. The tower function will be called under a
[TowerContext](../modules/tfutils.html#tensorpack.tfutils.tower.BaseTowerContext), [TowerContext](../modules/tfutils.html#tensorpack.tfutils.tower.BaseTowerContext),
which can be accessed by [get_current_tower_context()](../modules/tfutils.html#tensorpack.tfutils.tower.get_current_tower_context). which can be accessed by [get_current_tower_context()](../modules/tfutils.html#tensorpack.tfutils.tower.get_current_tower_context).
The context contains information about training/inference mode, scope name, etc. The context contains information about training/inference mode, scope name, etc.
...@@ -93,7 +93,7 @@ Note some __common problems__ when using these trainers: ...@@ -93,7 +93,7 @@ Note some __common problems__ when using these trainers:
inputs on each GPU needs to have consistent shapes. inputs on each GPU needs to have consistent shapes.
``` ```
2. The tower function (your model code) will get called multipile times on each GPU. 2. The tower function (your model code) will get called once on each GPU.
You must follow the abovementieond rules of tower function. You must follow the abovementieond rules of tower function.
### Distributed Trainers ### Distributed Trainers
......
...@@ -51,7 +51,7 @@ class Brightness(ImageAugmentor): ...@@ -51,7 +51,7 @@ class Brightness(ImageAugmentor):
""" """
Args: Args:
delta (float): Randomly add a value within [-delta,delta] delta (float): Randomly add a value within [-delta,delta]
clip (bool): clip results to [0,255]. clip (bool): clip results to [0,255] if data type is uint8.
""" """
super(Brightness, self).__init__() super(Brightness, self).__init__()
assert delta > 0 assert delta > 0
...@@ -78,7 +78,7 @@ class BrightnessScale(ImageAugmentor): ...@@ -78,7 +78,7 @@ class BrightnessScale(ImageAugmentor):
""" """
Args: Args:
range (tuple): Randomly scale the image by a factor in (range[0], range[1]) range (tuple): Randomly scale the image by a factor in (range[0], range[1])
clip (bool): clip results to [0,255]. clip (bool): clip results to [0,255] if data type is uint8.
""" """
super(BrightnessScale, self).__init__() super(BrightnessScale, self).__init__()
self._init(locals()) self._init(locals())
...@@ -101,11 +101,12 @@ class Contrast(ImageAugmentor): ...@@ -101,11 +101,12 @@ class Contrast(ImageAugmentor):
Apply ``x = (x - mean) * contrast_factor + mean`` to each channel. Apply ``x = (x - mean) * contrast_factor + mean`` to each channel.
""" """
def __init__(self, factor_range, rgb=True, clip=True): def __init__(self, factor_range, rgb=None, clip=True):
""" """
Args: Args:
factor_range (list or tuple): an interval to randomly sample the `contrast_factor`. factor_range (list or tuple): an interval to randomly sample the `contrast_factor`.
clip (bool): clip to [0, 255] if True. rgb (bool or None): if None, use the mean per-channel.
clip (bool): clip to [0, 255] if data type is uint8.
""" """
super(Contrast, self).__init__() super(Contrast, self).__init__()
self._init(locals()) self._init(locals())
...@@ -117,9 +118,12 @@ class Contrast(ImageAugmentor): ...@@ -117,9 +118,12 @@ class Contrast(ImageAugmentor):
old_dtype = img.dtype old_dtype = img.dtype
if img.ndim == 3: if img.ndim == 3:
if self.rgb is not None:
m = cv2.COLOR_RGB2GRAY if self.rgb else cv2.COLOR_BGR2GRAY m = cv2.COLOR_RGB2GRAY if self.rgb else cv2.COLOR_BGR2GRAY
grey = cv2.cvtColor(img, m) grey = cv2.cvtColor(img.astype('float32'), m)
mean = np.mean(grey) mean = np.mean(grey)
else:
mean = np.mean(img, axis=(0, 1), keepdims=True)
else: else:
mean = np.mean(img) mean = np.mean(img)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment