Commit a8b72a87 authored by Yuxin Wu's avatar Yuxin Wu

Improve dataset download logic

parent 5854c7de
...@@ -10,7 +10,10 @@ from ...utils.fs import download, get_dataset_path ...@@ -10,7 +10,10 @@ from ...utils.fs import download, get_dataset_path
from ..base import RNGDataFlow from ..base import RNGDataFlow
__all__ = ['BSDS500'] __all__ = ['BSDS500']
DATA_URL = "http://www.eecs.berkeley.edu/Research/Projects/CS/vision/grouping/BSR/BSR_bsds500.tgz" DATA_URL = "http://www.eecs.berkeley.edu/Research/Projects/CS/vision/grouping/BSR/BSR_bsds500.tgz"
DATA_SIZE = 70763455
IMG_W, IMG_H = 481, 321 IMG_W, IMG_H = 481, 321
...@@ -35,7 +38,7 @@ class BSDS500(RNGDataFlow): ...@@ -35,7 +38,7 @@ class BSDS500(RNGDataFlow):
if data_dir is None: if data_dir is None:
data_dir = get_dataset_path('bsds500_data') data_dir = get_dataset_path('bsds500_data')
if not os.path.isdir(os.path.join(data_dir, 'BSR')): if not os.path.isdir(os.path.join(data_dir, 'BSR')):
download(DATA_URL, data_dir) download(DATA_URL, data_dir, expect_size=DATA_SIZE)
filename = DATA_URL.split('/')[-1] filename = DATA_URL.split('/')[-1]
filepath = os.path.join(data_dir, filename) filepath = os.path.join(data_dir, filename)
import tarfile import tarfile
......
...@@ -6,6 +6,7 @@ ...@@ -6,6 +6,7 @@
import os import os
import pickle import pickle
import numpy as np import numpy as np
import tarfile
import six import six
from six.moves import range from six.moves import range
...@@ -16,13 +17,12 @@ from ..base import RNGDataFlow ...@@ -16,13 +17,12 @@ from ..base import RNGDataFlow
__all__ = ['Cifar10', 'Cifar100'] __all__ = ['Cifar10', 'Cifar100']
DATA_URL_CIFAR_10 = 'http://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz' DATA_URL_CIFAR_10 = ('http://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz', 170498071)
DATA_URL_CIFAR_100 = 'http://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz' DATA_URL_CIFAR_100 = ('http://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz', 169001437)
def maybe_download_and_extract(dest_directory, cifar_classnum): def maybe_download_and_extract(dest_directory, cifar_classnum):
"""Download and extract the tarball from Alex's website. """Download and extract the tarball from Alex's website. Copied from tensorflow example """
copied from tensorflow example """
assert cifar_classnum == 10 or cifar_classnum == 100 assert cifar_classnum == 10 or cifar_classnum == 100
if cifar_classnum == 10: if cifar_classnum == 10:
cifar_foldername = 'cifar-10-batches-py' cifar_foldername = 'cifar-10-batches-py'
...@@ -33,10 +33,9 @@ def maybe_download_and_extract(dest_directory, cifar_classnum): ...@@ -33,10 +33,9 @@ def maybe_download_and_extract(dest_directory, cifar_classnum):
return return
else: else:
DATA_URL = DATA_URL_CIFAR_10 if cifar_classnum == 10 else DATA_URL_CIFAR_100 DATA_URL = DATA_URL_CIFAR_10 if cifar_classnum == 10 else DATA_URL_CIFAR_100
download(DATA_URL, dest_directory) filename = DATA_URL[0].split('/')[-1]
filename = DATA_URL.split('/')[-1]
filepath = os.path.join(dest_directory, filename) filepath = os.path.join(dest_directory, filename)
import tarfile download(DATA_URL[0], dest_directory, expect_size=DATA_URL[1])
tarfile.open(filepath, 'r:gz').extractall(dest_directory) tarfile.open(filepath, 'r:gz').extractall(dest_directory)
......
...@@ -14,7 +14,7 @@ from ..base import RNGDataFlow ...@@ -14,7 +14,7 @@ from ..base import RNGDataFlow
__all__ = ['ILSVRCMeta', 'ILSVRC12', 'ILSVRC12Files'] __all__ = ['ILSVRCMeta', 'ILSVRC12', 'ILSVRC12Files']
CAFFE_ILSVRC12_URL = "http://dl.caffe.berkeleyvision.org/caffe_ilsvrc12.tar.gz" CAFFE_ILSVRC12_URL = ("http://dl.caffe.berkeleyvision.org/caffe_ilsvrc12.tar.gz", 17858008)
class ILSVRCMeta(object): class ILSVRCMeta(object):
...@@ -53,7 +53,7 @@ class ILSVRCMeta(object): ...@@ -53,7 +53,7 @@ class ILSVRCMeta(object):
return dict(enumerate(lines)) return dict(enumerate(lines))
def _download_caffe_meta(self): def _download_caffe_meta(self):
fpath = download(CAFFE_ILSVRC12_URL, self.dir, expect_size=17858008) fpath = download(CAFFE_ILSVRC12_URL[0], self.dir, expect_size=CAFFE_ILSVRC12_URL[1])
tarfile.open(fpath, 'r:gz').extractall(self.dir) tarfile.open(fpath, 'r:gz').extractall(self.dir)
def get_image_list(self, name, dir_structure='original'): def get_image_list(self, name, dir_structure='original'):
......
...@@ -8,6 +8,7 @@ import tensorflow as tf ...@@ -8,6 +8,7 @@ import tensorflow as tf
from ..tfutils.varreplace import custom_getter_scope from ..tfutils.varreplace import custom_getter_scope
from ..tfutils.scope_utils import under_name_scope, cached_name_scope from ..tfutils.scope_utils import under_name_scope, cached_name_scope
from ..tfutils.common import get_tf_version_number
from ..utils.argtools import call_only_once from ..utils.argtools import call_only_once
from ..utils import logger from ..utils import logger
...@@ -66,13 +67,16 @@ class LeastLoadedDeviceSetter(object): ...@@ -66,13 +67,16 @@ class LeastLoadedDeviceSetter(object):
self.ps_sizes = [0] * len(self.ps_devices) self.ps_sizes = [0] * len(self.ps_devices)
def __call__(self, op): def __call__(self, op):
def sanitize_name(name): # tensorflow/tensorflow#11484 if get_tf_version_number() >= 1.8:
return tf.DeviceSpec.from_string(name).to_string() from tensorflow.python.training.device_util import canonicalize
else:
def canonicalize(name): # tensorflow/tensorflow#11484
return tf.DeviceSpec.from_string(name).to_string()
if op.device: if op.device:
return op.device return op.device
if op.type not in ['Variable', 'VariableV2']: if op.type not in ['Variable', 'VariableV2']:
return sanitize_name(self.worker_device) return canonicalize(self.worker_device)
device_index, _ = min(enumerate( device_index, _ = min(enumerate(
self.ps_sizes), key=operator.itemgetter(1)) self.ps_sizes), key=operator.itemgetter(1))
...@@ -84,7 +88,7 @@ class LeastLoadedDeviceSetter(object): ...@@ -84,7 +88,7 @@ class LeastLoadedDeviceSetter(object):
self.ps_sizes[device_index] += var_size self.ps_sizes[device_index] += var_size
return sanitize_name(device_name) return canonicalize(device_name)
def __str__(self): def __str__(self):
return "LeastLoadedDeviceSetter-{}".format(self.worker_device) return "LeastLoadedDeviceSetter-{}".format(self.worker_device)
......
...@@ -8,16 +8,21 @@ try: ...@@ -8,16 +8,21 @@ try:
import cv2 # noqa import cv2 # noqa
if int(cv2.__version__.split('.')[0]) == 3: if int(cv2.__version__.split('.')[0]) == 3:
cv2.ocl.setUseOpenCL(False) cv2.ocl.setUseOpenCL(False)
# check if cv is built with cuda # check if cv is built with cuda or openmp
info = cv2.getBuildInformation().split('\n') info = cv2.getBuildInformation().split('\n')
for line in info: for line in info:
if 'use cuda' in line.lower(): splits = line.split()
answer = line.split()[-1].lower() if not len(splits):
if answer == 'yes': continue
answer = splits[-1].lower()
if answer in ['yes', 'no']:
if 'cuda' in line.lower() and answer == 'yes':
# issue#1197 # issue#1197
print("OpenCV is built with CUDA support. " print("OpenCV is built with CUDA support. "
"This may cause slow initialization or sometimes segfault with TensorFlow.") "This may cause slow initialization or sometimes segfault with TensorFlow.")
break if answer == 'openmp':
print("OpenCV is built with OpenMP support. This usually results in poor performance. For details, see "
"https://github.com/tensorpack/benchmarks/blob/master/ImageNet/benchmark-opencv-resize.py")
except (ImportError, TypeError): except (ImportError, TypeError):
pass pass
...@@ -41,9 +46,7 @@ os.environ['TF_GPU_THREAD_COUNT'] = '2' ...@@ -41,9 +46,7 @@ os.environ['TF_GPU_THREAD_COUNT'] = '2'
try: try:
import tensorflow as tf # noqa import tensorflow as tf # noqa
_version = tf.__version__.split('.') _version = tf.__version__.split('.')
assert int(_version[0]) >= 1, "TF>=1.0 is required!" assert int(_version[0]) >= 1 and int(_version[1]) >= 3, "TF>=1.3 is required!"
if int(_version[1]) < 3:
print("TF<1.3 support will be removed after 2018-03-15! Actually many examples already require TF>=1.3.")
_HAS_TF = True _HAS_TF = True
except ImportError: except ImportError:
_HAS_TF = False _HAS_TF = False
......
...@@ -13,7 +13,7 @@ __all__ = ['mkdir_p', 'download', 'recursive_walk', 'get_dataset_path'] ...@@ -13,7 +13,7 @@ __all__ = ['mkdir_p', 'download', 'recursive_walk', 'get_dataset_path']
def mkdir_p(dirname): def mkdir_p(dirname):
""" Make a dir recursively, but do nothing if the dir exists """ Like "mkdir -p", make a dir recursively, but do nothing if the dir exists
Args: Args:
dirname(str): dirname(str):
...@@ -38,6 +38,13 @@ def download(url, dir, filename=None, expect_size=None): ...@@ -38,6 +38,13 @@ def download(url, dir, filename=None, expect_size=None):
filename = url.split('/')[-1] filename = url.split('/')[-1]
fpath = os.path.join(dir, filename) fpath = os.path.join(dir, filename)
if os.path.isfile(fpath):
if expect_size is not None and os.stat(fpath).st_size == expect_size:
logger.info("File {} exists! Skip download.".format(filename))
return fpath
else:
logger.warn("File {} exists. Will overwrite with a new download!".format(filename))
def hook(t): def hook(t):
last_b = [0] last_b = [0]
...@@ -62,7 +69,7 @@ def download(url, dir, filename=None, expect_size=None): ...@@ -62,7 +69,7 @@ def download(url, dir, filename=None, expect_size=None):
logger.error("You may have downloaded a broken file, or the upstream may have modified the file.") logger.error("You may have downloaded a broken file, or the upstream may have modified the file.")
# TODO human-readable size # TODO human-readable size
print('Succesfully downloaded ' + filename + ". " + str(size) + ' bytes.') logger.info('Succesfully downloaded ' + filename + ". " + str(size) + ' bytes.')
return fpath return fpath
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment