Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
2d923b91
Commit
2d923b91
authored
Feb 23, 2016
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
update augmentors
parent
cdd71bfe
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
15 additions
and
5 deletions
+15
-5
tensorpack/dataflow/common.py
tensorpack/dataflow/common.py
+2
-0
tensorpack/dataflow/dataset/cifar10.py
tensorpack/dataflow/dataset/cifar10.py
+2
-1
tensorpack/dataflow/imgaug/imgproc.py
tensorpack/dataflow/imgaug/imgproc.py
+2
-1
tensorpack/dataflow/imgaug/noname.py
tensorpack/dataflow/imgaug/noname.py
+7
-1
tensorpack/train/base.py
tensorpack/train/base.py
+2
-2
No files found.
tensorpack/dataflow/common.py
View file @
2d923b91
...
@@ -4,6 +4,7 @@
...
@@ -4,6 +4,7 @@
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
import
numpy
as
np
import
numpy
as
np
import
copy
from
.base
import
DataFlow
from
.base
import
DataFlow
from
.imgaug
import
AugmentorList
,
Image
from
.imgaug
import
AugmentorList
,
Image
...
@@ -149,6 +150,7 @@ class MapDataComponent(DataFlow):
...
@@ -149,6 +150,7 @@ class MapDataComponent(DataFlow):
def
get_data
(
self
):
def
get_data
(
self
):
for
dp
in
self
.
ds
.
get_data
():
for
dp
in
self
.
ds
.
get_data
():
dp
=
copy
.
deepcopy
(
dp
)
# avoid modifying the original dp
dp
[
self
.
index
]
=
self
.
func
(
dp
[
self
.
index
])
dp
[
self
.
index
]
=
self
.
func
(
dp
[
self
.
index
])
yield
dp
yield
dp
...
...
tensorpack/dataflow/dataset/cifar10.py
View file @
2d923b91
...
@@ -6,6 +6,7 @@ import os, sys
...
@@ -6,6 +6,7 @@ import os, sys
import
pickle
import
pickle
import
numpy
as
np
import
numpy
as
np
from
six.moves
import
urllib
from
six.moves
import
urllib
import
copy
import
tarfile
import
tarfile
import
logging
import
logging
...
@@ -78,7 +79,7 @@ class Cifar10(DataFlow):
...
@@ -78,7 +79,7 @@ class Cifar10(DataFlow):
if
train_or_test
==
'train'
:
if
train_or_test
==
'train'
:
self
.
fs
=
fnames
[:
5
]
self
.
fs
=
fnames
[:
5
]
else
:
else
:
self
.
fs
=
fnames
[
-
1
]
self
.
fs
=
[
fnames
[
-
1
]
]
for
f
in
self
.
fs
:
for
f
in
self
.
fs
:
if
not
os
.
path
.
isfile
(
f
):
if
not
os
.
path
.
isfile
(
f
):
raise
ValueError
(
'Failed to find file: '
+
f
)
raise
ValueError
(
'Failed to find file: '
+
f
)
...
...
tensorpack/dataflow/imgaug/imgproc.py
View file @
2d923b91
...
@@ -10,7 +10,7 @@ __all__ = ['BrightnessAdd', 'Contrast', 'MeanVarianceNormalize']
...
@@ -10,7 +10,7 @@ __all__ = ['BrightnessAdd', 'Contrast', 'MeanVarianceNormalize']
class
BrightnessAdd
(
ImageAugmentor
):
class
BrightnessAdd
(
ImageAugmentor
):
"""
"""
Randomly add a value within [-delta,delta], and clip in [0,
1
]
Randomly add a value within [-delta,delta], and clip in [0,
255
]
"""
"""
def
__init__
(
self
,
delta
):
def
__init__
(
self
,
delta
):
assert
delta
>
0
assert
delta
>
0
...
@@ -24,6 +24,7 @@ class BrightnessAdd(ImageAugmentor):
...
@@ -24,6 +24,7 @@ class BrightnessAdd(ImageAugmentor):
class
Contrast
(
ImageAugmentor
):
class
Contrast
(
ImageAugmentor
):
"""
"""
Apply x = (x - mean) * contrast_factor + mean to each channel
Apply x = (x - mean) * contrast_factor + mean to each channel
and clip to [0, 255]
"""
"""
def
__init__
(
self
,
factor_range
):
def
__init__
(
self
,
factor_range
):
self
.
_init
(
locals
())
self
.
_init
(
locals
())
...
...
tensorpack/dataflow/imgaug/noname.py
View file @
2d923b91
...
@@ -7,7 +7,7 @@ from .base import ImageAugmentor
...
@@ -7,7 +7,7 @@ from .base import ImageAugmentor
import
numpy
as
np
import
numpy
as
np
import
cv2
import
cv2
__all__
=
[
'Flip'
]
__all__
=
[
'Flip'
,
'MapImage'
]
class
Flip
(
ImageAugmentor
):
class
Flip
(
ImageAugmentor
):
def
__init__
(
self
,
horiz
=
False
,
vert
=
False
,
prob
=
0.5
):
def
__init__
(
self
,
horiz
=
False
,
vert
=
False
,
prob
=
0.5
):
...
@@ -34,3 +34,9 @@ class Flip(ImageAugmentor):
...
@@ -34,3 +34,9 @@ class Flip(ImageAugmentor):
raise
NotImplementedError
()
raise
NotImplementedError
()
class
MapImage
(
ImageAugmentor
):
def
__init__
(
self
,
func
):
self
.
func
=
func
def
_augment
(
self
,
img
):
img
.
arr
=
self
.
func
(
img
.
arr
)
tensorpack/train/base.py
View file @
2d923b91
...
@@ -62,13 +62,13 @@ class Trainer(object):
...
@@ -62,13 +62,13 @@ class Trainer(object):
self
.
summary_writer
.
add_summary
(
summary
,
self
.
global_step
)
self
.
summary_writer
.
add_summary
(
summary
,
self
.
global_step
)
def
main_loop
(
self
):
def
main_loop
(
self
):
self
.
_init_summary
()
callbacks
=
self
.
config
.
callbacks
callbacks
=
self
.
config
.
callbacks
callbacks
.
before_train
(
self
)
with
self
.
sess
.
as_default
():
with
self
.
sess
.
as_default
():
try
:
try
:
self
.
_init_summary
()
self
.
global_step
=
get_global_step
()
self
.
global_step
=
get_global_step
()
logger
.
info
(
"Start training with global_step={}"
.
format
(
self
.
global_step
))
logger
.
info
(
"Start training with global_step={}"
.
format
(
self
.
global_step
))
callbacks
.
before_train
(
self
)
tf
.
get_default_graph
()
.
finalize
()
tf
.
get_default_graph
()
.
finalize
()
for
epoch
in
xrange
(
1
,
self
.
config
.
max_epoch
):
for
epoch
in
xrange
(
1
,
self
.
config
.
max_epoch
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment