Commit 88af1f1d authored by Yuxin Wu's avatar Yuxin Wu

a better handling of optional import.

parent 6e24b953
...@@ -12,6 +12,7 @@ from collections import defaultdict ...@@ -12,6 +12,7 @@ from collections import defaultdict
import six import six
from six.moves import queue from six.moves import queue
import zmq
from tensorpack.models.common import disable_layer_logging from tensorpack.models.common import disable_layer_logging
from tensorpack.callbacks import Callback from tensorpack.callbacks import Callback
...@@ -25,12 +26,6 @@ __all__ = ['SimulatorProcess', 'SimulatorMaster', ...@@ -25,12 +26,6 @@ __all__ = ['SimulatorProcess', 'SimulatorMaster',
'SimulatorProcessStateExchange', 'SimulatorProcessSharedWeight', 'SimulatorProcessStateExchange', 'SimulatorProcessSharedWeight',
'TransitionExperience', 'WeightSync'] 'TransitionExperience', 'WeightSync']
try:
import zmq
except ImportError:
logger.warn_dependency('Simulator', 'zmq')
__all__ = []
class TransitionExperience(object): class TransitionExperience(object):
""" A transition of state, or experience""" """ A transition of state, or experience"""
......
...@@ -5,18 +5,6 @@ ...@@ -5,18 +5,6 @@
import time import time
from ..utils import logger
try:
import gym
# TODO
# gym.undo_logger_setup()
# https://github.com/openai/gym/pull/199
# not sure does it cause other problems
__all__ = ['GymEnv']
except ImportError:
logger.warn_dependency('GymEnv', 'gym')
__all__ = []
import threading import threading
from ..utils.fs import mkdir_p from ..utils.fs import mkdir_p
...@@ -24,6 +12,7 @@ from ..utils.stats import StatCounter ...@@ -24,6 +12,7 @@ from ..utils.stats import StatCounter
from .envbase import RLEnvironment, DiscreteActionSpace from .envbase import RLEnvironment, DiscreteActionSpace
__all__ = ['GymEnv']
_ENV_LOCK = threading.Lock() _ENV_LOCK = threading.Lock()
...@@ -84,6 +73,17 @@ class GymEnv(RLEnvironment): ...@@ -84,6 +73,17 @@ class GymEnv(RLEnvironment):
return DiscreteActionSpace(spc.n) return DiscreteActionSpace(spc.n)
try:
import gym
# TODO
# gym.undo_logger_setup()
# https://github.com/openai/gym/pull/199
# not sure does it cause other problems
except ImportError:
from ..utils.dependency import create_dummy_class
GymEnv = create_dummy_class('GymEnv', 'gym') # noqa
if __name__ == '__main__': if __name__ == '__main__':
env = GymEnv('Breakout-v0', viz=0.1) env = GymEnv('Breakout-v0', viz=0.1)
num = env.get_action_space().num_actions() num = env.get_action_space().num_actions()
......
...@@ -8,17 +8,11 @@ import glob ...@@ -8,17 +8,11 @@ import glob
import cv2 import cv2
import numpy as np import numpy as np
from ...utils import logger, get_dataset_path from ...utils import get_dataset_path
from ...utils.fs import download from ...utils.fs import download
from ..base import RNGDataFlow from ..base import RNGDataFlow
try: __all__ = ['BSDS500']
from scipy.io import loadmat
__all__ = ['BSDS500']
except ImportError:
logger.warn_dependency('BSDS500', 'scipy.io')
__all__ = []
DATA_URL = "http://www.eecs.berkeley.edu/Research/Projects/CS/vision/grouping/BSR/BSR_bsds500.tgz" DATA_URL = "http://www.eecs.berkeley.edu/Research/Projects/CS/vision/grouping/BSR/BSR_bsds500.tgz"
IMG_W, IMG_H = 481, 321 IMG_W, IMG_H = 481, 321
...@@ -95,6 +89,12 @@ class BSDS500(RNGDataFlow): ...@@ -95,6 +89,12 @@ class BSDS500(RNGDataFlow):
yield [self.data[k], self.label[k]] yield [self.data[k], self.label[k]]
try:
from scipy.io import loadmat
except ImportError:
from ...utils.dependency import create_dummy_class
BSDS500 = create_dummy_class('BSDS500', 'scipy.io') # noqa
if __name__ == '__main__': if __name__ == '__main__':
a = BSDS500('val') a = BSDS500('val')
for k in a.get_data(): for k in a.get_data():
......
...@@ -9,12 +9,7 @@ import numpy as np ...@@ -9,12 +9,7 @@ import numpy as np
from ...utils import logger, get_dataset_path from ...utils import logger, get_dataset_path
from ..base import RNGDataFlow from ..base import RNGDataFlow
try: __all__ = ['SVHNDigit']
import scipy.io
__all__ = ['SVHNDigit']
except ImportError:
logger.warn_dependency('SVHNDigit', 'scipy.io')
__all__ = []
SVHN_URL = "http://ufldl.stanford.edu/housenumbers/" SVHN_URL = "http://ufldl.stanford.edu/housenumbers/"
...@@ -73,6 +68,12 @@ class SVHNDigit(RNGDataFlow): ...@@ -73,6 +68,12 @@ class SVHNDigit(RNGDataFlow):
return np.concatenate((a.X, b.X, c.X)).mean(axis=0) return np.concatenate((a.X, b.X, c.X)).mean(axis=0)
try:
import scipy.io
except ImportError:
from ...utils.dependency import create_dummy_class
SVHNDigit = create_dummy_class('SVHNDigit', 'scipy.io') # noqa
if __name__ == '__main__': if __name__ == '__main__':
a = SVHNDigit('train') a = SVHNDigit('train')
b = SVHNDigit.get_per_pixel_mean() b = SVHNDigit.get_per_pixel_mean()
...@@ -15,13 +15,8 @@ from ..utils.concurrency import DIE ...@@ -15,13 +15,8 @@ from ..utils.concurrency import DIE
from ..utils.serialize import dumps from ..utils.serialize import dumps
from ..utils.fs import mkdir_p from ..utils.fs import mkdir_p
__all__ = ['dump_dataset_images', 'dataflow_to_process_queue'] __all__ = ['dump_dataset_images', 'dataflow_to_process_queue',
try: 'dump_dataflow_to_lmdb']
import lmdb
except ImportError:
logger.warn_dependency("dump_dataflow_to_lmdb", 'lmdb')
else:
__all__.extend(['dump_dataflow_to_lmdb'])
def dump_dataset_images(ds, dirname, max_count=None, index=0): def dump_dataset_images(ds, dirname, max_count=None, index=0):
...@@ -84,6 +79,13 @@ def dump_dataflow_to_lmdb(ds, lmdb_path): ...@@ -84,6 +79,13 @@ def dump_dataflow_to_lmdb(ds, lmdb_path):
db.sync() db.sync()
try:
import lmdb
except ImportError:
from ..utils.dependency import create_dummy_func
dump_dataflow_to_lmdb = create_dummy_func('dump_dataflow_to_lmdb', 'lmdb') # noqa
def dataflow_to_process_queue(ds, size, nr_consumer): def dataflow_to_process_queue(ds, size, nr_consumer):
""" """
Convert a DataFlow to a :class:`multiprocessing.Queue`. Convert a DataFlow to a :class:`multiprocessing.Queue`.
......
...@@ -13,28 +13,8 @@ from ..utils.serialize import loads ...@@ -13,28 +13,8 @@ from ..utils.serialize import loads
from ..utils.argtools import log_once from ..utils.argtools import log_once
from .base import RNGDataFlow from .base import RNGDataFlow
try: __all__ = ['HDF5Data', 'LMDBData', 'LMDBDataDecoder', 'LMDBDataPoint',
import h5py 'CaffeLMDB', 'SVMLightData']
except ImportError:
logger.warn_dependency("HDF5Data", 'h5py')
__all__ = []
else:
__all__ = ['HDF5Data']
try:
import lmdb
except ImportError:
logger.warn_dependency("LMDBData", 'lmdb')
else:
__all__.extend(['LMDBData', 'LMDBDataDecoder', 'LMDBDataPoint', 'CaffeLMDB'])
try:
import sklearn.datasets
except ImportError:
logger.warn_dependency('SVMLightData', 'sklearn')
else:
__all__.extend(['SVMLightData'])
""" """
Adapters for different data format. Adapters for different data format.
...@@ -214,3 +194,21 @@ class SVMLightData(RNGDataFlow): ...@@ -214,3 +194,21 @@ class SVMLightData(RNGDataFlow):
self.rng.shuffle(idxs) self.rng.shuffle(idxs)
for id in idxs: for id in idxs:
yield [self.X[id, :], self.y[id]] yield [self.X[id, :], self.y[id]]
from ..utils.dependency import create_dummy_class # noqa
try:
import h5py
except ImportError:
HDF5Data = create_dummy_class('HDF5Data', 'h5py') # noqa
try:
import lmdb
except ImportError:
for klass in ['LMDBData', 'LMDBDataDecoder', 'LMDBDataPoint', 'CaffeLMDB']:
globals()[klass] = create_dummy_class(klass, 'lmdb')
try:
import sklearn.datasets
except ImportError:
SVMLightData = create_dummy_class('SVMLightData', 'sklearn') # noqa
...@@ -8,6 +8,7 @@ import itertools ...@@ -8,6 +8,7 @@ import itertools
from six.moves import range, zip from six.moves import range, zip
import uuid import uuid
import os import os
import zmq
from .base import ProxyDataFlow from .base import ProxyDataFlow
from ..utils.concurrency import (ensure_proc_terminate, from ..utils.concurrency import (ensure_proc_terminate,
...@@ -16,13 +17,7 @@ from ..utils.serialize import loads, dumps ...@@ -16,13 +17,7 @@ from ..utils.serialize import loads, dumps
from ..utils import logger from ..utils import logger
from ..utils.gpu import change_gpu from ..utils.gpu import change_gpu
__all__ = ['PrefetchData', 'BlockParallel'] __all__ = ['PrefetchData', 'BlockParallel', 'PrefetchDataZMQ', 'PrefetchOnGPUs']
try:
import zmq
except ImportError:
logger.warn_dependency('PrefetchDataZMQ', 'zmq')
else:
__all__.extend(['PrefetchDataZMQ', 'PrefetchOnGPUs'])
class PrefetchProcess(mp.Process): class PrefetchProcess(mp.Process):
......
...@@ -3,22 +3,14 @@ ...@@ -3,22 +3,14 @@
# File: tf_func.py # File: tf_func.py
# Author: Yuxin Wu <ppwwyyxxc@gmail.com> # Author: Yuxin Wu <ppwwyyxxc@gmail.com>
import tensorflow as tf
from .base import ProxyDataFlow from .base import ProxyDataFlow
from ..utils import logger
try:
import tensorflow as tf
except ImportError:
logger.warn_dependency('TFFuncMapper', 'tensorflow')
__all__ = []
else:
__all__ = []
""" This file was deprecated """ """ This file was deprecated """
class TFFuncMapper(ProxyDataFlow): class TFFuncMapper(ProxyDataFlow):
def __init__(self, ds, def __init__(self, ds,
get_placeholders, symbf, apply_symbf_on_dp, device='/cpu:0'): get_placeholders, symbf, apply_symbf_on_dp, device='/cpu:0'):
""" """
......
...@@ -9,21 +9,10 @@ from six.moves import queue, range ...@@ -9,21 +9,10 @@ from six.moves import queue, range
from ..utils.concurrency import DIE, StoppableThread from ..utils.concurrency import DIE, StoppableThread
from ..tfutils.modelutils import describe_model from ..tfutils.modelutils import describe_model
from ..utils import logger
from .base import OfflinePredictor, AsyncPredictorBase from .base import OfflinePredictor, AsyncPredictorBase
try: __all__ = ['MultiProcessPredictWorker', 'MultiProcessQueuePredictWorker',
if six.PY2: 'MultiThreadAsyncPredictor']
from tornado.concurrent import Future
else:
from concurrent.futures import Future
except ImportError:
logger.warn_dependency('MultiThreadAsyncPredictor', 'tornado.concurrent')
__all__ = ['MultiProcessPredictWorker', 'MultiProcessQueuePredictWorker']
else:
__all__ = ['MultiProcessPredictWorker', 'MultiProcessQueuePredictWorker',
'MultiThreadAsyncPredictor']
class MultiProcessPredictWorker(multiprocessing.Process): class MultiProcessPredictWorker(multiprocessing.Process):
...@@ -171,3 +160,13 @@ class MultiThreadAsyncPredictor(AsyncPredictorBase): ...@@ -171,3 +160,13 @@ class MultiThreadAsyncPredictor(AsyncPredictorBase):
f.add_done_callback(callback) f.add_done_callback(callback)
self.input_queue.put((dp, f)) self.input_queue.put((dp, f))
return f return f
try:
if six.PY2:
from tornado.concurrent import Future
else:
from concurrent.futures import Future
except ImportError:
from ..utils.dependency import create_dummy_class
MultiThreadAsyncPredictor = create_dummy_class('MultiThreadAsyncPredictor', 'tornado.concurrent') # noqa
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# File: dependency.py
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
""" Utilities to handle dependency """
__all__ = ['create_dummy_func', 'create_dummy_class']
def create_dummy_class(klass, dependency):
"""
When a dependency of a class is not available, create a dummy class which throws ImportError when used.
Args:
klass (str): name of the class.
dependency (str): name of the dependency.
Returns:
class: a class object
"""
class _Dummy(object):
def __init__(self, *args, **kwargs):
raise ImportError("Cannot import '{}', therefore '{}' is not available".format(dependency, klass))
return _Dummy
def create_dummy_func(func, dependency):
"""
When a dependency of a function is not available, create a dummy function which throws ImportError when used.
Args:
func (str): name of the function.
dependency (str): name of the dependency.
Returns:
function: a function object
"""
def _dummy(*args, **kwargs):
raise ImportError("Cannot import '{}', therefore '{}' is not available".format(dependency, func))
return _dummy
...@@ -11,8 +11,7 @@ from datetime import datetime ...@@ -11,8 +11,7 @@ from datetime import datetime
from six.moves import input from six.moves import input
import sys import sys
__all__ = ['set_logger_dir', 'disable_logger', 'auto_set_dir', __all__ = ['set_logger_dir', 'disable_logger', 'auto_set_dir']
'warn_dependency']
class _MyFormatter(logging.Formatter): class _MyFormatter(logging.Formatter):
...@@ -128,8 +127,3 @@ def auto_set_dir(action=None, overwrite=False): ...@@ -128,8 +127,3 @@ def auto_set_dir(action=None, overwrite=False):
os.path.join('train_log', os.path.join('train_log',
basename[:basename.rfind('.')]), basename[:basename.rfind('.')]),
action=action) action=action)
def warn_dependency(name, dependencies):
""" Print warning about an import failure due to missing dependencies. """
warn("Failed to import '{}', {} won't be available'".format(dependencies, name)) # noqa: F821
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment