Commit 88af1f1d authored by Yuxin Wu's avatar Yuxin Wu

a better handling of optional import.

parent 6e24b953
......@@ -12,6 +12,7 @@ from collections import defaultdict
import six
from six.moves import queue
import zmq
from tensorpack.models.common import disable_layer_logging
from tensorpack.callbacks import Callback
......@@ -25,12 +26,6 @@ __all__ = ['SimulatorProcess', 'SimulatorMaster',
'SimulatorProcessStateExchange', 'SimulatorProcessSharedWeight',
'TransitionExperience', 'WeightSync']
try:
import zmq
except ImportError:
logger.warn_dependency('Simulator', 'zmq')
__all__ = []
class TransitionExperience(object):
""" A transition of state, or experience"""
......
......@@ -5,18 +5,6 @@
import time
from ..utils import logger
try:
import gym
# TODO
# gym.undo_logger_setup()
# https://github.com/openai/gym/pull/199
# not sure does it cause other problems
__all__ = ['GymEnv']
except ImportError:
logger.warn_dependency('GymEnv', 'gym')
__all__ = []
import threading
from ..utils.fs import mkdir_p
......@@ -24,6 +12,7 @@ from ..utils.stats import StatCounter
from .envbase import RLEnvironment, DiscreteActionSpace
__all__ = ['GymEnv']
_ENV_LOCK = threading.Lock()
......@@ -84,6 +73,17 @@ class GymEnv(RLEnvironment):
return DiscreteActionSpace(spc.n)
try:
import gym
# TODO
# gym.undo_logger_setup()
# https://github.com/openai/gym/pull/199
# not sure does it cause other problems
except ImportError:
from ..utils.dependency import create_dummy_class
GymEnv = create_dummy_class('GymEnv', 'gym') # noqa
if __name__ == '__main__':
env = GymEnv('Breakout-v0', viz=0.1)
num = env.get_action_space().num_actions()
......
......@@ -8,17 +8,11 @@ import glob
import cv2
import numpy as np
from ...utils import logger, get_dataset_path
from ...utils import get_dataset_path
from ...utils.fs import download
from ..base import RNGDataFlow
try:
from scipy.io import loadmat
__all__ = ['BSDS500']
except ImportError:
logger.warn_dependency('BSDS500', 'scipy.io')
__all__ = []
__all__ = ['BSDS500']
DATA_URL = "http://www.eecs.berkeley.edu/Research/Projects/CS/vision/grouping/BSR/BSR_bsds500.tgz"
IMG_W, IMG_H = 481, 321
......@@ -95,6 +89,12 @@ class BSDS500(RNGDataFlow):
yield [self.data[k], self.label[k]]
try:
from scipy.io import loadmat
except ImportError:
from ...utils.dependency import create_dummy_class
BSDS500 = create_dummy_class('BSDS500', 'scipy.io') # noqa
if __name__ == '__main__':
a = BSDS500('val')
for k in a.get_data():
......
......@@ -9,12 +9,7 @@ import numpy as np
from ...utils import logger, get_dataset_path
from ..base import RNGDataFlow
try:
import scipy.io
__all__ = ['SVHNDigit']
except ImportError:
logger.warn_dependency('SVHNDigit', 'scipy.io')
__all__ = []
__all__ = ['SVHNDigit']
SVHN_URL = "http://ufldl.stanford.edu/housenumbers/"
......@@ -73,6 +68,12 @@ class SVHNDigit(RNGDataFlow):
return np.concatenate((a.X, b.X, c.X)).mean(axis=0)
try:
import scipy.io
except ImportError:
from ...utils.dependency import create_dummy_class
SVHNDigit = create_dummy_class('SVHNDigit', 'scipy.io') # noqa
if __name__ == '__main__':
a = SVHNDigit('train')
b = SVHNDigit.get_per_pixel_mean()
......@@ -15,13 +15,8 @@ from ..utils.concurrency import DIE
from ..utils.serialize import dumps
from ..utils.fs import mkdir_p
__all__ = ['dump_dataset_images', 'dataflow_to_process_queue']
try:
import lmdb
except ImportError:
logger.warn_dependency("dump_dataflow_to_lmdb", 'lmdb')
else:
__all__.extend(['dump_dataflow_to_lmdb'])
__all__ = ['dump_dataset_images', 'dataflow_to_process_queue',
'dump_dataflow_to_lmdb']
def dump_dataset_images(ds, dirname, max_count=None, index=0):
......@@ -84,6 +79,13 @@ def dump_dataflow_to_lmdb(ds, lmdb_path):
db.sync()
try:
import lmdb
except ImportError:
from ..utils.dependency import create_dummy_func
dump_dataflow_to_lmdb = create_dummy_func('dump_dataflow_to_lmdb', 'lmdb') # noqa
def dataflow_to_process_queue(ds, size, nr_consumer):
"""
Convert a DataFlow to a :class:`multiprocessing.Queue`.
......
......@@ -13,28 +13,8 @@ from ..utils.serialize import loads
from ..utils.argtools import log_once
from .base import RNGDataFlow
try:
import h5py
except ImportError:
logger.warn_dependency("HDF5Data", 'h5py')
__all__ = []
else:
__all__ = ['HDF5Data']
try:
import lmdb
except ImportError:
logger.warn_dependency("LMDBData", 'lmdb')
else:
__all__.extend(['LMDBData', 'LMDBDataDecoder', 'LMDBDataPoint', 'CaffeLMDB'])
try:
import sklearn.datasets
except ImportError:
logger.warn_dependency('SVMLightData', 'sklearn')
else:
__all__.extend(['SVMLightData'])
__all__ = ['HDF5Data', 'LMDBData', 'LMDBDataDecoder', 'LMDBDataPoint',
'CaffeLMDB', 'SVMLightData']
"""
Adapters for different data format.
......@@ -214,3 +194,21 @@ class SVMLightData(RNGDataFlow):
self.rng.shuffle(idxs)
for id in idxs:
yield [self.X[id, :], self.y[id]]
from ..utils.dependency import create_dummy_class # noqa
try:
import h5py
except ImportError:
HDF5Data = create_dummy_class('HDF5Data', 'h5py') # noqa
try:
import lmdb
except ImportError:
for klass in ['LMDBData', 'LMDBDataDecoder', 'LMDBDataPoint', 'CaffeLMDB']:
globals()[klass] = create_dummy_class(klass, 'lmdb')
try:
import sklearn.datasets
except ImportError:
SVMLightData = create_dummy_class('SVMLightData', 'sklearn') # noqa
......@@ -8,6 +8,7 @@ import itertools
from six.moves import range, zip
import uuid
import os
import zmq
from .base import ProxyDataFlow
from ..utils.concurrency import (ensure_proc_terminate,
......@@ -16,13 +17,7 @@ from ..utils.serialize import loads, dumps
from ..utils import logger
from ..utils.gpu import change_gpu
__all__ = ['PrefetchData', 'BlockParallel']
try:
import zmq
except ImportError:
logger.warn_dependency('PrefetchDataZMQ', 'zmq')
else:
__all__.extend(['PrefetchDataZMQ', 'PrefetchOnGPUs'])
__all__ = ['PrefetchData', 'BlockParallel', 'PrefetchDataZMQ', 'PrefetchOnGPUs']
class PrefetchProcess(mp.Process):
......
......@@ -3,22 +3,14 @@
# File: tf_func.py
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
import tensorflow as tf
from .base import ProxyDataFlow
from ..utils import logger
try:
import tensorflow as tf
except ImportError:
logger.warn_dependency('TFFuncMapper', 'tensorflow')
__all__ = []
else:
__all__ = []
""" This file was deprecated """
class TFFuncMapper(ProxyDataFlow):
def __init__(self, ds,
get_placeholders, symbf, apply_symbf_on_dp, device='/cpu:0'):
"""
......
......@@ -9,20 +9,9 @@ from six.moves import queue, range
from ..utils.concurrency import DIE, StoppableThread
from ..tfutils.modelutils import describe_model
from ..utils import logger
from .base import OfflinePredictor, AsyncPredictorBase
try:
if six.PY2:
from tornado.concurrent import Future
else:
from concurrent.futures import Future
except ImportError:
logger.warn_dependency('MultiThreadAsyncPredictor', 'tornado.concurrent')
__all__ = ['MultiProcessPredictWorker', 'MultiProcessQueuePredictWorker']
else:
__all__ = ['MultiProcessPredictWorker', 'MultiProcessQueuePredictWorker',
__all__ = ['MultiProcessPredictWorker', 'MultiProcessQueuePredictWorker',
'MultiThreadAsyncPredictor']
......@@ -171,3 +160,13 @@ class MultiThreadAsyncPredictor(AsyncPredictorBase):
f.add_done_callback(callback)
self.input_queue.put((dp, f))
return f
try:
if six.PY2:
from tornado.concurrent import Future
else:
from concurrent.futures import Future
except ImportError:
from ..utils.dependency import create_dummy_class
MultiThreadAsyncPredictor = create_dummy_class('MultiThreadAsyncPredictor', 'tornado.concurrent') # noqa
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# File: dependency.py
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
""" Utilities to handle dependency """
__all__ = ['create_dummy_func', 'create_dummy_class']
def create_dummy_class(klass, dependency):
"""
When a dependency of a class is not available, create a dummy class which throws ImportError when used.
Args:
klass (str): name of the class.
dependency (str): name of the dependency.
Returns:
class: a class object
"""
class _Dummy(object):
def __init__(self, *args, **kwargs):
raise ImportError("Cannot import '{}', therefore '{}' is not available".format(dependency, klass))
return _Dummy
def create_dummy_func(func, dependency):
"""
When a dependency of a function is not available, create a dummy function which throws ImportError when used.
Args:
func (str): name of the function.
dependency (str): name of the dependency.
Returns:
function: a function object
"""
def _dummy(*args, **kwargs):
raise ImportError("Cannot import '{}', therefore '{}' is not available".format(dependency, func))
return _dummy
......@@ -11,8 +11,7 @@ from datetime import datetime
from six.moves import input
import sys
__all__ = ['set_logger_dir', 'disable_logger', 'auto_set_dir',
'warn_dependency']
__all__ = ['set_logger_dir', 'disable_logger', 'auto_set_dir']
class _MyFormatter(logging.Formatter):
......@@ -128,8 +127,3 @@ def auto_set_dir(action=None, overwrite=False):
os.path.join('train_log',
basename[:basename.rfind('.')]),
action=action)
def warn_dependency(name, dependencies):
""" Print warning about an import failure due to missing dependencies. """
warn("Failed to import '{}', {} won't be available'".format(dependencies, name)) # noqa: F821
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment