Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
1e71e8f9
Commit
1e71e8f9
authored
May 10, 2016
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
augment images together
parent
ac6e140f
Changes
6
Show whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
32 additions
and
7 deletions
+32
-7
tensorpack/dataflow/base.py
tensorpack/dataflow/base.py
+2
-0
tensorpack/dataflow/common.py
tensorpack/dataflow/common.py
+1
-2
tensorpack/dataflow/dataset/cifar10.py
tensorpack/dataflow/dataset/cifar10.py
+1
-1
tensorpack/dataflow/image.py
tensorpack/dataflow/image.py
+25
-2
tensorpack/models/pool.py
tensorpack/models/pool.py
+2
-1
tensorpack/tfutils/symbolic_functions.py
tensorpack/tfutils/symbolic_functions.py
+1
-1
No files found.
tensorpack/dataflow/base.py
View file @
1e71e8f9
...
...
@@ -16,6 +16,8 @@ class DataFlow(object):
def
get_data
(
self
):
"""
A generator to generate data as a list.
Datapoint should be a mutable list.
Each component should be assumed immutable.
"""
def
size
(
self
):
...
...
tensorpack/dataflow/common.py
View file @
1e71e8f9
...
...
@@ -183,8 +183,7 @@ class MapDataComponent(ProxyDataFlow):
for
dp
in
self
.
ds
.
get_data
():
repl
=
self
.
func
(
dp
[
self
.
index
])
if
repl
is
not
None
:
dp
=
copy
.
deepcopy
(
dp
)
# avoid modifying the original dp
dp
[
self
.
index
]
=
repl
dp
[
self
.
index
]
=
repl
# NOTE modifying
yield
dp
class
RandomChooseData
(
DataFlow
):
...
...
tensorpack/dataflow/dataset/cifar10.py
View file @
1e71e8f9
...
...
@@ -98,7 +98,7 @@ class Cifar10(DataFlow):
if
self
.
shuffle
:
self
.
rng
.
shuffle
(
idxs
)
for
k
in
idxs
:
yield
self
.
data
[
k
]
yield
copy
.
copy
(
self
.
data
[
k
])
def
get_per_pixel_mean
(
self
):
"""
...
...
tensorpack/dataflow/image.py
View file @
1e71e8f9
...
...
@@ -6,10 +6,10 @@ import numpy as np
import
cv2
import
copy
from
.base
import
DataFlow
,
ProxyDataFlow
from
.common
import
MapDataComponent
from
.common
import
MapDataComponent
,
MapData
from
.imgaug
import
AugmentorList
__all__
=
[
'ImageFromFile'
,
'AugmentImageComponent'
]
__all__
=
[
'ImageFromFile'
,
'AugmentImageComponent'
,
'AugmentImagesTogether'
]
class
ImageFromFile
(
DataFlow
):
""" Generate rgb images from list of files """
...
...
@@ -56,3 +56,26 @@ class AugmentImageComponent(MapDataComponent):
self
.
augs
.
reset_state
()
class
AugmentImagesTogether
(
MapData
):
def
__init__
(
self
,
ds
,
augmentors
,
index
=
(
0
,
1
)):
"""
:param ds: a `DataFlow` instance.
:param augmentors: a list of `ImageAugmentor` instance to be applied in order.
:param index: tuple of indices of the image components
"""
self
.
augs
=
AugmentorList
(
augmentors
)
self
.
ds
=
ds
def
func
(
dp
):
im
=
dp
[
index
[
0
]]
im
,
prms
=
self
.
augs
.
_augment_return_params
(
im
)
dp
[
index
[
0
]]
=
im
for
idx
in
index
[
1
:]:
dp
[
idx
]
=
self
.
augs
.
_augment
(
dp
[
idx
],
prms
)
return
dp
super
(
AugmentImagesTogether
,
self
)
.
__init__
(
ds
,
func
)
def
reset_state
(
self
):
self
.
ds
.
reset_state
()
self
.
augs
.
reset_state
()
tensorpack/models/pool.py
View file @
1e71e8f9
...
...
@@ -8,7 +8,8 @@ import numpy
from
._common
import
*
from
..tfutils.symbolic_functions
import
*
__all__
=
[
'MaxPooling'
,
'FixedUnPooling'
,
'AvgPooling'
,
'GlobalAvgPooling'
]
__all__
=
[
'MaxPooling'
,
'FixedUnPooling'
,
'AvgPooling'
,
'GlobalAvgPooling'
,
'BilinearUpSample'
]
@
layer_register
()
def
MaxPooling
(
x
,
shape
,
stride
=
None
,
padding
=
'VALID'
):
...
...
tensorpack/tfutils/symbolic_functions.py
View file @
1e71e8f9
...
...
@@ -65,7 +65,7 @@ def class_balanced_binary_class_cross_entropy(pred, label, name='cross_entropy_l
:returns: class-balanced binary classification cross entropy loss
"""
z
=
batch_flatten
(
pred
)
y
=
batch_flatten
(
label
)
y
=
tf
.
cast
(
batch_flatten
(
label
),
tf
.
float32
)
count_neg
=
tf
.
reduce_sum
(
1.
-
y
)
count_pos
=
tf
.
reduce_sum
(
y
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment