Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
4a59173c
Commit
4a59173c
authored
Jun 10, 2016
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
dataset dir
parent
194cda0b
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
30 additions
and
12 deletions
+30
-12
tensorpack/dataflow/dataset/bsds500.py
tensorpack/dataflow/dataset/bsds500.py
+2
-1
tensorpack/dataflow/dataset/cifar.py
tensorpack/dataflow/dataset/cifar.py
+2
-2
tensorpack/dataflow/dataset/common.py
tensorpack/dataflow/dataset/common.py
+17
-0
tensorpack/dataflow/dataset/ilsvrc.py
tensorpack/dataflow/dataset/ilsvrc.py
+5
-5
tensorpack/dataflow/dataset/mnist.py
tensorpack/dataflow/dataset/mnist.py
+2
-1
tensorpack/dataflow/dataset/svhn.py
tensorpack/dataflow/dataset/svhn.py
+2
-3
No files found.
tensorpack/dataflow/dataset/bsds500.py
View file @
4a59173c
...
...
@@ -10,6 +10,7 @@ from scipy.io import loadmat
from
...utils
import
logger
,
get_rng
from
...utils.fs
import
download
from
..base
import
DataFlow
from
.common
import
get_dataset_dir
__all__
=
[
'BSDS500'
]
...
...
@@ -36,7 +37,7 @@ class BSDS500(DataFlow):
"""
# check and download data
if
data_dir
is
None
:
data_dir
=
os
.
path
.
join
(
os
.
path
.
dirname
(
__file__
),
'bsds500_data'
)
data_dir
=
get_dataset_dir
(
'bsds500_data'
)
if
not
os
.
path
.
isdir
(
os
.
path
.
join
(
data_dir
,
'BSR'
)):
download
(
DATA_URL
,
data_dir
)
filename
=
DATA_URL
.
split
(
'/'
)[
-
1
]
...
...
tensorpack/dataflow/dataset/cifar.py
View file @
4a59173c
...
...
@@ -16,6 +16,7 @@ import logging
from
...utils
import
logger
,
get_rng
from
...utils.fs
import
download
from
..base
import
DataFlow
from
.common
import
get_dataset_dir
__all__
=
[
'Cifar10'
,
'Cifar100'
]
...
...
@@ -92,8 +93,7 @@ class CifarBase(DataFlow):
assert
cifar_classnum
==
10
or
cifar_classnum
==
100
self
.
cifar_classnum
=
cifar_classnum
if
dir
is
None
:
dir
=
os
.
path
.
join
(
os
.
path
.
dirname
(
__file__
),
'cifar{}_data'
.
format
(
cifar_classnum
))
dir
=
get_dataset_dir
(
'cifar{}_data'
.
format
(
cifar_classnum
))
maybe_download_and_extract
(
dir
,
self
.
cifar_classnum
)
fnames
=
get_filenames
(
dir
,
cifar_classnum
)
if
train_or_test
==
'train'
:
...
...
tensorpack/dataflow/dataset/common.py
0 → 100644
View file @
4a59173c
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# File: common.py
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
import
os
__all__
=
[
'get_dataset_dir'
]
def
get_dataset_dir
(
name
):
d
=
os
.
environ
[
'TENSORPACK_DATASET'
]:
if
d
:
assert
os
.
path
.
isdir
(
d
)
else
:
d
=
os
.
path
.
dirname
(
__file__
)
return
os
.
path
.
join
(
d
,
name
)
tensorpack/dataflow/dataset/ilsvrc.py
View file @
4a59173c
...
...
@@ -8,8 +8,9 @@ import cv2
import
numpy
as
np
from
...utils
import
logger
,
get_rng
from
..base
import
DataFlow
from
...utils.fs
import
mkdir_p
,
download
from
..base
import
DataFlow
from
.common
import
get_dataset_dir
__all__
=
[
'ILSVRCMeta'
,
'ILSVRC12'
]
...
...
@@ -28,7 +29,7 @@ class ILSVRCMeta(object):
"""
def
__init__
(
self
,
dir
=
None
):
if
dir
is
None
:
dir
=
os
.
path
.
join
(
os
.
path
.
dirname
(
__file__
),
'ilsvrc_metadata'
)
dir
=
get_dataset_dir
(
'ilsvrc_metadata'
)
self
.
dir
=
dir
mkdir_p
(
self
.
dir
)
self
.
caffe_pb_file
=
os
.
path
.
join
(
self
.
dir
,
'caffe_pb2.py'
)
...
...
@@ -91,8 +92,7 @@ class ILSVRC12(DataFlow):
name: 'train' or 'val' or 'test'
"""
assert
name
in
[
'train'
,
'test'
,
'val'
]
self
.
dir
=
dir
self
.
name
=
name
self
.
full_dir
=
os
.
path
.
join
(
dir
,
name
)
self
.
shuffle
=
shuffle
self
.
meta
=
ILSVRCMeta
(
meta_dir
)
self
.
imglist
=
self
.
meta
.
get_image_list
(
name
)
...
...
@@ -116,7 +116,7 @@ class ILSVRC12(DataFlow):
self
.
rng
.
shuffle
(
idxs
)
for
k
in
idxs
:
tp
=
self
.
imglist
[
k
]
fname
=
os
.
path
.
join
(
self
.
dir
,
self
.
name
,
tp
[
0
])
.
strip
()
fname
=
os
.
path
.
join
(
self
.
full_dir
,
tp
[
0
])
.
strip
()
im
=
cv2
.
imread
(
fname
,
cv2
.
IMREAD_COLOR
)
assert
im
is
not
None
,
fname
if
im
.
ndim
==
2
:
...
...
tensorpack/dataflow/dataset/mnist.py
View file @
4a59173c
...
...
@@ -12,6 +12,7 @@ from six.moves import urllib, range
from
...utils
import
logger
from
...utils.fs
import
download
from
..base
import
DataFlow
from
.common
import
get_dataset_dir
__all__
=
[
'Mnist'
]
...
...
@@ -103,7 +104,7 @@ class Mnist(DataFlow):
train_or_test: string either 'train' or 'test'
"""
if
dir
is
None
:
dir
=
os
.
path
.
join
(
os
.
path
.
dirname
(
__file__
),
'mnist_data'
)
dir
=
get_dataset_dir
(
'mnist_data'
)
assert
train_or_test
in
[
'train'
,
'test'
]
self
.
train_or_test
=
train_or_test
self
.
shuffle
=
shuffle
...
...
tensorpack/dataflow/dataset/svhn.py
View file @
4a59173c
...
...
@@ -12,6 +12,7 @@ from six.moves import range
from
...utils
import
logger
,
get_rng
from
..base
import
DataFlow
from
.common
import
get_dataset_dir
__all__
=
[
'SVHNDigit'
]
...
...
@@ -36,9 +37,7 @@ class SVHNDigit(DataFlow):
self
.
X
,
self
.
Y
=
SVHNDigit
.
Cache
[
name
]
return
if
data_dir
is
None
:
data_dir
=
os
.
path
.
join
(
os
.
path
.
dirname
(
__file__
),
'svhn_data'
)
data_dir
=
get_dataset_dir
(
'svhn_data'
)
assert
name
in
[
'train'
,
'test'
,
'extra'
],
name
filename
=
os
.
path
.
join
(
data_dir
,
name
+
'_32x32.mat'
)
assert
os
.
path
.
isfile
(
filename
),
\
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment