Commit 9ca10d35 authored by Yuxin Wu's avatar Yuxin Wu

remove hard dependency on opencv

parent c723c5a4
...@@ -59,7 +59,7 @@ Dependencies: ...@@ -59,7 +59,7 @@ Dependencies:
+ Python 2 or 3 + Python 2 or 3
+ TensorFlow >= 1.0.0 (>=1.1.0 for Multi-GPU) + TensorFlow >= 1.0.0 (>=1.1.0 for Multi-GPU)
+ Python bindings for OpenCV + Python bindings for OpenCV (Optional, but required by a lot of features)
``` ```
pip install -U git+https://github.com/ppwwyyxx/tensorpack.git pip install -U git+https://github.com/ppwwyyxx/tensorpack.git
# or add `--user` to avoid system-wide installation. # or add `--user` to avoid system-wide installation.
......
...@@ -77,7 +77,6 @@ def get_feature(f): ...@@ -77,7 +77,6 @@ def get_feature(f):
class RawTIMIT(DataFlow): class RawTIMIT(DataFlow):
def __init__(self, dirname, label='phoneme'): def __init__(self, dirname, label='phoneme'):
self.dirname = dirname self.dirname = dirname
assert os.path.isdir(dirname), dirname assert os.path.isdir(dirname), dirname
...@@ -103,6 +102,7 @@ class RawTIMIT(DataFlow): ...@@ -103,6 +102,7 @@ class RawTIMIT(DataFlow):
def compute_mean_std(db, fname): def compute_mean_std(db, fname):
ds = LMDBDataPoint(db, shuffle=False) ds = LMDBDataPoint(db, shuffle=False)
ds.reset_state()
o = OnlineMoments() o = OnlineMoments()
with get_tqdm(total=ds.size()) as bar: with get_tqdm(total=ds.size()) as bar:
for dp in ds.get_data(): for dp in ds.get_data():
......
...@@ -3,7 +3,6 @@ ...@@ -3,7 +3,6 @@
# Author: Yuxin Wu <ppwwyyxx@gmail.com> # Author: Yuxin Wu <ppwwyyxx@gmail.com>
import os import os
import cv2
import numpy as np import numpy as np
from .base import Callback from .base import Callback
...@@ -65,3 +64,10 @@ class DumpParamAsImage(Callback): ...@@ -65,3 +64,10 @@ class DumpParamAsImage(Callback):
res = im * self.scale res = im * self.scale
res = np.clip(res, 0, 255) res = np.clip(res, 0, 255)
cv2.imwrite(fname, res.astype('uint8')) cv2.imwrite(fname, res.astype('uint8'))
try:
import cv2
except ImportError:
from ..utils.develop import create_dummy_class
DumpParamAsImage = create_dummy_class('DumpParamAsImage', 'cv2') # noqa
...@@ -5,7 +5,6 @@ ...@@ -5,7 +5,6 @@
import os import os
import glob import glob
import cv2
import numpy as np import numpy as np
from ...utils.fs import download, get_dataset_path from ...utils.fs import download, get_dataset_path
...@@ -90,9 +89,10 @@ class BSDS500(RNGDataFlow): ...@@ -90,9 +89,10 @@ class BSDS500(RNGDataFlow):
try: try:
from scipy.io import loadmat from scipy.io import loadmat
import cv2
except ImportError: except ImportError:
from ...utils.develop import create_dummy_class from ...utils.develop import create_dummy_class
BSDS500 = create_dummy_class('BSDS500', 'scipy.io') # noqa BSDS500 = create_dummy_class('BSDS500', ['scipy.io', 'cv2']) # noqa
if __name__ == '__main__': if __name__ == '__main__':
a = BSDS500('val') a = BSDS500('val')
......
...@@ -4,7 +4,6 @@ ...@@ -4,7 +4,6 @@
# Author: Yuxin Wu <ppwwyyxxc@gmail.com> # Author: Yuxin Wu <ppwwyyxxc@gmail.com>
import os import os
import tarfile import tarfile
import cv2
import six import six
import numpy as np import numpy as np
import xml.etree.ElementTree as ET import xml.etree.ElementTree as ET
...@@ -227,6 +226,12 @@ class ILSVRC12(RNGDataFlow): ...@@ -227,6 +226,12 @@ class ILSVRC12(RNGDataFlow):
return ret return ret
try:
import cv2
except ImportError:
from ...utils.develop import create_dummy_class
ILSVRC12 = create_dummy_class('ILSVRC12', 'cv2') # noqa
if __name__ == '__main__': if __name__ == '__main__':
meta = ILSVRCMeta() meta = ILSVRCMeta()
# print(meta.get_synset_words_1000()) # print(meta.get_synset_words_1000())
......
...@@ -5,7 +5,6 @@ ...@@ -5,7 +5,6 @@
import sys import sys
import os import os
import multiprocessing as mp import multiprocessing as mp
import cv2
from six.moves import range from six.moves import range
from .base import DataFlow from .base import DataFlow
...@@ -161,3 +160,9 @@ try: ...@@ -161,3 +160,9 @@ try:
except ImportError: except ImportError:
dump_dataflow_to_tfrecord = create_dummy_func( # noqa dump_dataflow_to_tfrecord = create_dummy_func( # noqa
'dump_dataflow_to_tfrecord', 'tensorflow') 'dump_dataflow_to_tfrecord', 'tensorflow')
try:
import cv2
except ImportError:
dump_dataflow_images = create_dummy_func( # noqa
'dump_dataflow_images', 'cv2')
...@@ -3,11 +3,9 @@ ...@@ -3,11 +3,9 @@
# Author: Yuxin Wu <ppwwyyxx@gmail.com> # Author: Yuxin Wu <ppwwyyxx@gmail.com>
import numpy as np import numpy as np
import cv2
import copy as copy_mod import copy as copy_mod
from .base import RNGDataFlow from .base import RNGDataFlow
from .common import MapDataComponent, MapData from .common import MapDataComponent, MapData
from .imgaug import AugmentorList
from ..utils import logger from ..utils import logger
from ..utils.argtools import shape2d from ..utils.argtools import shape2d
...@@ -139,3 +137,13 @@ class AugmentImageComponents(MapData): ...@@ -139,3 +137,13 @@ class AugmentImageComponents(MapData):
def reset_state(self): def reset_state(self):
self.ds.reset_state() self.ds.reset_state()
self.augs.reset_state() self.augs.reset_state()
try:
import cv2
from .imgaug import AugmentorList
except ImportError:
from ..utils.develop import create_dummy_class
ImageFromFile = create_dummy_class('ImageFromFile', 'cv2') # noqa
AugmentImageComponent = create_dummy_class('AugmentImageComponent', 'cv2') # noqa
AugmentImageComponents = create_dummy_class('AugmentImageComponents', 'cv2') # noqa
...@@ -17,7 +17,13 @@ def global_import(name): ...@@ -17,7 +17,13 @@ def global_import(name):
__all__.append(k) __all__.append(k)
for _, module_name, _ in iter_modules( try:
[os.path.dirname(__file__)]): import cv2 # noqa
if not module_name.startswith('_'): except ImportError:
global_import(module_name) from ...utils import logger
logger.warn("Cannot import 'cv2', therefore image augmentation is not available.")
else:
for _, module_name, _ in iter_modules(
[os.path.dirname(__file__)]):
if not module_name.startswith('_'):
global_import(module_name)
# issue#1924 may happen on old systems try:
import cv2 # noqa # issue#1924 may happen on old systems
import cv2 # noqa
except ImportError:
pass
import os import os
# issue#7378 may happen with custom opencv. It doesn't hurt to disable opencl # issue#7378 may happen with custom opencv. It doesn't hurt to disable opencl
......
...@@ -4,7 +4,6 @@ ...@@ -4,7 +4,6 @@
import six import six
import tensorflow as tf import tensorflow as tf
import cv2
import re import re
from six.moves import range from six.moves import range
...@@ -169,3 +168,10 @@ def add_moving_summary(v, *args, **kwargs): ...@@ -169,3 +168,10 @@ def add_moving_summary(v, *args, **kwargs):
tf.summary.scalar(name + '-summary', averager.average(c)) tf.summary.scalar(name + '-summary', averager.average(c))
tf.add_to_collection(coll, avg_maintain_op) tf.add_to_collection(coll, avg_maintain_op)
try:
import cv2
except ImportError:
from ..utils.develop import create_dummy_func
create_image_summary = create_dummy_func('create_image_summary', 'cv2') # noqa
...@@ -36,11 +36,14 @@ def create_dummy_func(func, dependency): ...@@ -36,11 +36,14 @@ def create_dummy_func(func, dependency):
Args: Args:
func (str): name of the function. func (str): name of the function.
dependency (str): name of the dependency. dependency (str or list[str]): name(s) of the dependency.
Returns: Returns:
function: a function object function: a function object
""" """
if isinstance(dependency, (list, str)):
dependency = ','.join(dependency)
def _dummy(*args, **kwargs): def _dummy(*args, **kwargs):
raise ImportError("Cannot import '{}', therefore '{}' is not available".format(dependency, func)) raise ImportError("Cannot import '{}', therefore '{}' is not available".format(dependency, func))
return _dummy return _dummy
......
...@@ -7,10 +7,14 @@ import numpy as np ...@@ -7,10 +7,14 @@ import numpy as np
import os import os
import sys import sys
import io import io
import cv2
from .fs import mkdir_p from .fs import mkdir_p
from .argtools import shape2d from .argtools import shape2d
try:
import cv2
except ImportError:
pass
__all__ = ['pyplot2img', 'interactive_imshow', __all__ = ['pyplot2img', 'interactive_imshow',
'stack_patches', 'gen_stack_patches', 'stack_patches', 'gen_stack_patches',
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment