Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
f6b502d7
Commit
f6b502d7
authored
May 10, 2016
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
bsds dataset
parent
1c3d8741
Changes
8
Show whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
133 additions
and
5 deletions
+133
-5
tensorpack/__init__.py
tensorpack/__init__.py
+1
-0
tensorpack/dataflow/dataset/bsds500.py
tensorpack/dataflow/dataset/bsds500.py
+99
-0
tensorpack/dataflow/dataset/cifar10.py
tensorpack/dataflow/dataset/cifar10.py
+1
-1
tensorpack/dataflow/dataset/ilsvrc.py
tensorpack/dataflow/dataset/ilsvrc.py
+1
-1
tensorpack/dataflow/dataset/svhn.py
tensorpack/dataflow/dataset/svhn.py
+2
-2
tensorpack/dataflow/image.py
tensorpack/dataflow/image.py
+2
-0
tensorpack/tfutils/symbolic_functions.py
tensorpack/tfutils/symbolic_functions.py
+25
-0
tensorpack/train/base.py
tensorpack/train/base.py
+2
-1
No files found.
tensorpack/__init__.py
View file @
f6b502d7
...
...
@@ -2,6 +2,7 @@
# File: __init__.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
import
cv2
# fix https://github.com/tensorflow/tensorflow/issues/1924
from
.
import
models
from
.
import
train
from
.
import
utils
...
...
tensorpack/dataflow/dataset/bsds500.py
0 → 100644
View file @
f6b502d7
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# File: bsds500.py
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
import
os
,
glob
import
cv2
import
numpy
as
np
from
scipy.io
import
loadmat
from
...utils
import
logger
,
get_rng
from
...utils.fs
import
download
from
..base
import
DataFlow
__all__
=
[
'BSDS500'
]
DATA_URL
=
"http://www.eecs.berkeley.edu/Research/Projects/CS/vision/grouping/BSR/BSR_bsds500.tgz"
IMG_W
,
IMG_H
=
481
,
321
class
BSDS500
(
DataFlow
):
"""
`Berkeley Segmentation Data Set and Benchmarks 500
<http://www.eecs.berkeley.edu/Research/Projects/CS/vision/grouping/resources.html#bsds500>`_.
Produce (image, label) pair, where image has shape (321, 481, 3) and
ranges in [0,255]. Label is binary and has shape (321, 481).
Those pixels annotated as boundaries by >= 3 out of 5 annotators are
considered positive examples. This is used in `Holistically-Nested Edge Detection
<http://arxiv.org/abs/1504.06375>`_.
"""
def
__init__
(
self
,
name
,
data_dir
=
None
,
shuffle
=
True
):
"""
:param name: 'train', 'test', 'val'
:param data_dir: a directory containing the original 'BSR' directory.
"""
# check and download data
if
data_dir
is
None
:
data_dir
=
os
.
path
.
join
(
os
.
path
.
dirname
(
__file__
),
'bsds500_data'
)
if
not
os
.
path
.
isdir
(
os
.
path
.
join
(
data_dir
,
'BSR'
)):
download
(
DATA_URL
,
data_dir
)
filename
=
DATA_URL
.
split
(
'/'
)[
-
1
]
filepath
=
os
.
path
.
join
(
data_dir
,
filename
)
import
tarfile
tarfile
.
open
(
filepath
,
'r:gz'
)
.
extractall
(
data_dir
)
self
.
data_root
=
os
.
path
.
join
(
data_dir
,
'BSR'
,
'BSDS500'
,
'data'
)
assert
os
.
path
.
isdir
(
self
.
data_root
)
self
.
shuffle
=
shuffle
assert
name
in
[
'train'
,
'test'
,
'val'
]
self
.
_load
(
name
)
self
.
rng
=
get_rng
(
self
)
def
reset_state
(
self
):
self
.
rng
=
get_rng
(
self
)
def
_load
(
self
,
name
):
image_glob
=
os
.
path
.
join
(
self
.
data_root
,
'images'
,
name
,
'*.jpg'
)
image_files
=
glob
.
glob
(
image_glob
)
gt_dir
=
os
.
path
.
join
(
self
.
data_root
,
'groundTruth'
,
name
)
self
.
data
=
np
.
zeros
((
len
(
image_files
),
IMG_H
,
IMG_W
,
3
),
dtype
=
'uint8'
)
self
.
label
=
np
.
zeros
((
len
(
image_files
),
IMG_H
,
IMG_W
),
dtype
=
'bool'
)
for
idx
,
f
in
enumerate
(
image_files
):
im
=
cv2
.
imread
(
f
,
cv2
.
IMREAD_COLOR
)
assert
im
is
not
None
if
im
.
shape
[
0
]
>
im
.
shape
[
1
]:
im
=
np
.
transpose
(
im
,
(
1
,
0
,
2
))
assert
im
.
shape
[:
2
]
==
(
IMG_H
,
IMG_W
),
"{} != {}"
.
format
(
im
.
shape
[:
2
],
(
IMG_H
,
IMG_W
))
imgid
=
os
.
path
.
basename
(
f
)
.
split
(
'.'
)[
0
]
gt_file
=
os
.
path
.
join
(
gt_dir
,
imgid
)
gt
=
loadmat
(
gt_file
)[
'groundTruth'
][
0
]
gt
=
sum
(
gt
[
k
][
'Boundaries'
][
0
][
0
]
for
k
in
range
(
5
))
gt
[
gt
<
3
]
=
0
gt
[
gt
!=
0
]
=
1
if
gt
.
shape
[
0
]
>
gt
.
shape
[
1
]:
gt
=
gt
.
transpose
()
assert
gt
.
shape
==
(
IMG_H
,
IMG_W
)
self
.
data
[
idx
]
=
im
self
.
label
[
idx
]
=
gt
def
size
(
self
):
return
self
.
data
.
shape
[
0
]
def
get_data
(
self
):
idxs
=
np
.
arange
(
self
.
data
.
shape
[
0
])
if
self
.
shuffle
:
self
.
rng
.
shuffle
(
idxs
)
for
k
in
idxs
:
yield
[
self
.
data
[
k
],
self
.
label
[
k
]]
if
__name__
==
'__main__'
:
a
=
BSDS500
(
'val'
)
for
k
in
a
.
get_data
():
cv2
.
imshow
(
"haha"
,
k
[
1
]
.
astype
(
'uint8'
)
*
255
)
cv2
.
waitKey
(
1000
)
tensorpack/dataflow/dataset/cifar10.py
View file @
f6b502d7
...
...
@@ -9,7 +9,6 @@ import random
import
six
from
six.moves
import
urllib
,
range
import
copy
import
tarfile
import
logging
from
...utils
import
logger
,
get_rng
...
...
@@ -31,6 +30,7 @@ def maybe_download_and_extract(dest_directory):
download
(
DATA_URL
,
dest_directory
)
filename
=
DATA_URL
.
split
(
'/'
)[
-
1
]
filepath
=
os
.
path
.
join
(
dest_directory
,
filename
)
import
tarfile
tarfile
.
open
(
filepath
,
'r:gz'
)
.
extractall
(
dest_directory
)
def
read_cifar10
(
filenames
):
...
...
tensorpack/dataflow/dataset/ilsvrc.py
View file @
f6b502d7
...
...
@@ -117,7 +117,7 @@ class ILSVRC12(DataFlow):
for
k
in
idxs
:
tp
=
self
.
imglist
[
k
]
fname
=
os
.
path
.
join
(
self
.
dir
,
self
.
name
,
tp
[
0
])
.
strip
()
im
=
cv2
.
imread
(
fname
)
im
=
cv2
.
imread
(
fname
,
cv2
.
IMREAD_COLOR
)
assert
im
is
not
None
,
fname
if
im
.
ndim
==
2
:
im
=
np
.
expand_dims
(
im
,
2
)
.
repeat
(
3
,
2
)
...
...
tensorpack/dataflow/dataset/svhn.py
View file @
f6b502d7
...
...
@@ -26,8 +26,8 @@ class SVHNDigit(DataFlow):
def
__init__
(
self
,
name
,
data_dir
=
None
,
shuffle
=
True
):
"""
name: 'train', 'test', or 'extra'
data_dir: a directory containing the original {train,test,extra}_32x32.mat
:param
name: 'train', 'test', or 'extra'
:param
data_dir: a directory containing the original {train,test,extra}_32x32.mat
"""
self
.
shuffle
=
shuffle
self
.
rng
=
get_rng
(
self
)
...
...
tensorpack/dataflow/image.py
View file @
f6b502d7
...
...
@@ -54,3 +54,5 @@ class AugmentImageComponent(MapDataComponent):
def
reset_state
(
self
):
self
.
ds
.
reset_state
()
self
.
augs
.
reset_state
()
tensorpack/tfutils/symbolic_functions.py
View file @
f6b502d7
...
...
@@ -54,3 +54,28 @@ def logSoftmax(x):
return
logprob
def
class_balanced_binary_class_cross_entropy
(
pred
,
label
,
name
=
'cross_entropy_loss'
):
"""
The class-balanced cross entropy loss for binary classification,
as in `Holistically-Nested Edge Detection
<http://arxiv.org/abs/1504.06375>`_.
:param pred: size: b x ANYTHING. the predictions in [0,1].
:param label: size: b x ANYTHING. the ground truth in {0,1}.
:returns: class-balanced binary classification cross entropy loss
"""
z
=
batch_flatten
(
pred
)
y
=
batch_flatten
(
label
)
count_neg
=
tf
.
reduce_sum
(
1.
-
y
)
count_pos
=
tf
.
reduce_sum
(
y
)
total
=
tf
.
add
(
count_neg
,
count_pos
)
beta
=
tf
.
truediv
(
count_neg
,
total
)
eps
=
1e-8
loss_pos
=
tf
.
mul
(
-
beta
,
tf
.
reduce_sum
(
tf
.
mul
(
tf
.
log
(
tf
.
abs
(
z
)
+
eps
),
y
),
1
))
loss_neg
=
tf
.
mul
(
1.
-
beta
,
tf
.
reduce_sum
(
tf
.
mul
(
tf
.
log
(
tf
.
abs
(
1.
-
z
)
+
eps
),
1.
-
y
),
1
))
cost
=
tf
.
sub
(
loss_pos
,
loss_neg
)
cost
=
tf
.
reduce_mean
(
cost
,
name
=
name
)
return
cost
tensorpack/train/base.py
View file @
f6b502d7
...
...
@@ -98,7 +98,8 @@ class Trainer(object):
for
step
in
tqdm
.
trange
(
self
.
config
.
step_per_epoch
,
leave
=
True
,
mininterval
=
0.5
,
dynamic_ncols
=
True
,
ascii
=
True
):
dynamic_ncols
=
True
,
ascii
=
True
,
bar_format
=
'{l_bar}{bar}|{n_fmt}/{total_fmt} [{elapsed}<{remaining},{rate_noinv_fmt}]'
):
if
self
.
coord
.
should_stop
():
return
self
.
run_step
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment