Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
a8b72a87
Commit
a8b72a87
authored
Jul 03, 2018
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
Improve dataset download logic
parent
5854c7de
Changes
6
Show whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
40 additions
and
24 deletions
+40
-24
tensorpack/dataflow/dataset/bsds500.py
tensorpack/dataflow/dataset/bsds500.py
+4
-1
tensorpack/dataflow/dataset/cifar.py
tensorpack/dataflow/dataset/cifar.py
+6
-7
tensorpack/dataflow/dataset/ilsvrc.py
tensorpack/dataflow/dataset/ilsvrc.py
+2
-2
tensorpack/graph_builder/utils.py
tensorpack/graph_builder/utils.py
+8
-4
tensorpack/libinfo.py
tensorpack/libinfo.py
+11
-8
tensorpack/utils/fs.py
tensorpack/utils/fs.py
+9
-2
No files found.
tensorpack/dataflow/dataset/bsds500.py
View file @
a8b72a87
...
...
@@ -10,7 +10,10 @@ from ...utils.fs import download, get_dataset_path
from
..base
import
RNGDataFlow
__all__
=
[
'BSDS500'
]
DATA_URL
=
"http://www.eecs.berkeley.edu/Research/Projects/CS/vision/grouping/BSR/BSR_bsds500.tgz"
DATA_SIZE
=
70763455
IMG_W
,
IMG_H
=
481
,
321
...
...
@@ -35,7 +38,7 @@ class BSDS500(RNGDataFlow):
if
data_dir
is
None
:
data_dir
=
get_dataset_path
(
'bsds500_data'
)
if
not
os
.
path
.
isdir
(
os
.
path
.
join
(
data_dir
,
'BSR'
)):
download
(
DATA_URL
,
data_dir
)
download
(
DATA_URL
,
data_dir
,
expect_size
=
DATA_SIZE
)
filename
=
DATA_URL
.
split
(
'/'
)[
-
1
]
filepath
=
os
.
path
.
join
(
data_dir
,
filename
)
import
tarfile
...
...
tensorpack/dataflow/dataset/cifar.py
View file @
a8b72a87
...
...
@@ -6,6 +6,7 @@
import
os
import
pickle
import
numpy
as
np
import
tarfile
import
six
from
six.moves
import
range
...
...
@@ -16,13 +17,12 @@ from ..base import RNGDataFlow
__all__
=
[
'Cifar10'
,
'Cifar100'
]
DATA_URL_CIFAR_10
=
'http://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz'
DATA_URL_CIFAR_100
=
'http://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz'
DATA_URL_CIFAR_10
=
(
'http://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz'
,
170498071
)
DATA_URL_CIFAR_100
=
(
'http://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz'
,
169001437
)
def
maybe_download_and_extract
(
dest_directory
,
cifar_classnum
):
"""Download and extract the tarball from Alex's website.
copied from tensorflow example """
"""Download and extract the tarball from Alex's website. Copied from tensorflow example """
assert
cifar_classnum
==
10
or
cifar_classnum
==
100
if
cifar_classnum
==
10
:
cifar_foldername
=
'cifar-10-batches-py'
...
...
@@ -33,10 +33,9 @@ def maybe_download_and_extract(dest_directory, cifar_classnum):
return
else
:
DATA_URL
=
DATA_URL_CIFAR_10
if
cifar_classnum
==
10
else
DATA_URL_CIFAR_100
download
(
DATA_URL
,
dest_directory
)
filename
=
DATA_URL
.
split
(
'/'
)[
-
1
]
filename
=
DATA_URL
[
0
]
.
split
(
'/'
)[
-
1
]
filepath
=
os
.
path
.
join
(
dest_directory
,
filename
)
import
tarfile
download
(
DATA_URL
[
0
],
dest_directory
,
expect_size
=
DATA_URL
[
1
])
tarfile
.
open
(
filepath
,
'r:gz'
)
.
extractall
(
dest_directory
)
...
...
tensorpack/dataflow/dataset/ilsvrc.py
View file @
a8b72a87
...
...
@@ -14,7 +14,7 @@ from ..base import RNGDataFlow
__all__
=
[
'ILSVRCMeta'
,
'ILSVRC12'
,
'ILSVRC12Files'
]
CAFFE_ILSVRC12_URL
=
"http://dl.caffe.berkeleyvision.org/caffe_ilsvrc12.tar.gz"
CAFFE_ILSVRC12_URL
=
(
"http://dl.caffe.berkeleyvision.org/caffe_ilsvrc12.tar.gz"
,
17858008
)
class
ILSVRCMeta
(
object
):
...
...
@@ -53,7 +53,7 @@ class ILSVRCMeta(object):
return
dict
(
enumerate
(
lines
))
def
_download_caffe_meta
(
self
):
fpath
=
download
(
CAFFE_ILSVRC12_URL
,
self
.
dir
,
expect_size
=
17858008
)
fpath
=
download
(
CAFFE_ILSVRC12_URL
[
0
],
self
.
dir
,
expect_size
=
CAFFE_ILSVRC12_URL
[
1
]
)
tarfile
.
open
(
fpath
,
'r:gz'
)
.
extractall
(
self
.
dir
)
def
get_image_list
(
self
,
name
,
dir_structure
=
'original'
):
...
...
tensorpack/graph_builder/utils.py
View file @
a8b72a87
...
...
@@ -8,6 +8,7 @@ import tensorflow as tf
from
..tfutils.varreplace
import
custom_getter_scope
from
..tfutils.scope_utils
import
under_name_scope
,
cached_name_scope
from
..tfutils.common
import
get_tf_version_number
from
..utils.argtools
import
call_only_once
from
..utils
import
logger
...
...
@@ -66,13 +67,16 @@ class LeastLoadedDeviceSetter(object):
self
.
ps_sizes
=
[
0
]
*
len
(
self
.
ps_devices
)
def
__call__
(
self
,
op
):
def
sanitize_name
(
name
):
# tensorflow/tensorflow#11484
if
get_tf_version_number
()
>=
1.8
:
from
tensorflow.python.training.device_util
import
canonicalize
else
:
def
canonicalize
(
name
):
# tensorflow/tensorflow#11484
return
tf
.
DeviceSpec
.
from_string
(
name
)
.
to_string
()
if
op
.
device
:
return
op
.
device
if
op
.
type
not
in
[
'Variable'
,
'VariableV2'
]:
return
sanitize_nam
e
(
self
.
worker_device
)
return
canonicaliz
e
(
self
.
worker_device
)
device_index
,
_
=
min
(
enumerate
(
self
.
ps_sizes
),
key
=
operator
.
itemgetter
(
1
))
...
...
@@ -84,7 +88,7 @@ class LeastLoadedDeviceSetter(object):
self
.
ps_sizes
[
device_index
]
+=
var_size
return
sanitize_nam
e
(
device_name
)
return
canonicaliz
e
(
device_name
)
def
__str__
(
self
):
return
"LeastLoadedDeviceSetter-{}"
.
format
(
self
.
worker_device
)
...
...
tensorpack/libinfo.py
View file @
a8b72a87
...
...
@@ -8,16 +8,21 @@ try:
import
cv2
# noqa
if
int
(
cv2
.
__version__
.
split
(
'.'
)[
0
])
==
3
:
cv2
.
ocl
.
setUseOpenCL
(
False
)
# check if cv is built with cuda
# check if cv is built with cuda
or openmp
info
=
cv2
.
getBuildInformation
()
.
split
(
'
\n
'
)
for
line
in
info
:
if
'use cuda'
in
line
.
lower
():
answer
=
line
.
split
()[
-
1
]
.
lower
()
if
answer
==
'yes'
:
splits
=
line
.
split
()
if
not
len
(
splits
):
continue
answer
=
splits
[
-
1
]
.
lower
()
if
answer
in
[
'yes'
,
'no'
]:
if
'cuda'
in
line
.
lower
()
and
answer
==
'yes'
:
# issue#1197
print
(
"OpenCV is built with CUDA support. "
"This may cause slow initialization or sometimes segfault with TensorFlow."
)
break
if
answer
==
'openmp'
:
print
(
"OpenCV is built with OpenMP support. This usually results in poor performance. For details, see "
"https://github.com/tensorpack/benchmarks/blob/master/ImageNet/benchmark-opencv-resize.py"
)
except
(
ImportError
,
TypeError
):
pass
...
...
@@ -41,9 +46,7 @@ os.environ['TF_GPU_THREAD_COUNT'] = '2'
try
:
import
tensorflow
as
tf
# noqa
_version
=
tf
.
__version__
.
split
(
'.'
)
assert
int
(
_version
[
0
])
>=
1
,
"TF>=1.0 is required!"
if
int
(
_version
[
1
])
<
3
:
print
(
"TF<1.3 support will be removed after 2018-03-15! Actually many examples already require TF>=1.3."
)
assert
int
(
_version
[
0
])
>=
1
and
int
(
_version
[
1
])
>=
3
,
"TF>=1.3 is required!"
_HAS_TF
=
True
except
ImportError
:
_HAS_TF
=
False
...
...
tensorpack/utils/fs.py
View file @
a8b72a87
...
...
@@ -13,7 +13,7 @@ __all__ = ['mkdir_p', 'download', 'recursive_walk', 'get_dataset_path']
def
mkdir_p
(
dirname
):
"""
M
ake a dir recursively, but do nothing if the dir exists
"""
Like "mkdir -p", m
ake a dir recursively, but do nothing if the dir exists
Args:
dirname(str):
...
...
@@ -38,6 +38,13 @@ def download(url, dir, filename=None, expect_size=None):
filename
=
url
.
split
(
'/'
)[
-
1
]
fpath
=
os
.
path
.
join
(
dir
,
filename
)
if
os
.
path
.
isfile
(
fpath
):
if
expect_size
is
not
None
and
os
.
stat
(
fpath
)
.
st_size
==
expect_size
:
logger
.
info
(
"File {} exists! Skip download."
.
format
(
filename
))
return
fpath
else
:
logger
.
warn
(
"File {} exists. Will overwrite with a new download!"
.
format
(
filename
))
def
hook
(
t
):
last_b
=
[
0
]
...
...
@@ -62,7 +69,7 @@ def download(url, dir, filename=None, expect_size=None):
logger
.
error
(
"You may have downloaded a broken file, or the upstream may have modified the file."
)
# TODO human-readable size
print
(
'Succesfully downloaded '
+
filename
+
". "
+
str
(
size
)
+
' bytes.'
)
logger
.
info
(
'Succesfully downloaded '
+
filename
+
". "
+
str
(
size
)
+
' bytes.'
)
return
fpath
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment