Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
a8b72a87
Commit
a8b72a87
authored
Jul 03, 2018
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Improve dataset download logic
parent
5854c7de
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
40 additions
and
24 deletions
+40
-24
tensorpack/dataflow/dataset/bsds500.py
tensorpack/dataflow/dataset/bsds500.py
+4
-1
tensorpack/dataflow/dataset/cifar.py
tensorpack/dataflow/dataset/cifar.py
+6
-7
tensorpack/dataflow/dataset/ilsvrc.py
tensorpack/dataflow/dataset/ilsvrc.py
+2
-2
tensorpack/graph_builder/utils.py
tensorpack/graph_builder/utils.py
+8
-4
tensorpack/libinfo.py
tensorpack/libinfo.py
+11
-8
tensorpack/utils/fs.py
tensorpack/utils/fs.py
+9
-2
No files found.
tensorpack/dataflow/dataset/bsds500.py
View file @
a8b72a87
...
@@ -10,7 +10,10 @@ from ...utils.fs import download, get_dataset_path
...
@@ -10,7 +10,10 @@ from ...utils.fs import download, get_dataset_path
from
..base
import
RNGDataFlow
from
..base
import
RNGDataFlow
__all__
=
[
'BSDS500'
]
__all__
=
[
'BSDS500'
]
DATA_URL
=
"http://www.eecs.berkeley.edu/Research/Projects/CS/vision/grouping/BSR/BSR_bsds500.tgz"
DATA_URL
=
"http://www.eecs.berkeley.edu/Research/Projects/CS/vision/grouping/BSR/BSR_bsds500.tgz"
DATA_SIZE
=
70763455
IMG_W
,
IMG_H
=
481
,
321
IMG_W
,
IMG_H
=
481
,
321
...
@@ -35,7 +38,7 @@ class BSDS500(RNGDataFlow):
...
@@ -35,7 +38,7 @@ class BSDS500(RNGDataFlow):
if
data_dir
is
None
:
if
data_dir
is
None
:
data_dir
=
get_dataset_path
(
'bsds500_data'
)
data_dir
=
get_dataset_path
(
'bsds500_data'
)
if
not
os
.
path
.
isdir
(
os
.
path
.
join
(
data_dir
,
'BSR'
)):
if
not
os
.
path
.
isdir
(
os
.
path
.
join
(
data_dir
,
'BSR'
)):
download
(
DATA_URL
,
data_dir
)
download
(
DATA_URL
,
data_dir
,
expect_size
=
DATA_SIZE
)
filename
=
DATA_URL
.
split
(
'/'
)[
-
1
]
filename
=
DATA_URL
.
split
(
'/'
)[
-
1
]
filepath
=
os
.
path
.
join
(
data_dir
,
filename
)
filepath
=
os
.
path
.
join
(
data_dir
,
filename
)
import
tarfile
import
tarfile
...
...
tensorpack/dataflow/dataset/cifar.py
View file @
a8b72a87
...
@@ -6,6 +6,7 @@
...
@@ -6,6 +6,7 @@
import
os
import
os
import
pickle
import
pickle
import
numpy
as
np
import
numpy
as
np
import
tarfile
import
six
import
six
from
six.moves
import
range
from
six.moves
import
range
...
@@ -16,13 +17,12 @@ from ..base import RNGDataFlow
...
@@ -16,13 +17,12 @@ from ..base import RNGDataFlow
__all__
=
[
'Cifar10'
,
'Cifar100'
]
__all__
=
[
'Cifar10'
,
'Cifar100'
]
DATA_URL_CIFAR_10
=
'http://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz'
DATA_URL_CIFAR_10
=
(
'http://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz'
,
170498071
)
DATA_URL_CIFAR_100
=
'http://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz'
DATA_URL_CIFAR_100
=
(
'http://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz'
,
169001437
)
def
maybe_download_and_extract
(
dest_directory
,
cifar_classnum
):
def
maybe_download_and_extract
(
dest_directory
,
cifar_classnum
):
"""Download and extract the tarball from Alex's website.
"""Download and extract the tarball from Alex's website. Copied from tensorflow example """
copied from tensorflow example """
assert
cifar_classnum
==
10
or
cifar_classnum
==
100
assert
cifar_classnum
==
10
or
cifar_classnum
==
100
if
cifar_classnum
==
10
:
if
cifar_classnum
==
10
:
cifar_foldername
=
'cifar-10-batches-py'
cifar_foldername
=
'cifar-10-batches-py'
...
@@ -33,10 +33,9 @@ def maybe_download_and_extract(dest_directory, cifar_classnum):
...
@@ -33,10 +33,9 @@ def maybe_download_and_extract(dest_directory, cifar_classnum):
return
return
else
:
else
:
DATA_URL
=
DATA_URL_CIFAR_10
if
cifar_classnum
==
10
else
DATA_URL_CIFAR_100
DATA_URL
=
DATA_URL_CIFAR_10
if
cifar_classnum
==
10
else
DATA_URL_CIFAR_100
download
(
DATA_URL
,
dest_directory
)
filename
=
DATA_URL
[
0
]
.
split
(
'/'
)[
-
1
]
filename
=
DATA_URL
.
split
(
'/'
)[
-
1
]
filepath
=
os
.
path
.
join
(
dest_directory
,
filename
)
filepath
=
os
.
path
.
join
(
dest_directory
,
filename
)
import
tarfile
download
(
DATA_URL
[
0
],
dest_directory
,
expect_size
=
DATA_URL
[
1
])
tarfile
.
open
(
filepath
,
'r:gz'
)
.
extractall
(
dest_directory
)
tarfile
.
open
(
filepath
,
'r:gz'
)
.
extractall
(
dest_directory
)
...
...
tensorpack/dataflow/dataset/ilsvrc.py
View file @
a8b72a87
...
@@ -14,7 +14,7 @@ from ..base import RNGDataFlow
...
@@ -14,7 +14,7 @@ from ..base import RNGDataFlow
__all__
=
[
'ILSVRCMeta'
,
'ILSVRC12'
,
'ILSVRC12Files'
]
__all__
=
[
'ILSVRCMeta'
,
'ILSVRC12'
,
'ILSVRC12Files'
]
CAFFE_ILSVRC12_URL
=
"http://dl.caffe.berkeleyvision.org/caffe_ilsvrc12.tar.gz"
CAFFE_ILSVRC12_URL
=
(
"http://dl.caffe.berkeleyvision.org/caffe_ilsvrc12.tar.gz"
,
17858008
)
class
ILSVRCMeta
(
object
):
class
ILSVRCMeta
(
object
):
...
@@ -53,7 +53,7 @@ class ILSVRCMeta(object):
...
@@ -53,7 +53,7 @@ class ILSVRCMeta(object):
return
dict
(
enumerate
(
lines
))
return
dict
(
enumerate
(
lines
))
def
_download_caffe_meta
(
self
):
def
_download_caffe_meta
(
self
):
fpath
=
download
(
CAFFE_ILSVRC12_URL
,
self
.
dir
,
expect_size
=
17858008
)
fpath
=
download
(
CAFFE_ILSVRC12_URL
[
0
],
self
.
dir
,
expect_size
=
CAFFE_ILSVRC12_URL
[
1
]
)
tarfile
.
open
(
fpath
,
'r:gz'
)
.
extractall
(
self
.
dir
)
tarfile
.
open
(
fpath
,
'r:gz'
)
.
extractall
(
self
.
dir
)
def
get_image_list
(
self
,
name
,
dir_structure
=
'original'
):
def
get_image_list
(
self
,
name
,
dir_structure
=
'original'
):
...
...
tensorpack/graph_builder/utils.py
View file @
a8b72a87
...
@@ -8,6 +8,7 @@ import tensorflow as tf
...
@@ -8,6 +8,7 @@ import tensorflow as tf
from
..tfutils.varreplace
import
custom_getter_scope
from
..tfutils.varreplace
import
custom_getter_scope
from
..tfutils.scope_utils
import
under_name_scope
,
cached_name_scope
from
..tfutils.scope_utils
import
under_name_scope
,
cached_name_scope
from
..tfutils.common
import
get_tf_version_number
from
..utils.argtools
import
call_only_once
from
..utils.argtools
import
call_only_once
from
..utils
import
logger
from
..utils
import
logger
...
@@ -66,13 +67,16 @@ class LeastLoadedDeviceSetter(object):
...
@@ -66,13 +67,16 @@ class LeastLoadedDeviceSetter(object):
self
.
ps_sizes
=
[
0
]
*
len
(
self
.
ps_devices
)
self
.
ps_sizes
=
[
0
]
*
len
(
self
.
ps_devices
)
def
__call__
(
self
,
op
):
def
__call__
(
self
,
op
):
def
sanitize_name
(
name
):
# tensorflow/tensorflow#11484
if
get_tf_version_number
()
>=
1.8
:
return
tf
.
DeviceSpec
.
from_string
(
name
)
.
to_string
()
from
tensorflow.python.training.device_util
import
canonicalize
else
:
def
canonicalize
(
name
):
# tensorflow/tensorflow#11484
return
tf
.
DeviceSpec
.
from_string
(
name
)
.
to_string
()
if
op
.
device
:
if
op
.
device
:
return
op
.
device
return
op
.
device
if
op
.
type
not
in
[
'Variable'
,
'VariableV2'
]:
if
op
.
type
not
in
[
'Variable'
,
'VariableV2'
]:
return
sanitize_nam
e
(
self
.
worker_device
)
return
canonicaliz
e
(
self
.
worker_device
)
device_index
,
_
=
min
(
enumerate
(
device_index
,
_
=
min
(
enumerate
(
self
.
ps_sizes
),
key
=
operator
.
itemgetter
(
1
))
self
.
ps_sizes
),
key
=
operator
.
itemgetter
(
1
))
...
@@ -84,7 +88,7 @@ class LeastLoadedDeviceSetter(object):
...
@@ -84,7 +88,7 @@ class LeastLoadedDeviceSetter(object):
self
.
ps_sizes
[
device_index
]
+=
var_size
self
.
ps_sizes
[
device_index
]
+=
var_size
return
sanitize_nam
e
(
device_name
)
return
canonicaliz
e
(
device_name
)
def
__str__
(
self
):
def
__str__
(
self
):
return
"LeastLoadedDeviceSetter-{}"
.
format
(
self
.
worker_device
)
return
"LeastLoadedDeviceSetter-{}"
.
format
(
self
.
worker_device
)
...
...
tensorpack/libinfo.py
View file @
a8b72a87
...
@@ -8,16 +8,21 @@ try:
...
@@ -8,16 +8,21 @@ try:
import
cv2
# noqa
import
cv2
# noqa
if
int
(
cv2
.
__version__
.
split
(
'.'
)[
0
])
==
3
:
if
int
(
cv2
.
__version__
.
split
(
'.'
)[
0
])
==
3
:
cv2
.
ocl
.
setUseOpenCL
(
False
)
cv2
.
ocl
.
setUseOpenCL
(
False
)
# check if cv is built with cuda
# check if cv is built with cuda
or openmp
info
=
cv2
.
getBuildInformation
()
.
split
(
'
\n
'
)
info
=
cv2
.
getBuildInformation
()
.
split
(
'
\n
'
)
for
line
in
info
:
for
line
in
info
:
if
'use cuda'
in
line
.
lower
():
splits
=
line
.
split
()
answer
=
line
.
split
()[
-
1
]
.
lower
()
if
not
len
(
splits
):
if
answer
==
'yes'
:
continue
answer
=
splits
[
-
1
]
.
lower
()
if
answer
in
[
'yes'
,
'no'
]:
if
'cuda'
in
line
.
lower
()
and
answer
==
'yes'
:
# issue#1197
# issue#1197
print
(
"OpenCV is built with CUDA support. "
print
(
"OpenCV is built with CUDA support. "
"This may cause slow initialization or sometimes segfault with TensorFlow."
)
"This may cause slow initialization or sometimes segfault with TensorFlow."
)
break
if
answer
==
'openmp'
:
print
(
"OpenCV is built with OpenMP support. This usually results in poor performance. For details, see "
"https://github.com/tensorpack/benchmarks/blob/master/ImageNet/benchmark-opencv-resize.py"
)
except
(
ImportError
,
TypeError
):
except
(
ImportError
,
TypeError
):
pass
pass
...
@@ -41,9 +46,7 @@ os.environ['TF_GPU_THREAD_COUNT'] = '2'
...
@@ -41,9 +46,7 @@ os.environ['TF_GPU_THREAD_COUNT'] = '2'
try
:
try
:
import
tensorflow
as
tf
# noqa
import
tensorflow
as
tf
# noqa
_version
=
tf
.
__version__
.
split
(
'.'
)
_version
=
tf
.
__version__
.
split
(
'.'
)
assert
int
(
_version
[
0
])
>=
1
,
"TF>=1.0 is required!"
assert
int
(
_version
[
0
])
>=
1
and
int
(
_version
[
1
])
>=
3
,
"TF>=1.3 is required!"
if
int
(
_version
[
1
])
<
3
:
print
(
"TF<1.3 support will be removed after 2018-03-15! Actually many examples already require TF>=1.3."
)
_HAS_TF
=
True
_HAS_TF
=
True
except
ImportError
:
except
ImportError
:
_HAS_TF
=
False
_HAS_TF
=
False
...
...
tensorpack/utils/fs.py
View file @
a8b72a87
...
@@ -13,7 +13,7 @@ __all__ = ['mkdir_p', 'download', 'recursive_walk', 'get_dataset_path']
...
@@ -13,7 +13,7 @@ __all__ = ['mkdir_p', 'download', 'recursive_walk', 'get_dataset_path']
def
mkdir_p
(
dirname
):
def
mkdir_p
(
dirname
):
"""
M
ake a dir recursively, but do nothing if the dir exists
"""
Like "mkdir -p", m
ake a dir recursively, but do nothing if the dir exists
Args:
Args:
dirname(str):
dirname(str):
...
@@ -38,6 +38,13 @@ def download(url, dir, filename=None, expect_size=None):
...
@@ -38,6 +38,13 @@ def download(url, dir, filename=None, expect_size=None):
filename
=
url
.
split
(
'/'
)[
-
1
]
filename
=
url
.
split
(
'/'
)[
-
1
]
fpath
=
os
.
path
.
join
(
dir
,
filename
)
fpath
=
os
.
path
.
join
(
dir
,
filename
)
if
os
.
path
.
isfile
(
fpath
):
if
expect_size
is
not
None
and
os
.
stat
(
fpath
)
.
st_size
==
expect_size
:
logger
.
info
(
"File {} exists! Skip download."
.
format
(
filename
))
return
fpath
else
:
logger
.
warn
(
"File {} exists. Will overwrite with a new download!"
.
format
(
filename
))
def
hook
(
t
):
def
hook
(
t
):
last_b
=
[
0
]
last_b
=
[
0
]
...
@@ -62,7 +69,7 @@ def download(url, dir, filename=None, expect_size=None):
...
@@ -62,7 +69,7 @@ def download(url, dir, filename=None, expect_size=None):
logger
.
error
(
"You may have downloaded a broken file, or the upstream may have modified the file."
)
logger
.
error
(
"You may have downloaded a broken file, or the upstream may have modified the file."
)
# TODO human-readable size
# TODO human-readable size
print
(
'Succesfully downloaded '
+
filename
+
". "
+
str
(
size
)
+
' bytes.'
)
logger
.
info
(
'Succesfully downloaded '
+
filename
+
". "
+
str
(
size
)
+
' bytes.'
)
return
fpath
return
fpath
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment