Commit eecb5803 authored by Yuxin Wu's avatar Yuxin Wu

hide utils.utils.* from automatic import

parent 94a01039
......@@ -12,6 +12,7 @@ import argparse
from tensorpack import *
from tensorpack.utils.argtools import memoized
from tensorpack.utils.stats import OnlineMoments
from tensorpack.utils.utils import get_tqdm
import bob.ap
CHARSET = set(string.ascii_lowercase + ' ')
......
......@@ -35,7 +35,6 @@ class CharRNNData(RNGDataFlow):
def __init__(self, input_file, size):
self.seq_length = param.seq_len
self._size = size
self.rng = get_rng(self)
logger.info("Loading corpus...")
# preprocess data
......
......@@ -11,7 +11,8 @@ from collections import deque
import threading
import six
from six.moves import range
from tensorpack.utils import (get_rng, logger, execute_only_once)
from tensorpack.utils import logger
from tensorpack.utils.utils import get_rng, execute_only_once
from tensorpack.utils.fs import get_dataset_path
from tensorpack.utils.stats import StatCounter
......
......@@ -11,7 +11,8 @@ import six
from six.moves import queue, range
from tensorpack.dataflow import DataFlow
from tensorpack.utils import logger, get_tqdm, get_rng
from tensorpack.utils import logger
from tensorpack.utils.utils import get_tqdm, get_rng
from tensorpack.utils.concurrency import LoopThread, ShareSessionThread
from tensorpack.callbacks.base import Callback
......
......@@ -3,17 +3,16 @@
# File: disturb.py
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
from tensorpack import ProxyDataFlow, get_rng
from tensorpack.dataflow import ProxyDataFlow, RNGDataFlow
class DisturbLabel(ProxyDataFlow):
class DisturbLabel(ProxyDataFlow, RNGDataFlow):
def __init__(self, ds, prob):
super(DisturbLabel, self).__init__(ds)
self.prob = prob
def reset_state(self):
super(DisturbLabel, self).reset_state()
self.rng = get_rng(self)
def get_data(self):
for dp in self.ds.get_data():
......
......@@ -7,7 +7,7 @@
from abc import abstractmethod, ABCMeta
from collections import defaultdict
import six
from ..utils import get_rng
from ..utils.utils import get_rng
__all__ = ['RLEnvironment', 'ProxyPlayer',
'DiscreteActionSpace']
......
......@@ -90,7 +90,7 @@ if __name__ == '__main__':
env = GymEnv('FlappyBird-v0', viz=0.1)
num = env.get_action_space().num_actions()
from ..utils import get_rng
from ..utils.utils import get_rng
rng = get_rng(num)
while True:
act = rng.choice(range(num))
......
......@@ -6,7 +6,7 @@
from abc import abstractmethod, ABCMeta
import six
from ..utils import get_rng
from ..utils.utils import get_rng
__all__ = ['DataFlow', 'ProxyDataFlow', 'RNGDataFlow', 'DataFlowTerminated']
......
......@@ -10,8 +10,8 @@ from termcolor import colored
from collections import deque, defaultdict
from six.moves import range, map
from .base import DataFlow, ProxyDataFlow, RNGDataFlow
from ..utils import logger, get_rng
from ..utils.utils import get_tqdm
from ..utils import logger
from ..utils.utils import get_tqdm, get_rng
from ..utils.develop import log_deprecated
__all__ = ['TestDataSpeed', 'PrintData', 'BatchData', 'BatchDataByShape', 'FixedSizeData', 'MapData',
......
......@@ -3,7 +3,7 @@
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
from abc import abstractmethod, ABCMeta
from ...utils import get_rng
from ...utils.utils import get_rng
import six
from six.moves import zip
......
......@@ -6,7 +6,8 @@
import time
from collections import deque
from .base import DataFlow
from ..utils import logger, get_tqdm
from ..utils import logger
from ..utils.utils import get_tqdm
from ..utils.serialize import dumps, loads, dumps_for_tfop
try:
import zmq
......
......@@ -12,7 +12,8 @@ import six
from ..dataflow import DataFlow
from ..dataflow.dftools import dump_dataflow_to_process_queue
from ..utils.concurrency import ensure_proc_terminate, OrderedResultGatherProc, DIE
from ..utils import logger, get_tqdm
from ..utils import logger
from ..utils.utils import get_tqdm
from ..utils.gpu import change_gpu
from .concurrency import MultiProcessQueuePredictWorker
......
......@@ -13,26 +13,13 @@ These utils should be irrelevant to tensorflow.
__all__ = []
def _global_import(name):
p = __import__(name, globals(), None, level=1)
lst = p.__all__ if '__all__' in dir(p) else dir(p)
for k in lst:
globals()[k] = p.__dict__[k]
__all__.append(k)
_TO_IMPORT = set([
'utils',
])
# this two functions for back-compat only
def get_nr_gpu():
from .gpu import get_nr_gpu
from .gpu import get_nr_gpu as gg
logger.warn( # noqa
"get_nr_gpu will not be automatically imported any more! "
"Please do `from tensorpack.utils.gpu import get_nr_gpu`")
return get_nr_gpu()
return gg()
def change_gpu(val):
......@@ -43,6 +30,14 @@ def change_gpu(val):
return cg(val)
def get_rng(obj=None):
from .utils import get_rng as gr
logger.warn( # noqa
"get_rng will not be automatically imported any more! "
"Please do `from tensorpack.utils.utils import get_rng`")
return gr(obj)
_CURR_DIR = os.path.dirname(__file__)
for _, module_name, _ in iter_modules(
[_CURR_DIR]):
......@@ -51,8 +46,6 @@ for _, module_name, _ in iter_modules(
continue
if module_name.startswith('_'):
continue
if module_name in _TO_IMPORT:
_global_import(module_name)
__all__.extend([
'logger',
'get_nr_gpu', 'change_gpu'])
'get_nr_gpu', 'change_gpu', 'get_rng'])
......@@ -14,7 +14,6 @@ import numpy as np
__all__ = ['change_env',
'get_rng',
'fix_rng_seed',
# 'get_tqdm_kwargs',
'get_tqdm',
'execute_only_once',
]
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment