Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
cba97f75
Commit
cba97f75
authored
Jul 15, 2016
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
bbox and tf_func
parent
0ec586b0
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
150 additions
and
16 deletions
+150
-16
examples/ResNet/svhn-resnet.py
examples/ResNet/svhn-resnet.py
+3
-4
tensorpack/dataflow/dataset/ilsvrc.py
tensorpack/dataflow/dataset/ilsvrc.py
+65
-10
tensorpack/dataflow/imgaug/base.py
tensorpack/dataflow/imgaug/base.py
+0
-1
tensorpack/dataflow/imgaug/noname.py
tensorpack/dataflow/imgaug/noname.py
+1
-1
tensorpack/dataflow/tf_func.py
tensorpack/dataflow/tf_func.py
+81
-0
No files found.
examples/ResNet/svhn-resnet.py
View file @
cba97f75
...
...
@@ -38,13 +38,12 @@ def get_data(train_or_test):
if
isTrain
:
augmentors
=
[
imgaug
.
CenterPaste
((
40
,
40
)),
imgaug
.
RandomCrop
((
32
,
32
)),
#imgaug.Flip(horiz=True),
imgaug
.
Brightness
(
10
),
imgaug
.
Contrast
((
0.8
,
1.2
)),
imgaug
.
GaussianDeform
(
# this is slow
imgaug
.
GaussianDeform
(
# this is slow
. without it, can only reach 1.9% error
[(
0.2
,
0.2
),
(
0.2
,
0.8
),
(
0.8
,
0.8
),
(
0.8
,
0.2
)],
(
32
,
32
),
0.2
,
3
),
(
40
,
40
),
0.2
,
3
),
imgaug
.
RandomCrop
((
32
,
32
)),
imgaug
.
MapImage
(
lambda
x
:
x
-
pp_mean
),
]
else
:
...
...
tensorpack/dataflow/dataset/ilsvrc.py
View file @
cba97f75
...
...
@@ -7,10 +7,12 @@ import tarfile
import
cv2
import
numpy
as
np
from
six.moves
import
range
import
xml.etree.ElementTree
as
ET
from
...utils
import
logger
,
get_rng
,
get_dataset_dir
,
memoized
from
...utils.loadcaffe
import
get_caffe_pb
from
...utils.fs
import
mkdir_p
,
download
from
...utils.timer
import
timed_operation
from
..base
import
RNGDataFlow
__all__
=
[
'ILSVRCMeta'
,
'ILSVRC12'
]
...
...
@@ -20,7 +22,6 @@ def log_once(s): logger.warn(s)
CAFFE_ILSVRC12_URL
=
"http://dl.caffe.berkeleyvision.org/caffe_ilsvrc12.tar.gz"
# TODO move caffe_pb outside
class
ILSVRCMeta
(
object
):
"""
Some metadata for ILSVRC dataset.
...
...
@@ -90,15 +91,16 @@ class ILSVRCMeta(object):
class
ILSVRC12
(
RNGDataFlow
):
def
__init__
(
self
,
dir
,
name
,
meta_dir
=
None
,
shuffle
=
True
,
dir_structure
=
'original'
):
dir_structure
=
'original'
,
include_bb
=
False
):
"""
:param dir: A directory containing a subdir named `name`, where the
original ILSVRC12_`name`.tar gets decompressed.
:param name: 'train' or 'val' or 'test'
:param dir_structure:
t
he dir structure of 'val' or 'test'.
i
f is 'original' then keep the original decompressed dir with list
of image files
. i
f equals to 'train', use the `train/` dir
:param dir_structure:
T
he dir structure of 'val' or 'test'.
I
f is 'original' then keep the original decompressed dir with list
of image files
(as below). I
f equals to 'train', use the `train/` dir
structure with class name as subdirectories.
:param include_bb: Include the bounding box. Useful in training.
Dir should have the following structure:
...
...
@@ -116,6 +118,11 @@ class ILSVRC12(RNGDataFlow):
test/
ILSVRC2012_test_00000001.JPEG
...
bbox/
n02134418/
n02134418_198.xml
...
...
After decompress ILSVRC12_img_train.tar, you can use the following
command to build the above structure for `train/`:
...
...
@@ -125,6 +132,7 @@ class ILSVRC12(RNGDataFlow):
find -type f -name '*.tar' | parallel -P 10 'echo {} && mkdir -p {/.} && tar xf {} -C {/.}'
Or:
for i in *.tar; do dir=${i
%
.tar}; echo $dir; mkdir -p $dir; tar xf $i -C $dir; done
"""
assert
name
in
[
'train'
,
'test'
,
'val'
]
self
.
full_dir
=
os
.
path
.
join
(
dir
,
name
)
...
...
@@ -136,12 +144,19 @@ class ILSVRC12(RNGDataFlow):
self
.
dir_structure
=
dir_structure
self
.
synset
=
meta
.
get_synset_1000
()
if
include_bb
:
assert
name
==
'train'
,
'Bounding box only available for training'
self
.
bblist
=
ILSVRC12
.
get_training_bbox
(
os
.
path
.
join
(
dir
,
'bbox'
),
self
.
imglist
)
self
.
include_bb
=
include_bb
def
size
(
self
):
return
len
(
self
.
imglist
)
def
get_data
(
self
):
"""
Produce original images or shape [h, w, 3], and label
Produce original images of shape [h, w, 3], and label,
and optionally a bbox of [xmin, ymin, xmax, ymax] in [0, 1]
"""
idxs
=
np
.
arange
(
len
(
self
.
imglist
))
add_label_to_fname
=
(
self
.
name
!=
'train'
and
self
.
dir_structure
!=
'original'
)
...
...
@@ -157,15 +172,55 @@ class ILSVRC12(RNGDataFlow):
assert
im
is
not
None
,
fname
if
im
.
ndim
==
2
:
im
=
np
.
expand_dims
(
im
,
2
)
.
repeat
(
3
,
2
)
if
self
.
include_bb
:
bb
=
self
.
bblist
[
k
]
if
not
bb
:
bb
=
[
0
,
0
,
1
,
1
]
yield
[
im
,
label
,
bb
]
else
:
yield
[
im
,
label
]
@
staticmethod
def
get_training_bbox
(
bbox_dir
,
imglist
):
ret
=
[]
def
parse_bbox
(
fname
):
root
=
ET
.
parse
(
fname
)
.
getroot
()
size
=
root
.
find
(
'size'
)
.
getchildren
()
size
=
map
(
int
,
[
size
[
0
]
.
text
,
size
[
1
]
.
text
])
box
=
root
.
find
(
'object'
)
.
find
(
'bndbox'
)
.
getchildren
()
box
=
map
(
lambda
x
:
float
(
x
.
text
),
box
)
box
[
0
]
/=
size
[
0
]
box
[
1
]
/=
size
[
1
]
box
[
2
]
/=
size
[
0
]
box
[
3
]
/=
size
[
1
]
return
np
.
asarray
(
box
,
dtype
=
'float32'
)
with
timed_operation
(
'Loading Bounding Boxes ...'
):
cnt
=
0
import
tqdm
for
k
in
tqdm
.
trange
(
len
(
imglist
)):
fname
=
imglist
[
k
][
0
]
fname
=
fname
[:
-
4
]
+
'xml'
fname
=
os
.
path
.
join
(
bbox_dir
,
fname
)
try
:
ret
.
append
(
parse_bbox
(
fname
))
cnt
+=
1
except
KeyboardInterrupt
:
raise
except
:
ret
.
append
(
None
)
logger
.
info
(
"{}/{} images have bounding box."
.
format
(
cnt
,
len
(
imglist
)))
return
ret
if
__name__
==
'__main__'
:
meta
=
ILSVRCMeta
()
print
(
meta
.
get_per_pixel_mean
())
#print(meta.get_synset_words_1000())
#ds = ILSVRC12('/home/wyx/data/imagenet', 'val')
ds
=
ILSVRC12
(
'/home/wyx/data/fake_ilsvrc/'
,
'train'
,
include_bb
=
True
,
shuffle
=
False
)
ds
.
reset_state
()
for
k
in
ds
.
get_data
():
from
IPython
import
embed
;
embed
()
...
...
tensorpack/dataflow/imgaug/base.py
View file @
cba97f75
...
...
@@ -16,7 +16,6 @@ class ImageAugmentor(object):
self
.
reset_state
()
def
_init
(
self
,
params
=
None
):
self
.
reset_state
()
if
params
:
for
k
,
v
in
params
.
items
():
if
k
!=
'self'
:
...
...
tensorpack/dataflow/imgaug/noname.py
View file @
cba97f75
...
...
@@ -22,7 +22,7 @@ class Flip(ImageAugmentor):
:param prob: probability of flip.
"""
if
horiz
and
vert
:
raise
ValueError
(
"Please use two Flip
, with both 0.5 prob
"
)
raise
ValueError
(
"Please use two Flip
instead.
"
)
elif
horiz
:
self
.
code
=
1
elif
vert
:
...
...
tensorpack/dataflow/tf_func.py
0 → 100644
View file @
cba97f75
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# File: tf_func.py
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
from
.base
import
ProxyDataFlow
from
..utils
import
logger
try
:
import
tensorflow
as
tf
except
ImportError
:
logger
.
warn
(
"Cannot import tensorflow. TFFuncMapper won't be available."
)
__all__
=
[]
else
:
__all__
=
[
'TFFuncMapper'
]
class
TFFuncMapper
(
ProxyDataFlow
):
def
__init__
(
self
,
ds
,
get_placeholders
,
symbf
,
apply_symbf_on_dp
,
device
=
'/cpu:0'
):
"""
:param get_placeholders: a function returning the placeholders
:param symbf: a symbolic function taking the placeholders
:param apply_symbf_on_dp: apply the above function to datapoint
"""
super
(
TFFuncMapper
,
self
)
.
__init__
(
ds
)
self
.
get_placeholders
=
get_placeholders
self
.
symbf
=
symbf
self
.
apply_symbf_on_dp
=
apply_symbf_on_dp
self
.
device
=
device
def
reset_state
(
self
):
super
(
TFFuncMapper
,
self
)
.
reset_state
()
self
.
graph
=
tf
.
Graph
()
with
self
.
graph
.
as_default
(),
\
tf
.
device
(
self
.
device
):
self
.
placeholders
=
self
.
get_placeholders
()
self
.
output_vars
=
self
.
symbf
(
self
.
placeholders
)
self
.
sess
=
tf
.
Session
()
def
run_func
(
vals
):
return
self
.
sess
.
run
(
self
.
output_vars
,
feed_dict
=
dict
(
zip
(
self
.
placeholders
,
vals
)))
self
.
run_func
=
run_func
def
get_data
(
self
):
for
dp
in
self
.
ds
.
get_data
():
dp
=
self
.
apply_symbf_on_dp
(
dp
,
self
.
run_func
)
if
dp
:
yield
dp
if
__name__
==
'__main__'
:
from
.raw
import
FakeData
from
.prefetch
import
PrefetchDataZMQ
from
.image
import
AugmentImageComponent
from
.
import
imgaug
ds
=
FakeData
([[
224
,
224
,
3
]],
100000
,
random
=
False
)
def
tf_aug
(
v
):
v
=
v
[
0
]
v
=
tf
.
image
.
random_brightness
(
v
,
0.1
)
v
=
tf
.
image
.
random_contrast
(
v
,
0.8
,
1.2
)
v
=
tf
.
image
.
random_flip_left_right
(
v
)
return
v
ds
=
TFFuncMapper
(
ds
,
lambda
:
[
tf
.
placeholder
(
tf
.
float32
,
[
224
,
224
,
3
],
name
=
'img'
)],
tf_aug
,
lambda
dp
,
f
:
[
f
([
dp
[
0
]])[
0
]]
)
#ds = AugmentImageComponent(ds,
#[imgaug.Brightness(0.1, clip=False),
#imgaug.Contrast((0.8, 1.2), clip=False),
#imgaug.Flip(horiz=True)
#])
#ds = PrefetchDataZMQ(ds, 4)
ds
.
reset_state
()
import
tqdm
itr
=
ds
.
get_data
()
for
k
in
tqdm
.
trange
(
100000
):
next
(
itr
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment