Commit eecb5803 authored by Yuxin Wu's avatar Yuxin Wu

hide utils.utils.* from automatic import

parent 94a01039
...@@ -12,6 +12,7 @@ import argparse ...@@ -12,6 +12,7 @@ import argparse
from tensorpack import * from tensorpack import *
from tensorpack.utils.argtools import memoized from tensorpack.utils.argtools import memoized
from tensorpack.utils.stats import OnlineMoments from tensorpack.utils.stats import OnlineMoments
from tensorpack.utils.utils import get_tqdm
import bob.ap import bob.ap
CHARSET = set(string.ascii_lowercase + ' ') CHARSET = set(string.ascii_lowercase + ' ')
......
...@@ -35,7 +35,6 @@ class CharRNNData(RNGDataFlow): ...@@ -35,7 +35,6 @@ class CharRNNData(RNGDataFlow):
def __init__(self, input_file, size): def __init__(self, input_file, size):
self.seq_length = param.seq_len self.seq_length = param.seq_len
self._size = size self._size = size
self.rng = get_rng(self)
logger.info("Loading corpus...") logger.info("Loading corpus...")
# preprocess data # preprocess data
......
...@@ -11,7 +11,8 @@ from collections import deque ...@@ -11,7 +11,8 @@ from collections import deque
import threading import threading
import six import six
from six.moves import range from six.moves import range
from tensorpack.utils import (get_rng, logger, execute_only_once) from tensorpack.utils import logger
from tensorpack.utils.utils import get_rng, execute_only_once
from tensorpack.utils.fs import get_dataset_path from tensorpack.utils.fs import get_dataset_path
from tensorpack.utils.stats import StatCounter from tensorpack.utils.stats import StatCounter
......
...@@ -11,7 +11,8 @@ import six ...@@ -11,7 +11,8 @@ import six
from six.moves import queue, range from six.moves import queue, range
from tensorpack.dataflow import DataFlow from tensorpack.dataflow import DataFlow
from tensorpack.utils import logger, get_tqdm, get_rng from tensorpack.utils import logger
from tensorpack.utils.utils import get_tqdm, get_rng
from tensorpack.utils.concurrency import LoopThread, ShareSessionThread from tensorpack.utils.concurrency import LoopThread, ShareSessionThread
from tensorpack.callbacks.base import Callback from tensorpack.callbacks.base import Callback
......
...@@ -3,17 +3,16 @@ ...@@ -3,17 +3,16 @@
# File: disturb.py # File: disturb.py
# Author: Yuxin Wu <ppwwyyxxc@gmail.com> # Author: Yuxin Wu <ppwwyyxxc@gmail.com>
from tensorpack import ProxyDataFlow, get_rng from tensorpack.dataflow import ProxyDataFlow, RNGDataFlow
class DisturbLabel(ProxyDataFlow): class DisturbLabel(ProxyDataFlow, RNGDataFlow):
def __init__(self, ds, prob): def __init__(self, ds, prob):
super(DisturbLabel, self).__init__(ds) super(DisturbLabel, self).__init__(ds)
self.prob = prob self.prob = prob
def reset_state(self): def reset_state(self):
super(DisturbLabel, self).reset_state() super(DisturbLabel, self).reset_state()
self.rng = get_rng(self)
def get_data(self): def get_data(self):
for dp in self.ds.get_data(): for dp in self.ds.get_data():
......
...@@ -7,7 +7,7 @@ ...@@ -7,7 +7,7 @@
from abc import abstractmethod, ABCMeta from abc import abstractmethod, ABCMeta
from collections import defaultdict from collections import defaultdict
import six import six
from ..utils import get_rng from ..utils.utils import get_rng
__all__ = ['RLEnvironment', 'ProxyPlayer', __all__ = ['RLEnvironment', 'ProxyPlayer',
'DiscreteActionSpace'] 'DiscreteActionSpace']
......
...@@ -90,7 +90,7 @@ if __name__ == '__main__': ...@@ -90,7 +90,7 @@ if __name__ == '__main__':
env = GymEnv('FlappyBird-v0', viz=0.1) env = GymEnv('FlappyBird-v0', viz=0.1)
num = env.get_action_space().num_actions() num = env.get_action_space().num_actions()
from ..utils import get_rng from ..utils.utils import get_rng
rng = get_rng(num) rng = get_rng(num)
while True: while True:
act = rng.choice(range(num)) act = rng.choice(range(num))
......
...@@ -6,7 +6,7 @@ ...@@ -6,7 +6,7 @@
from abc import abstractmethod, ABCMeta from abc import abstractmethod, ABCMeta
import six import six
from ..utils import get_rng from ..utils.utils import get_rng
__all__ = ['DataFlow', 'ProxyDataFlow', 'RNGDataFlow', 'DataFlowTerminated'] __all__ = ['DataFlow', 'ProxyDataFlow', 'RNGDataFlow', 'DataFlowTerminated']
......
...@@ -10,8 +10,8 @@ from termcolor import colored ...@@ -10,8 +10,8 @@ from termcolor import colored
from collections import deque, defaultdict from collections import deque, defaultdict
from six.moves import range, map from six.moves import range, map
from .base import DataFlow, ProxyDataFlow, RNGDataFlow from .base import DataFlow, ProxyDataFlow, RNGDataFlow
from ..utils import logger, get_rng from ..utils import logger
from ..utils.utils import get_tqdm from ..utils.utils import get_tqdm, get_rng
from ..utils.develop import log_deprecated from ..utils.develop import log_deprecated
__all__ = ['TestDataSpeed', 'PrintData', 'BatchData', 'BatchDataByShape', 'FixedSizeData', 'MapData', __all__ = ['TestDataSpeed', 'PrintData', 'BatchData', 'BatchDataByShape', 'FixedSizeData', 'MapData',
......
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
# Author: Yuxin Wu <ppwwyyxx@gmail.com> # Author: Yuxin Wu <ppwwyyxx@gmail.com>
from abc import abstractmethod, ABCMeta from abc import abstractmethod, ABCMeta
from ...utils import get_rng from ...utils.utils import get_rng
import six import six
from six.moves import zip from six.moves import zip
......
...@@ -6,7 +6,8 @@ ...@@ -6,7 +6,8 @@
import time import time
from collections import deque from collections import deque
from .base import DataFlow from .base import DataFlow
from ..utils import logger, get_tqdm from ..utils import logger
from ..utils.utils import get_tqdm
from ..utils.serialize import dumps, loads, dumps_for_tfop from ..utils.serialize import dumps, loads, dumps_for_tfop
try: try:
import zmq import zmq
......
...@@ -12,7 +12,8 @@ import six ...@@ -12,7 +12,8 @@ import six
from ..dataflow import DataFlow from ..dataflow import DataFlow
from ..dataflow.dftools import dump_dataflow_to_process_queue from ..dataflow.dftools import dump_dataflow_to_process_queue
from ..utils.concurrency import ensure_proc_terminate, OrderedResultGatherProc, DIE from ..utils.concurrency import ensure_proc_terminate, OrderedResultGatherProc, DIE
from ..utils import logger, get_tqdm from ..utils import logger
from ..utils.utils import get_tqdm
from ..utils.gpu import change_gpu from ..utils.gpu import change_gpu
from .concurrency import MultiProcessQueuePredictWorker from .concurrency import MultiProcessQueuePredictWorker
......
...@@ -13,26 +13,13 @@ These utils should be irrelevant to tensorflow. ...@@ -13,26 +13,13 @@ These utils should be irrelevant to tensorflow.
__all__ = [] __all__ = []
def _global_import(name):
p = __import__(name, globals(), None, level=1)
lst = p.__all__ if '__all__' in dir(p) else dir(p)
for k in lst:
globals()[k] = p.__dict__[k]
__all__.append(k)
_TO_IMPORT = set([
'utils',
])
# this two functions for back-compat only # this two functions for back-compat only
def get_nr_gpu(): def get_nr_gpu():
from .gpu import get_nr_gpu from .gpu import get_nr_gpu as gg
logger.warn( # noqa logger.warn( # noqa
"get_nr_gpu will not be automatically imported any more! " "get_nr_gpu will not be automatically imported any more! "
"Please do `from tensorpack.utils.gpu import get_nr_gpu`") "Please do `from tensorpack.utils.gpu import get_nr_gpu`")
return get_nr_gpu() return gg()
def change_gpu(val): def change_gpu(val):
...@@ -43,6 +30,14 @@ def change_gpu(val): ...@@ -43,6 +30,14 @@ def change_gpu(val):
return cg(val) return cg(val)
def get_rng(obj=None):
from .utils import get_rng as gr
logger.warn( # noqa
"get_rng will not be automatically imported any more! "
"Please do `from tensorpack.utils.utils import get_rng`")
return gr(obj)
_CURR_DIR = os.path.dirname(__file__) _CURR_DIR = os.path.dirname(__file__)
for _, module_name, _ in iter_modules( for _, module_name, _ in iter_modules(
[_CURR_DIR]): [_CURR_DIR]):
...@@ -51,8 +46,6 @@ for _, module_name, _ in iter_modules( ...@@ -51,8 +46,6 @@ for _, module_name, _ in iter_modules(
continue continue
if module_name.startswith('_'): if module_name.startswith('_'):
continue continue
if module_name in _TO_IMPORT:
_global_import(module_name)
__all__.extend([ __all__.extend([
'logger', 'logger',
'get_nr_gpu', 'change_gpu']) 'get_nr_gpu', 'change_gpu', 'get_rng'])
...@@ -14,7 +14,6 @@ import numpy as np ...@@ -14,7 +14,6 @@ import numpy as np
__all__ = ['change_env', __all__ = ['change_env',
'get_rng', 'get_rng',
'fix_rng_seed', 'fix_rng_seed',
# 'get_tqdm_kwargs',
'get_tqdm', 'get_tqdm',
'execute_only_once', 'execute_only_once',
] ]
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment