Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
8c0106e4
You need to sign in or sign up before continuing.
Commit
8c0106e4
authored
Aug 17, 2017
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
update docs in imgaug
parent
34357e77
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
42 additions
and
12 deletions
+42
-12
tensorpack/dataflow/image.py
tensorpack/dataflow/image.py
+29
-9
tensorpack/dataflow/imgaug/base.py
tensorpack/dataflow/imgaug/base.py
+13
-3
No files found.
tensorpack/dataflow/image.py
View file @
8c0106e4
...
@@ -12,6 +12,12 @@ from ..utils.argtools import shape2d
...
@@ -12,6 +12,12 @@ from ..utils.argtools import shape2d
__all__
=
[
'ImageFromFile'
,
'AugmentImageComponent'
,
'AugmentImageCoordinates'
,
'AugmentImageComponents'
]
__all__
=
[
'ImageFromFile'
,
'AugmentImageComponent'
,
'AugmentImageCoordinates'
,
'AugmentImageComponents'
]
def
_valid_coords
(
coords
):
assert
coords
.
ndim
==
2
,
coords
.
ndim
assert
coords
.
shape
[
1
]
==
2
,
coords
.
shape
assert
np
.
issubdtype
(
coords
.
dtype
,
np
.
float
),
coords
.
dtype
class
ImageFromFile
(
RNGDataFlow
):
class
ImageFromFile
(
RNGDataFlow
):
""" Produce images read from a list of files. """
""" Produce images read from a list of files. """
def
__init__
(
self
,
files
,
channel
=
3
,
resize
=
None
,
shuffle
=
False
):
def
__init__
(
self
,
files
,
channel
=
3
,
resize
=
None
,
shuffle
=
False
):
...
@@ -49,14 +55,14 @@ class ImageFromFile(RNGDataFlow):
...
@@ -49,14 +55,14 @@ class ImageFromFile(RNGDataFlow):
class
AugmentImageComponent
(
MapDataComponent
):
class
AugmentImageComponent
(
MapDataComponent
):
"""
"""
Apply image augmentors on 1 component.
Apply image augmentors on 1
image
component.
"""
"""
def
__init__
(
self
,
ds
,
augmentors
,
index
=
0
,
copy
=
True
):
def
__init__
(
self
,
ds
,
augmentors
,
index
=
0
,
copy
=
True
):
"""
"""
Args:
Args:
ds (DataFlow): input DataFlow.
ds (DataFlow): input DataFlow.
augmentors (AugmentorList): a list of :class:`imgaug.ImageAugmentor` to be applied in order.
augmentors (AugmentorList): a list of :class:`imgaug.ImageAugmentor` to be applied in order.
index (int): the index of the image component to be augmented.
index (int): the index of the image component to be augmented
in the datapoint
.
copy (bool): Some augmentors modify the input images. When copy is
copy (bool): Some augmentors modify the input images. When copy is
True, a copy will be made before any augmentors are applied,
True, a copy will be made before any augmentors are applied,
to keep the original images not modified.
to keep the original images not modified.
...
@@ -117,9 +123,7 @@ class AugmentImageCoordinates(MapData):
...
@@ -117,9 +123,7 @@ class AugmentImageCoordinates(MapData):
def
func
(
dp
):
def
func
(
dp
):
try
:
try
:
img
,
coords
=
dp
[
img_index
],
dp
[
coords_index
]
img
,
coords
=
dp
[
img_index
],
dp
[
coords_index
]
assert
coords
.
ndim
==
2
,
coords
.
ndim
_valid_coords
(
coords
)
assert
coords
.
shape
[
1
]
==
2
,
coords
.
shape
assert
np
.
issubdtype
(
coords
.
dtype
,
np
.
float
),
coords
.
dtype
if
copy
:
if
copy
:
img
,
coords
=
copy_mod
.
deepcopy
((
img
,
coords
))
img
,
coords
=
copy_mod
.
deepcopy
((
img
,
coords
))
img
,
prms
=
self
.
augs
.
_augment_return_params
(
img
)
img
,
prms
=
self
.
augs
.
_augment_return_params
(
img
)
...
@@ -145,14 +149,25 @@ class AugmentImageCoordinates(MapData):
...
@@ -145,14 +149,25 @@ class AugmentImageCoordinates(MapData):
class
AugmentImageComponents
(
MapData
):
class
AugmentImageComponents
(
MapData
):
"""
"""
Apply image augmentors on several components, with shared augmentation parameters.
Apply image augmentors on several components, with shared augmentation parameters.
Example:
.. code-block:: python
ds = MyDataFlow() # produce [image(HWC), segmask(HW), keypoint(Nx2)]
ds = AugmentImageComponents(
ds, augs,
index=(0,1), coords_index=(2,))
"""
"""
def
__init__
(
self
,
ds
,
augmentors
,
index
=
(
0
,
1
),
copy
=
True
):
def
__init__
(
self
,
ds
,
augmentors
,
index
=
(
0
,
1
),
co
ords_index
=
(),
co
py
=
True
):
"""
"""
Args:
Args:
ds (DataFlow): input DataFlow.
ds (DataFlow): input DataFlow.
augmentors (AugmentorList): a list of :class:`imgaug.ImageAugmentor` instance to be applied in order.
augmentors (AugmentorList): a list of :class:`imgaug.ImageAugmentor` instance to be applied in order.
index: tuple of indices of components.
index: tuple of indices of the image components.
coords_index: tuple of indices of the coordinates components.
copy (bool): Some augmentors modify the input images. When copy is
copy (bool): Some augmentors modify the input images. When copy is
True, a copy will be made before any augmentors are applied,
True, a copy will be made before any augmentors are applied,
to keep the original images not modified.
to keep the original images not modified.
...
@@ -169,11 +184,16 @@ class AugmentImageComponents(MapData):
...
@@ -169,11 +184,16 @@ class AugmentImageComponents(MapData):
dp
=
copy_mod
.
copy
(
dp
)
# always do a shallow copy, make sure the list is intact
dp
=
copy_mod
.
copy
(
dp
)
# always do a shallow copy, make sure the list is intact
copy_func
=
copy_mod
.
deepcopy
if
copy
else
lambda
x
:
x
# noqa
copy_func
=
copy_mod
.
deepcopy
if
copy
else
lambda
x
:
x
# noqa
try
:
try
:
im
=
copy_func
(
dp
[
index
[
0
]])
major_image
=
index
[
0
]
# image to be used to get params. TODO better design?
im
=
copy_func
(
dp
[
major_image
])
im
,
prms
=
self
.
augs
.
_augment_return_params
(
im
)
im
,
prms
=
self
.
augs
.
_augment_return_params
(
im
)
dp
[
index
[
0
]
]
=
im
dp
[
major_image
]
=
im
for
idx
in
index
[
1
:]:
for
idx
in
index
[
1
:]:
dp
[
idx
]
=
self
.
augs
.
_augment
(
copy_func
(
dp
[
idx
]),
prms
)
dp
[
idx
]
=
self
.
augs
.
_augment
(
copy_func
(
dp
[
idx
]),
prms
)
for
idx
in
coords_index
:
coords
=
copy_func
(
dp
[
idx
])
_valid_coords
(
coords
)
dp
[
idx
]
=
self
.
augs
.
_augment_coords
(
coords
,
prms
)
return
dp
return
dp
except
KeyboardInterrupt
:
except
KeyboardInterrupt
:
raise
raise
...
...
tensorpack/dataflow/imgaug/base.py
View file @
8c0106e4
...
@@ -44,16 +44,20 @@ class Augmentor(object):
...
@@ -44,16 +44,20 @@ class Augmentor(object):
@
abstractmethod
@
abstractmethod
def
_augment
(
self
,
d
,
param
):
def
_augment
(
self
,
d
,
param
):
"""
"""
augment with the given param and return the new image
Augment with the given param and return the new data.
The augmentor is allowed to modify data in-place.
"""
"""
def
_get_augment_params
(
self
,
d
):
def
_get_augment_params
(
self
,
d
):
"""
"""
get the augmentor parameters
Get the augmentor parameters.
"""
"""
return
None
return
None
def
_rand_range
(
self
,
low
=
1.0
,
high
=
None
,
size
=
None
):
def
_rand_range
(
self
,
low
=
1.0
,
high
=
None
,
size
=
None
):
"""
Uniform float random number between low and high.
"""
if
high
is
None
:
if
high
is
None
:
low
,
high
=
0
,
low
low
,
high
=
0
,
low
if
size
is
None
:
if
size
is
None
:
...
@@ -64,9 +68,15 @@ class Augmentor(object):
...
@@ -64,9 +68,15 @@ class Augmentor(object):
class
ImageAugmentor
(
Augmentor
):
class
ImageAugmentor
(
Augmentor
):
def
_augment_coords
(
self
,
coords
,
param
):
def
_augment_coords
(
self
,
coords
,
param
):
"""
"""
Augment the coordinates given the param.
By default, keeps coordinates unchanged.
By default, keeps coordinates unchanged.
If a subclass changes coordinates but couldn't implement this method,
If a subclass changes coordinates but couldn't implement this method,
it should ``raise NotImplementedError()``.
it should ``raise NotImplementedError()``.
Args:
coords: Nx2 floating point nparray where each row is (x, y)
Returns:
new coords
"""
"""
return
coords
return
coords
...
@@ -86,7 +96,7 @@ class AugmentorList(ImageAugmentor):
...
@@ -86,7 +96,7 @@ class AugmentorList(ImageAugmentor):
def
_get_augment_params
(
self
,
img
):
def
_get_augment_params
(
self
,
img
):
# the next augmentor requires the previous one to finish
# the next augmentor requires the previous one to finish
raise
RuntimeError
(
"Cannot simply get
parameters of a AugmentorList
!"
)
raise
RuntimeError
(
"Cannot simply get
all parameters of a AugmentorList without running the augmentation
!"
)
def
_augment_return_params
(
self
,
img
):
def
_augment_return_params
(
self
,
img
):
assert
img
.
ndim
in
[
2
,
3
],
img
.
ndim
assert
img
.
ndim
in
[
2
,
3
],
img
.
ndim
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment