Commit 540cdf7c authored by Yuxin Wu's avatar Yuxin Wu

misc import fix

parent 59dd1aa5
...@@ -11,7 +11,7 @@ import threading ...@@ -11,7 +11,7 @@ import threading
import six import six
from six.moves import range from six.moves import range
from tensorpack.utils import (get_rng, logger, get_dataset_path, execute_only_once) from tensorpack.utils import (get_rng, logger, get_dataset_path, execute_only_once)
from tensorpack.utils.stat import StatCounter from tensorpack.utils.stats import StatCounter
from tensorpack.RL.envbase import RLEnvironment, DiscreteActionSpace from tensorpack.RL.envbase import RLEnvironment, DiscreteActionSpace
......
...@@ -11,8 +11,7 @@ from six.moves import queue ...@@ -11,8 +11,7 @@ from six.moves import queue
from tensorpack import * from tensorpack import *
from tensorpack.predict import get_predict_func from tensorpack.predict import get_predict_func
from tensorpack.utils.concurrency import * from tensorpack.utils.concurrency import *
from tensorpack.utils.stat import * from tensorpack.utils.stats import *
from tensorpack.callbacks import *
global get_player global get_player
get_player = None get_player = None
......
...@@ -17,8 +17,7 @@ from six.moves import queue ...@@ -17,8 +17,7 @@ from six.moves import queue
from tensorpack import * from tensorpack import *
from tensorpack.utils.concurrency import * from tensorpack.utils.concurrency import *
from tensorpack.utils.serialize import * from tensorpack.utils.serialize import *
from tensorpack.utils.timer import * from tensorpack.utils.stats import *
from tensorpack.utils.stat import *
from tensorpack.tfutils import symbolic_functions as symbf from tensorpack.tfutils import symbolic_functions as symbf
from tensorpack.RL import * from tensorpack.RL import *
......
...@@ -13,7 +13,7 @@ import multiprocessing ...@@ -13,7 +13,7 @@ import multiprocessing
import tensorflow as tf import tensorflow as tf
from tensorflow.contrib.layers import variance_scaling_initializer from tensorflow.contrib.layers import variance_scaling_initializer
from tensorpack import * from tensorpack import *
from tensorpack.utils.stat import RatioCounter from tensorpack.utils.stats import RatioCounter
from tensorpack.tfutils.symbolic_functions import * from tensorpack.tfutils.symbolic_functions import *
from tensorpack.tfutils.summary import * from tensorpack.tfutils.summary import *
......
...@@ -15,7 +15,7 @@ from tensorflow.contrib.layers import variance_scaling_initializer ...@@ -15,7 +15,7 @@ from tensorflow.contrib.layers import variance_scaling_initializer
from tensorpack import * from tensorpack import *
from tensorpack.utils import logger from tensorpack.utils import logger
from tensorpack.utils.stat import RatioCounter from tensorpack.utils.stats import RatioCounter
from tensorpack.tfutils.symbolic_functions import * from tensorpack.tfutils.symbolic_functions import *
from tensorpack.tfutils.summary import * from tensorpack.tfutils.summary import *
from tensorpack.dataflow.dataset import ILSVRCMeta from tensorpack.dataflow.dataset import ILSVRCMeta
......
...@@ -6,12 +6,11 @@ ...@@ -6,12 +6,11 @@
import numpy as np import numpy as np
from collections import deque, namedtuple from collections import deque, namedtuple
import threading import threading
from tqdm import tqdm
import six import six
from six.moves import queue from six.moves import queue
from ..dataflow import DataFlow from ..dataflow import DataFlow
from ..utils import * from ..utils import logger, get_tqdm
from ..utils.concurrency import LoopThread from ..utils.concurrency import LoopThread
from ..callbacks.base import Callback from ..callbacks.base import Callback
...@@ -71,7 +70,7 @@ class ExpReplay(DataFlow, Callback): ...@@ -71,7 +70,7 @@ class ExpReplay(DataFlow, Callback):
self._populate_exp() self._populate_exp()
self.exploration = old_exploration self.exploration = old_exploration
with tqdm(total=self.init_memory_size) as pbar: with get_tqdm(total=self.init_memory_size) as pbar:
while len(self.mem) < self.init_memory_size: while len(self.mem) < self.init_memory_size:
self._populate_exp() self._populate_exp()
pbar.update() pbar.update()
......
...@@ -19,8 +19,8 @@ except ImportError: ...@@ -19,8 +19,8 @@ except ImportError:
import threading import threading
from ..utils.fs import * from ..utils.fs import mkdir_p
from ..utils.stat import * from ..utils.stats import StatCounter
from .envbase import RLEnvironment, DiscreteActionSpace from .envbase import RLEnvironment, DiscreteActionSpace
...@@ -77,7 +77,7 @@ if __name__ == '__main__': ...@@ -77,7 +77,7 @@ if __name__ == '__main__':
env = GymEnv('Breakout-v0', viz=0.1) env = GymEnv('Breakout-v0', viz=0.1)
num = env.get_action_space().num_actions() num = env.get_action_space().num_actions()
from ..utils import * from ..utils import get_rng
rng = get_rng(num) rng = get_rng(num)
while True: while True:
act = rng.choice(range(num)) act = rng.choice(range(num))
......
...@@ -19,9 +19,9 @@ from ..callbacks import Callback ...@@ -19,9 +19,9 @@ from ..callbacks import Callback
from ..tfutils.varmanip import SessionUpdate from ..tfutils.varmanip import SessionUpdate
from ..predict import OfflinePredictor from ..predict import OfflinePredictor
from ..utils import logger from ..utils import logger
from ..utils.timer import * #from ..utils.timer import *
from ..utils.serialize import * from ..utils.serialize import loads, dumps
from ..utils.concurrency import * from ..utils.concurrency import LoopThread, ensure_proc_terminate
__all__ = ['SimulatorProcess', 'SimulatorMaster', __all__ = ['SimulatorProcess', 'SimulatorMaster',
'SimulatorProcessStateExchange', 'SimulatorProcessSharedWeight', 'SimulatorProcessStateExchange', 'SimulatorProcessSharedWeight',
......
...@@ -7,7 +7,7 @@ from contextlib import contextmanager ...@@ -7,7 +7,7 @@ from contextlib import contextmanager
import time import time
from .base import Callback from .base import Callback
from .stat import StatPrinter from .stats import StatPrinter
from ..utils import logger from ..utils import logger
__all__ = ['Callbacks'] __all__ = ['Callbacks']
......
...@@ -11,7 +11,7 @@ from six.moves import zip, map ...@@ -11,7 +11,7 @@ from six.moves import zip, map
from ..dataflow import DataFlow from ..dataflow import DataFlow
from ..utils import get_tqdm, logger, execute_only_once from ..utils import get_tqdm, logger, execute_only_once
from ..utils.stat import RatioCounter, BinaryStatistics from ..utils.stats import RatioCounter, BinaryStatistics
from ..tfutils import get_op_tensor_name, get_op_var_name from ..tfutils import get_op_tensor_name, get_op_var_name
from .base import Callback from .base import Callback
from .dispatcher import OutputTensorDispatcer from .dispatcher import OutputTensorDispatcer
......
...@@ -12,7 +12,8 @@ import uuid ...@@ -12,7 +12,8 @@ import uuid
import os import os
from .base import ProxyDataFlow from .base import ProxyDataFlow
from ..utils.concurrency import * from ..utils.concurrency import (ensure_proc_terminate,
mask_sigint, start_proc_mask_signal)
from ..utils.serialize import loads, dumps from ..utils.serialize import loads, dumps
from ..utils import logger from ..utils import logger
from ..utils.gpu import change_gpu from ..utils.gpu import change_gpu
...@@ -82,6 +83,7 @@ class PrefetchData(ProxyDataFlow): ...@@ -82,6 +83,7 @@ class PrefetchData(ProxyDataFlow):
pass pass
def BlockParallel(ds, queue_size): def BlockParallel(ds, queue_size):
# TODO more doc
""" """
Insert `BlockParallel` in dataflow pipeline to block parallelism on ds Insert `BlockParallel` in dataflow pipeline to block parallelism on ds
...@@ -170,7 +172,8 @@ class PrefetchDataZMQ(ProxyDataFlow): ...@@ -170,7 +172,8 @@ class PrefetchDataZMQ(ProxyDataFlow):
pass pass
class PrefetchOnGPUs(PrefetchDataZMQ): class PrefetchOnGPUs(PrefetchDataZMQ):
""" Prefetch with each process having a specific CUDA_VISIBLE_DEVICES""" """ Prefetch with each process having a specific CUDA_VISIBLE_DEVICES
variable"""
def __init__(self, ds, gpus, pipedir=None): def __init__(self, ds, gpus, pipedir=None):
self.gpus = gpus self.gpus = gpus
super(PrefetchOnGPUs, self).__init__(ds, len(gpus), pipedir) super(PrefetchOnGPUs, self).__init__(ds, len(gpus), pipedir)
......
...@@ -13,7 +13,7 @@ from ..utils.concurrency import DIE ...@@ -13,7 +13,7 @@ from ..utils.concurrency import DIE
from ..tfutils.modelutils import describe_model from ..tfutils.modelutils import describe_model
from ..utils import logger from ..utils import logger
from .base import * from .base import OfflinePredictor, AsyncPredictorBase
try: try:
if six.PY2: if six.PY2:
......
...@@ -10,7 +10,7 @@ from collections import defaultdict ...@@ -10,7 +10,7 @@ from collections import defaultdict
import six import six
import atexit import atexit
from .stat import StatCounter from .stats import StatCounter
from . import logger from . import logger
__all__ = ['total_timer', 'timed_operation', __all__ = ['total_timer', 'timed_operation',
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment