Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
c44150d8
Commit
c44150d8
authored
Jan 27, 2018
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
update docs
parent
97cce4a2
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
22 additions
and
5 deletions
+22
-5
examples/cifar-convnet.py
examples/cifar-convnet.py
+0
-1
tensorpack/dataflow/image.py
tensorpack/dataflow/image.py
+13
-3
tensorpack/dataflow/imgaug/base.py
tensorpack/dataflow/imgaug/base.py
+9
-1
No files found.
examples/cifar-convnet.py
View file @
c44150d8
...
@@ -6,7 +6,6 @@ import tensorflow as tf
...
@@ -6,7 +6,6 @@ import tensorflow as tf
import
argparse
import
argparse
import
os
import
os
from
tensorpack
import
*
from
tensorpack
import
*
from
tensorpack.tfutils.summary
import
*
from
tensorpack.tfutils.summary
import
*
from
tensorpack.dataflow
import
dataset
from
tensorpack.dataflow
import
dataset
...
...
tensorpack/dataflow/image.py
View file @
c44150d8
...
@@ -13,7 +13,13 @@ from ..utils.argtools import shape2d
...
@@ -13,7 +13,13 @@ from ..utils.argtools import shape2d
__all__
=
[
'ImageFromFile'
,
'AugmentImageComponent'
,
'AugmentImageCoordinates'
,
'AugmentImageComponents'
]
__all__
=
[
'ImageFromFile'
,
'AugmentImageComponent'
,
'AugmentImageCoordinates'
,
'AugmentImageComponents'
]
def
_valid_coords
(
coords
):
def
check_dtype
(
img
):
if
isinstance
(
img
.
dtype
,
np
.
integer
):
assert
img
.
dtype
==
np
.
uint8
,
\
"[Augmentor] Got image of type {}, use uint8 or floating points instead!"
.
format
(
img
.
dtype
)
def
validate_coords
(
coords
):
assert
coords
.
ndim
==
2
,
coords
.
ndim
assert
coords
.
ndim
==
2
,
coords
.
ndim
assert
coords
.
shape
[
1
]
==
2
,
coords
.
shape
assert
coords
.
shape
[
1
]
==
2
,
coords
.
shape
assert
np
.
issubdtype
(
coords
.
dtype
,
np
.
float
),
coords
.
dtype
assert
np
.
issubdtype
(
coords
.
dtype
,
np
.
float
),
coords
.
dtype
...
@@ -99,6 +105,7 @@ class AugmentImageComponent(MapDataComponent):
...
@@ -99,6 +105,7 @@ class AugmentImageComponent(MapDataComponent):
exception_handler
=
ExceptionHandler
(
catch_exceptions
)
exception_handler
=
ExceptionHandler
(
catch_exceptions
)
def
func
(
x
):
def
func
(
x
):
check_dtype
(
x
)
with
exception_handler
.
catch
():
with
exception_handler
.
catch
():
if
copy
:
if
copy
:
x
=
copy_mod
.
deepcopy
(
x
)
x
=
copy_mod
.
deepcopy
(
x
)
...
@@ -138,7 +145,8 @@ class AugmentImageCoordinates(MapData):
...
@@ -138,7 +145,8 @@ class AugmentImageCoordinates(MapData):
def
func
(
dp
):
def
func
(
dp
):
with
exception_handler
.
catch
():
with
exception_handler
.
catch
():
img
,
coords
=
dp
[
img_index
],
dp
[
coords_index
]
img
,
coords
=
dp
[
img_index
],
dp
[
coords_index
]
_valid_coords
(
coords
)
check_dtype
(
img
)
validate_coords
(
coords
)
if
copy
:
if
copy
:
img
,
coords
=
copy_mod
.
deepcopy
((
img
,
coords
))
img
,
coords
=
copy_mod
.
deepcopy
((
img
,
coords
))
img
,
prms
=
self
.
augs
.
_augment_return_params
(
img
)
img
,
prms
=
self
.
augs
.
_augment_return_params
(
img
)
...
@@ -191,14 +199,16 @@ class AugmentImageComponents(MapData):
...
@@ -191,14 +199,16 @@ class AugmentImageComponents(MapData):
copy_func
=
copy_mod
.
deepcopy
if
copy
else
lambda
x
:
x
# noqa
copy_func
=
copy_mod
.
deepcopy
if
copy
else
lambda
x
:
x
# noqa
with
exception_handler
.
catch
():
with
exception_handler
.
catch
():
major_image
=
index
[
0
]
# image to be used to get params. TODO better design?
major_image
=
index
[
0
]
# image to be used to get params. TODO better design?
check_dtype
(
major_image
)
im
=
copy_func
(
dp
[
major_image
])
im
=
copy_func
(
dp
[
major_image
])
im
,
prms
=
self
.
augs
.
_augment_return_params
(
im
)
im
,
prms
=
self
.
augs
.
_augment_return_params
(
im
)
dp
[
major_image
]
=
im
dp
[
major_image
]
=
im
for
idx
in
index
[
1
:]:
for
idx
in
index
[
1
:]:
check_dtype
(
dp
[
idx
])
dp
[
idx
]
=
self
.
augs
.
_augment
(
copy_func
(
dp
[
idx
]),
prms
)
dp
[
idx
]
=
self
.
augs
.
_augment
(
copy_func
(
dp
[
idx
]),
prms
)
for
idx
in
coords_index
:
for
idx
in
coords_index
:
coords
=
copy_func
(
dp
[
idx
])
coords
=
copy_func
(
dp
[
idx
])
_valid
_coords
(
coords
)
validate
_coords
(
coords
)
dp
[
idx
]
=
self
.
augs
.
_augment_coords
(
coords
,
prms
)
dp
[
idx
]
=
self
.
augs
.
_augment_coords
(
coords
,
prms
)
return
dp
return
dp
...
...
tensorpack/dataflow/imgaug/base.py
View file @
c44150d8
...
@@ -5,10 +5,12 @@
...
@@ -5,10 +5,12 @@
import
inspect
import
inspect
import
pprint
import
pprint
from
abc
import
abstractmethod
,
ABCMeta
from
abc
import
abstractmethod
,
ABCMeta
from
...utils.utils
import
get_rng
import
six
import
six
from
six.moves
import
zip
from
six.moves
import
zip
from
...utils.utils
import
get_rng
from
..image
import
check_dtype
__all__
=
[
'Augmentor'
,
'ImageAugmentor'
,
'AugmentorList'
]
__all__
=
[
'Augmentor'
,
'ImageAugmentor'
,
'AugmentorList'
]
...
@@ -101,6 +103,10 @@ class Augmentor(object):
...
@@ -101,6 +103,10 @@ class Augmentor(object):
class
ImageAugmentor
(
Augmentor
):
class
ImageAugmentor
(
Augmentor
):
"""
ImageAugmentor should take images of type uint8 in range [0, 255], or
floating point images in range [0, 1] or [0, 255].
"""
def
augment_coords
(
self
,
coords
,
param
):
def
augment_coords
(
self
,
coords
,
param
):
return
self
.
_augment_coords
(
coords
,
param
)
return
self
.
_augment_coords
(
coords
,
param
)
...
@@ -137,6 +143,7 @@ class AugmentorList(ImageAugmentor):
...
@@ -137,6 +143,7 @@ class AugmentorList(ImageAugmentor):
raise
RuntimeError
(
"Cannot simply get all parameters of a AugmentorList without running the augmentation!"
)
raise
RuntimeError
(
"Cannot simply get all parameters of a AugmentorList without running the augmentation!"
)
def
_augment_return_params
(
self
,
img
):
def
_augment_return_params
(
self
,
img
):
check_dtype
(
img
)
assert
img
.
ndim
in
[
2
,
3
],
img
.
ndim
assert
img
.
ndim
in
[
2
,
3
],
img
.
ndim
prms
=
[]
prms
=
[]
...
@@ -146,6 +153,7 @@ class AugmentorList(ImageAugmentor):
...
@@ -146,6 +153,7 @@ class AugmentorList(ImageAugmentor):
return
img
,
prms
return
img
,
prms
def
_augment
(
self
,
img
,
param
):
def
_augment
(
self
,
img
,
param
):
check_dtype
(
img
)
assert
img
.
ndim
in
[
2
,
3
],
img
.
ndim
assert
img
.
ndim
in
[
2
,
3
],
img
.
ndim
for
aug
,
prm
in
zip
(
self
.
augs
,
param
):
for
aug
,
prm
in
zip
(
self
.
augs
,
param
):
img
=
aug
.
_augment
(
img
,
prm
)
img
=
aug
.
_augment
(
img
,
prm
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment