Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
a1da74af
Commit
a1da74af
authored
Dec 02, 2018
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
more docs about augmentor (#996)
parent
7a0b15d5
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
33 additions
and
17 deletions
+33
-17
examples/FasterRCNN/config.py
examples/FasterRCNN/config.py
+1
-1
examples/ImageNetModels/imagenet_utils.py
examples/ImageNetModels/imagenet_utils.py
+1
-1
tensorpack/dataflow/image.py
tensorpack/dataflow/image.py
+7
-8
tensorpack/dataflow/imgaug/base.py
tensorpack/dataflow/imgaug/base.py
+20
-3
tensorpack/dataflow/parallel.py
tensorpack/dataflow/parallel.py
+3
-1
tensorpack/dataflow/parallel_map.py
tensorpack/dataflow/parallel_map.py
+1
-3
No files found.
examples/FasterRCNN/config.py
View file @
a1da74af
...
@@ -123,7 +123,7 @@ _C.TRAIN.LR_SCHEDULE = [240000, 320000, 360000] # "2x" schedule in detectro
...
@@ -123,7 +123,7 @@ _C.TRAIN.LR_SCHEDULE = [240000, 320000, 360000] # "2x" schedule in detectro
_C
.
TRAIN
.
EVAL_PERIOD
=
25
# period (epochs) to run eva
_C
.
TRAIN
.
EVAL_PERIOD
=
25
# period (epochs) to run eva
# preprocessing --------------------
# preprocessing --------------------
# Alternative old (worse & faster) setting: 600
, 1024
# Alternative old (worse & faster) setting: 600
_C
.
PREPROC
.
TRAIN_SHORT_EDGE_SIZE
=
[
800
,
800
]
# [min, max] to sample from
_C
.
PREPROC
.
TRAIN_SHORT_EDGE_SIZE
=
[
800
,
800
]
# [min, max] to sample from
_C
.
PREPROC
.
TEST_SHORT_EDGE_SIZE
=
800
_C
.
PREPROC
.
TEST_SHORT_EDGE_SIZE
=
800
_C
.
PREPROC
.
MAX_SIZE
=
1333
_C
.
PREPROC
.
MAX_SIZE
=
1333
...
...
examples/ImageNetModels/imagenet_utils.py
View file @
a1da74af
...
@@ -205,7 +205,7 @@ def get_imagenet_tfdata(datadir, name, batch_size, mapper=None, parallel=None):
...
@@ -205,7 +205,7 @@ def get_imagenet_tfdata(datadir, name, batch_size, mapper=None, parallel=None):
def
fbresnet_mapper
(
isTrain
):
def
fbresnet_mapper
(
isTrain
):
"""
"""
Note: compared to fbresnet_augmentor, it
Note: compared to fbresnet_augmentor, it
lacks some photometric augmentation that may have a small effect on accuracy.
lacks some photometric augmentation that may have a small effect
(0.1~0.2
%
)
on accuracy.
"""
"""
JPEG_OPT
=
{
'fancy_upscaling'
:
True
,
'dct_method'
:
'INTEGER_ACCURATE'
}
JPEG_OPT
=
{
'fancy_upscaling'
:
True
,
'dct_method'
:
'INTEGER_ACCURATE'
}
...
...
tensorpack/dataflow/image.py
View file @
a1da74af
...
@@ -15,9 +15,8 @@ __all__ = ['ImageFromFile', 'AugmentImageComponent', 'AugmentImageCoordinates',
...
@@ -15,9 +15,8 @@ __all__ = ['ImageFromFile', 'AugmentImageComponent', 'AugmentImageCoordinates',
def
check_dtype
(
img
):
def
check_dtype
(
img
):
assert
isinstance
(
img
,
np
.
ndarray
),
"[Augmentor] Needs an numpy array, but got a {}!"
.
format
(
type
(
img
))
assert
isinstance
(
img
,
np
.
ndarray
),
"[Augmentor] Needs an numpy array, but got a {}!"
.
format
(
type
(
img
))
if
isinstance
(
img
.
dtype
,
np
.
integer
):
assert
not
isinstance
(
img
.
dtype
,
np
.
integer
)
or
(
img
.
dtype
==
np
.
uint8
),
\
assert
img
.
dtype
==
np
.
uint8
,
\
"[Augmentor] Got image of type {}, use uint8 or floating points instead!"
.
format
(
img
.
dtype
)
"[Augmentor] Got image of type {}, use uint8 or floating points instead!"
.
format
(
img
.
dtype
)
def
validate_coords
(
coords
):
def
validate_coords
(
coords
):
...
@@ -161,9 +160,9 @@ class AugmentImageCoordinates(MapData):
...
@@ -161,9 +160,9 @@ class AugmentImageCoordinates(MapData):
validate_coords
(
coords
)
validate_coords
(
coords
)
if
self
.
_copy
:
if
self
.
_copy
:
img
,
coords
=
copy_mod
.
deepcopy
((
img
,
coords
))
img
,
coords
=
copy_mod
.
deepcopy
((
img
,
coords
))
img
,
prms
=
self
.
augs
.
_
augment_return_params
(
img
)
img
,
prms
=
self
.
augs
.
augment_return_params
(
img
)
dp
[
self
.
_img_index
]
=
img
dp
[
self
.
_img_index
]
=
img
coords
=
self
.
augs
.
_
augment_coords
(
coords
,
prms
)
coords
=
self
.
augs
.
augment_coords
(
coords
,
prms
)
dp
[
self
.
_coords_index
]
=
coords
dp
[
self
.
_coords_index
]
=
coords
return
dp
return
dp
...
@@ -207,15 +206,15 @@ class AugmentImageComponents(MapData):
...
@@ -207,15 +206,15 @@ class AugmentImageComponents(MapData):
major_image
=
index
[
0
]
# image to be used to get params. TODO better design?
major_image
=
index
[
0
]
# image to be used to get params. TODO better design?
im
=
copy_func
(
dp
[
major_image
])
im
=
copy_func
(
dp
[
major_image
])
check_dtype
(
im
)
check_dtype
(
im
)
im
,
prms
=
self
.
augs
.
_
augment_return_params
(
im
)
im
,
prms
=
self
.
augs
.
augment_return_params
(
im
)
dp
[
major_image
]
=
im
dp
[
major_image
]
=
im
for
idx
in
index
[
1
:]:
for
idx
in
index
[
1
:]:
check_dtype
(
dp
[
idx
])
check_dtype
(
dp
[
idx
])
dp
[
idx
]
=
self
.
augs
.
_augment
(
copy_func
(
dp
[
idx
]),
prms
)
dp
[
idx
]
=
self
.
augs
.
augment_with_params
(
copy_func
(
dp
[
idx
]),
prms
)
for
idx
in
coords_index
:
for
idx
in
coords_index
:
coords
=
copy_func
(
dp
[
idx
])
coords
=
copy_func
(
dp
[
idx
])
validate_coords
(
coords
)
validate_coords
(
coords
)
dp
[
idx
]
=
self
.
augs
.
_
augment_coords
(
coords
,
prms
)
dp
[
idx
]
=
self
.
augs
.
augment_coords
(
coords
,
prms
)
return
dp
return
dp
super
(
AugmentImageComponents
,
self
)
.
__init__
(
ds
,
func
)
super
(
AugmentImageComponents
,
self
)
.
__init__
(
ds
,
func
)
...
...
tensorpack/dataflow/imgaug/base.py
View file @
a1da74af
...
@@ -35,15 +35,22 @@ class Augmentor(object):
...
@@ -35,15 +35,22 @@ class Augmentor(object):
def
augment
(
self
,
d
):
def
augment
(
self
,
d
):
"""
"""
Perform augmentation on the data.
Perform augmentation on the data.
Returns:
augmented data
"""
"""
d
,
params
=
self
.
_augment_return_params
(
d
)
d
,
params
=
self
.
_augment_return_params
(
d
)
return
d
return
d
def
augment_return_params
(
self
,
d
):
def
augment_return_params
(
self
,
d
):
"""
"""
Augment the data and return the augmentation parameters.
The returned parameters can be used to augment another data with identical transformation.
This can be used in, e.g. augmentation for image, masks, keypoints altogether.
Returns:
Returns:
augmented data
augmented data
augmentation params
augmentation params
: can be any type
"""
"""
return
self
.
_augment_return_params
(
d
)
return
self
.
_augment_return_params
(
d
)
...
@@ -54,6 +61,15 @@ class Augmentor(object):
...
@@ -54,6 +61,15 @@ class Augmentor(object):
prms
=
self
.
_get_augment_params
(
d
)
prms
=
self
.
_get_augment_params
(
d
)
return
(
self
.
_augment
(
d
,
prms
),
prms
)
return
(
self
.
_augment
(
d
,
prms
),
prms
)
def
augment_with_params
(
self
,
d
,
param
):
"""
Augment the data with the given param.
Returns:
augmented data
"""
return
self
.
_augment
(
d
,
param
)
@
abstractmethod
@
abstractmethod
def
_augment
(
self
,
d
,
param
):
def
_augment
(
self
,
d
,
param
):
"""
"""
...
@@ -115,8 +131,9 @@ class ImageAugmentor(Augmentor):
...
@@ -115,8 +131,9 @@ class ImageAugmentor(Augmentor):
def
augment_coords
(
self
,
coords
,
param
):
def
augment_coords
(
self
,
coords
,
param
):
"""
"""
Augment the coordinates given the param.
Augment the coordinates given the param.
By default, an augmentor keeps coordinates unchanged.
By default, an augmentor keeps coordinates unchanged.
If a subclass changes coordinates but couldn't implement this method,
If a subclass
of :class:`ImageAugmentor`
changes coordinates but couldn't implement this method,
it should ``raise NotImplementedError()``.
it should ``raise NotImplementedError()``.
Args:
Args:
...
@@ -132,7 +149,7 @@ class ImageAugmentor(Augmentor):
...
@@ -132,7 +149,7 @@ class ImageAugmentor(Augmentor):
class
AugmentorList
(
ImageAugmentor
):
class
AugmentorList
(
ImageAugmentor
):
"""
"""
Augment by a list of augmentors
Augment
an image
by a list of augmentors
"""
"""
def
__init__
(
self
,
augmentors
):
def
__init__
(
self
,
augmentors
):
...
...
tensorpack/dataflow/parallel.py
View file @
a1da74af
...
@@ -378,7 +378,9 @@ class MultiThreadPrefetchData(DataFlow):
...
@@ -378,7 +378,9 @@ class MultiThreadPrefetchData(DataFlow):
def
__init__
(
self
,
get_df
,
nr_prefetch
,
nr_thread
):
def
__init__
(
self
,
get_df
,
nr_prefetch
,
nr_thread
):
"""
"""
Args:
Args:
get_df ( -> DataFlow): a callable which returns a DataFlow
get_df ( -> DataFlow): a callable which returns a DataFlow.
Each thread will call this function to get the DataFlow to use.
Therefore do not return the same DataFlow for each call.
nr_prefetch (int): size of the queue
nr_prefetch (int): size of the queue
nr_thread (int): number of threads
nr_thread (int): number of threads
"""
"""
...
...
tensorpack/dataflow/parallel_map.py
View file @
a1da74af
...
@@ -11,7 +11,6 @@ import zmq
...
@@ -11,7 +11,6 @@ import zmq
from
.base
import
DataFlow
,
ProxyDataFlow
,
DataFlowReentrantGuard
from
.base
import
DataFlow
,
ProxyDataFlow
,
DataFlowReentrantGuard
from
.common
import
RepeatedData
from
.common
import
RepeatedData
from
..utils.concurrency
import
StoppableThread
,
enable_death_signal
from
..utils.concurrency
import
StoppableThread
,
enable_death_signal
from
..utils
import
logger
from
..utils.serialize
import
loads
,
dumps
from
..utils.serialize
import
loads
,
dumps
from
.parallel
import
(
from
.parallel
import
(
...
@@ -59,10 +58,9 @@ class _ParallelMapData(ProxyDataFlow):
...
@@ -59,10 +58,9 @@ class _ParallelMapData(ProxyDataFlow):
dp
=
next
(
self
.
_iter
)
dp
=
next
(
self
.
_iter
)
self
.
_send
(
dp
)
self
.
_send
(
dp
)
except
StopIteration
:
except
StopIteration
:
logger
.
e
rror
(
raise
RuntimeE
rror
(
"[{}] buffer_size cannot be larger than the size of the DataFlow when strict=True!"
.
format
(
"[{}] buffer_size cannot be larger than the size of the DataFlow when strict=True!"
.
format
(
type
(
self
)
.
__name__
))
type
(
self
)
.
__name__
))
raise
self
.
_buffer_occupancy
+=
cnt
self
.
_buffer_occupancy
+=
cnt
def
get_data_non_strict
(
self
):
def
get_data_non_strict
(
self
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment