Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
36bdc187
Commit
36bdc187
authored
Oct 19, 2020
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
add places365 dataset
parent
ecf525d6
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
83 additions
and
2 deletions
+83
-2
tensorpack/dataflow/dataset/ilsvrc.py
tensorpack/dataflow/dataset/ilsvrc.py
+2
-2
tensorpack/dataflow/dataset/places.py
tensorpack/dataflow/dataset/places.py
+81
-0
No files found.
tensorpack/dataflow/dataset/ilsvrc.py
View file @
36bdc187
...
...
@@ -192,8 +192,7 @@ class ILSVRC12(ILSVRC12Files):
dir (str): A directory containing a subdir named ``name``,
containing the images in a structure described below.
name (str): One of 'train' or 'val' or 'test'.
shuffle (bool): shuffle the dataset.
Defaults to True if name=='train'.
shuffle (bool): shuffle the dataset. Defaults to True if name=='train'.
dir_structure (str): One of 'original' or 'train'.
The directory structure for the 'val' directory.
'original' means the original decompressed directory, which only has list of image files (as below).
...
...
@@ -354,6 +353,7 @@ try:
except
ImportError
:
from
...utils.develop
import
create_dummy_class
ILSVRC12
=
create_dummy_class
(
'ILSVRC12'
,
'cv2'
)
# noqa
TinyImageNet
=
create_dummy_class
(
'TinyImageNet'
,
'cv2'
)
# noqa
if
__name__
==
'__main__'
:
meta
=
ILSVRCMeta
()
...
...
tensorpack/dataflow/dataset/places.py
0 → 100644
View file @
36bdc187
#-*- coding: utf-8 -*-
import
os
import
numpy
as
np
from
...utils
import
logger
from
..base
import
RNGDataFlow
class
Places365Standard
(
RNGDataFlow
):
"""
The Places365-Standard Dataset, in low resolution format only.
Produces BGR images of shape (256, 256, 3) in range [0, 255].
"""
def
__init__
(
self
,
dir
,
name
,
shuffle
=
None
):
"""
Args:
dir: path to the Places365-Standard dataset in its "easy directory
structure". See http://places2.csail.mit.edu/download.html
name: one of "train" or "val"
shuffle (bool): shuffle the dataset. Defaults to True if name=='train'.
"""
assert
name
in
[
'train'
,
'val'
],
name
dir
=
os
.
path
.
expanduser
(
dir
)
assert
os
.
path
.
isdir
(
dir
),
dir
self
.
name
=
name
if
shuffle
is
None
:
shuffle
=
name
==
'train'
self
.
shuffle
=
shuffle
label_file
=
os
.
path
.
join
(
dir
,
name
+
".txt"
)
all_files
=
[]
labels
=
set
()
with
open
(
label_file
)
as
f
:
for
line
in
f
:
filepath
=
os
.
path
.
join
(
dir
,
line
.
strip
())
line
=
line
.
strip
()
.
split
(
"/"
)
label
=
line
[
1
]
all_files
.
append
((
filepath
,
label
))
labels
.
add
(
label
)
self
.
_labels
=
sorted
(
list
(
labels
))
# class ids are sorted alphabetically:
# https://github.com/CSAILVision/places365/blob/master/categories_places365.txt
labelmap
=
{
label
:
id
for
id
,
label
in
enumerate
(
self
.
_labels
)}
self
.
_files
=
[(
path
,
labelmap
[
x
])
for
path
,
x
in
all_files
]
logger
.
info
(
"Found {} images in {}."
.
format
(
len
(
self
.
_files
),
label_file
))
def
get_label_names
(
self
):
"""
Returns:
[str]: name of each class.
"""
return
self
.
_labels
def
__len__
(
self
):
return
len
(
self
.
_files
)
def
__iter__
(
self
):
idxs
=
np
.
arange
(
len
(
self
.
_files
))
if
self
.
shuffle
:
self
.
rng
.
shuffle
(
idxs
)
for
k
in
idxs
:
fname
,
label
=
self
.
_files
[
k
]
im
=
cv2
.
imread
(
fname
,
cv2
.
IMREAD_COLOR
)
assert
im
is
not
None
,
fname
yield
[
im
,
label
]
try
:
import
cv2
except
ImportError
:
from
...utils.develop
import
create_dummy_class
Places365Standard
=
create_dummy_class
(
'Places365Standard'
,
'cv2'
)
# noqa
if
__name__
==
'__main__'
:
from
tensorpack.dataflow
import
PrintData
ds
=
Places365Standard
(
"~/data/places365_standard/"
,
'train'
)
ds
=
PrintData
(
ds
,
num
=
100
)
ds
.
reset_state
()
for
k
in
ds
:
pass
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment