Commit 2c9c35d6 authored by Yuxin Wu's avatar Yuxin Wu

make scipy a opt-dependency

parent e7ede3eb
...@@ -3,8 +3,7 @@ ...@@ -3,8 +3,7 @@
# Author: Yuxin Wu <ppwwyyxx@gmail.com> # Author: Yuxin Wu <ppwwyyxx@gmail.com>
import os import os
import scipy.misc import cv2
from scipy.misc import imsave
import numpy as np import numpy as np
from .base import Callback from .base import Callback
...@@ -62,5 +61,5 @@ class DumpParamAsImage(Callback): ...@@ -62,5 +61,5 @@ class DumpParamAsImage(Callback):
res = im * self.scale res = im * self.scale
if self.clip: if self.clip:
res = np.clip(res, 0, 255) res = np.clip(res, 0, 255)
imsave(fname, res.astype('uint8')) cv2.imwrite(fname, res.astype('uint8'))
...@@ -6,14 +6,17 @@ ...@@ -6,14 +6,17 @@
import os, glob import os, glob
import cv2 import cv2
import numpy as np import numpy as np
from scipy.io import loadmat
from ...utils import logger, get_rng, get_dataset_dir from ...utils import logger, get_rng, get_dataset_dir
from ...utils.fs import download from ...utils.fs import download
from ..base import DataFlow from ..base import DataFlow
__all__ = ['BSDS500'] try:
from scipy.io import loadmat
__all__ = ['BSDS500']
except ImportError:
logger.error("Cannot import scipy. BSDS500 dataset won't be available!")
__all__ = []
DATA_URL = "http://www.eecs.berkeley.edu/Research/Projects/CS/vision/grouping/BSR/BSR_bsds500.tgz" DATA_URL = "http://www.eecs.berkeley.edu/Research/Projects/CS/vision/grouping/BSR/BSR_bsds500.tgz"
IMG_W, IMG_H = 481, 321 IMG_W, IMG_H = 481, 321
......
...@@ -6,14 +6,17 @@ ...@@ -6,14 +6,17 @@
import os import os
import random import random
import numpy as np import numpy as np
import scipy
import scipy.io
from six.moves import range from six.moves import range
from ...utils import logger, get_rng, get_dataset_dir from ...utils import logger, get_rng, get_dataset_dir
from ..base import DataFlow from ..base import DataFlow
__all__ = ['SVHNDigit'] try:
import scipy.io
__all__ = ['SVHNDigit']
except ImportError:
logger.error("Cannot import scipy. SVHNDigit dataset won't be available!")
__all__ = []
SVHN_URL = "http://ufldl.stanford.edu/housenumbers/" SVHN_URL = "http://ufldl.stanford.edu/housenumbers/"
......
...@@ -3,8 +3,8 @@ ...@@ -3,8 +3,8 @@
# Author: Yuxin Wu <ppwwyyxx@gmail.com> # Author: Yuxin Wu <ppwwyyxx@gmail.com>
import sys, os import sys, os
import cv2
import multiprocessing import multiprocessing
from scipy.misc import imsave
from ..utils.fs import mkdir_p from ..utils.fs import mkdir_p
...@@ -28,8 +28,7 @@ def dump_dataset_images(ds, dirname, max_count=None, index=0): ...@@ -28,8 +28,7 @@ def dump_dataset_images(ds, dirname, max_count=None, index=0):
if i > max_count: if i > max_count:
return return
img = dp[index] img = dp[index]
imsave(os.path.join(dirname, "{}.jpg".format(i)), img) cv2.imwrite(os.path.join(dirname, "{}.jpg".format(i)), img)
def dataflow_to_process_queue(ds, size, nr_consumer): def dataflow_to_process_queue(ds, size, nr_consumer):
""" """
......
...@@ -90,6 +90,7 @@ def get_dataset_dir(name): ...@@ -90,6 +90,7 @@ def get_dataset_dir(name):
if d: if d:
assert os.path.isdir(d) assert os.path.isdir(d)
else: else:
d = os.path.dirname(__file__) d = os.path.join(os.path.dirname(__file__), '..', 'dataflow', 'dataset')
logger.info("TENSORPACK_DATASET not set, using {} to keep dataset.".format(d))
return os.path.join(d, name) return os.path.join(d, name)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment