Commit 9ca10d35 authored by Yuxin Wu's avatar Yuxin Wu

remove hard dependency on opencv

parent c723c5a4
......@@ -59,7 +59,7 @@ Dependencies:
+ Python 2 or 3
+ TensorFlow >= 1.0.0 (>=1.1.0 for Multi-GPU)
+ Python bindings for OpenCV
+ Python bindings for OpenCV (Optional, but required by a lot of features)
```
pip install -U git+https://github.com/ppwwyyxx/tensorpack.git
# or add `--user` to avoid system-wide installation.
......
......@@ -77,7 +77,6 @@ def get_feature(f):
class RawTIMIT(DataFlow):
def __init__(self, dirname, label='phoneme'):
self.dirname = dirname
assert os.path.isdir(dirname), dirname
......@@ -103,6 +102,7 @@ class RawTIMIT(DataFlow):
def compute_mean_std(db, fname):
ds = LMDBDataPoint(db, shuffle=False)
ds.reset_state()
o = OnlineMoments()
with get_tqdm(total=ds.size()) as bar:
for dp in ds.get_data():
......
......@@ -3,7 +3,6 @@
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
import os
import cv2
import numpy as np
from .base import Callback
......@@ -65,3 +64,10 @@ class DumpParamAsImage(Callback):
res = im * self.scale
res = np.clip(res, 0, 255)
cv2.imwrite(fname, res.astype('uint8'))
try:
import cv2
except ImportError:
from ..utils.develop import create_dummy_class
DumpParamAsImage = create_dummy_class('DumpParamAsImage', 'cv2') # noqa
......@@ -5,7 +5,6 @@
import os
import glob
import cv2
import numpy as np
from ...utils.fs import download, get_dataset_path
......@@ -90,9 +89,10 @@ class BSDS500(RNGDataFlow):
try:
from scipy.io import loadmat
import cv2
except ImportError:
from ...utils.develop import create_dummy_class
BSDS500 = create_dummy_class('BSDS500', 'scipy.io') # noqa
BSDS500 = create_dummy_class('BSDS500', ['scipy.io', 'cv2']) # noqa
if __name__ == '__main__':
a = BSDS500('val')
......
......@@ -4,7 +4,6 @@
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
import os
import tarfile
import cv2
import six
import numpy as np
import xml.etree.ElementTree as ET
......@@ -227,6 +226,12 @@ class ILSVRC12(RNGDataFlow):
return ret
try:
import cv2
except ImportError:
from ...utils.develop import create_dummy_class
ILSVRC12 = create_dummy_class('ILSVRC12', 'cv2') # noqa
if __name__ == '__main__':
meta = ILSVRCMeta()
# print(meta.get_synset_words_1000())
......
......@@ -5,7 +5,6 @@
import sys
import os
import multiprocessing as mp
import cv2
from six.moves import range
from .base import DataFlow
......@@ -161,3 +160,9 @@ try:
except ImportError:
dump_dataflow_to_tfrecord = create_dummy_func( # noqa
'dump_dataflow_to_tfrecord', 'tensorflow')
try:
import cv2
except ImportError:
dump_dataflow_images = create_dummy_func( # noqa
'dump_dataflow_images', 'cv2')
......@@ -3,11 +3,9 @@
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
import numpy as np
import cv2
import copy as copy_mod
from .base import RNGDataFlow
from .common import MapDataComponent, MapData
from .imgaug import AugmentorList
from ..utils import logger
from ..utils.argtools import shape2d
......@@ -139,3 +137,13 @@ class AugmentImageComponents(MapData):
def reset_state(self):
self.ds.reset_state()
self.augs.reset_state()
try:
import cv2
from .imgaug import AugmentorList
except ImportError:
from ..utils.develop import create_dummy_class
ImageFromFile = create_dummy_class('ImageFromFile', 'cv2') # noqa
AugmentImageComponent = create_dummy_class('AugmentImageComponent', 'cv2') # noqa
AugmentImageComponents = create_dummy_class('AugmentImageComponents', 'cv2') # noqa
......@@ -17,7 +17,13 @@ def global_import(name):
__all__.append(k)
for _, module_name, _ in iter_modules(
try:
import cv2 # noqa
except ImportError:
from ...utils import logger
logger.warn("Cannot import 'cv2', therefore image augmentation is not available.")
else:
for _, module_name, _ in iter_modules(
[os.path.dirname(__file__)]):
if not module_name.startswith('_'):
global_import(module_name)
# issue#1924 may happen on old systems
import cv2 # noqa
try:
# issue#1924 may happen on old systems
import cv2 # noqa
except ImportError:
pass
import os
# issue#7378 may happen with custom opencv. It doesn't hurt to disable opencl
......
......@@ -4,7 +4,6 @@
import six
import tensorflow as tf
import cv2
import re
from six.moves import range
......@@ -169,3 +168,10 @@ def add_moving_summary(v, *args, **kwargs):
tf.summary.scalar(name + '-summary', averager.average(c))
tf.add_to_collection(coll, avg_maintain_op)
try:
import cv2
except ImportError:
from ..utils.develop import create_dummy_func
create_image_summary = create_dummy_func('create_image_summary', 'cv2') # noqa
......@@ -36,11 +36,14 @@ def create_dummy_func(func, dependency):
Args:
func (str): name of the function.
dependency (str): name of the dependency.
dependency (str or list[str]): name(s) of the dependency.
Returns:
function: a function object
"""
if isinstance(dependency, (list, str)):
dependency = ','.join(dependency)
def _dummy(*args, **kwargs):
raise ImportError("Cannot import '{}', therefore '{}' is not available".format(dependency, func))
return _dummy
......
......@@ -7,10 +7,14 @@ import numpy as np
import os
import sys
import io
import cv2
from .fs import mkdir_p
from .argtools import shape2d
try:
import cv2
except ImportError:
pass
__all__ = ['pyplot2img', 'interactive_imshow',
'stack_patches', 'gen_stack_patches',
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment