Commit 540cdf7c authored by Yuxin Wu's avatar Yuxin Wu

misc import fix

parent 59dd1aa5
......@@ -11,7 +11,7 @@ import threading
import six
from six.moves import range
from tensorpack.utils import (get_rng, logger, get_dataset_path, execute_only_once)
from tensorpack.utils.stat import StatCounter
from tensorpack.utils.stats import StatCounter
from tensorpack.RL.envbase import RLEnvironment, DiscreteActionSpace
......
......@@ -11,8 +11,7 @@ from six.moves import queue
from tensorpack import *
from tensorpack.predict import get_predict_func
from tensorpack.utils.concurrency import *
from tensorpack.utils.stat import *
from tensorpack.callbacks import *
from tensorpack.utils.stats import *
global get_player
get_player = None
......
......@@ -17,8 +17,7 @@ from six.moves import queue
from tensorpack import *
from tensorpack.utils.concurrency import *
from tensorpack.utils.serialize import *
from tensorpack.utils.timer import *
from tensorpack.utils.stat import *
from tensorpack.utils.stats import *
from tensorpack.tfutils import symbolic_functions as symbf
from tensorpack.RL import *
......
......@@ -13,7 +13,7 @@ import multiprocessing
import tensorflow as tf
from tensorflow.contrib.layers import variance_scaling_initializer
from tensorpack import *
from tensorpack.utils.stat import RatioCounter
from tensorpack.utils.stats import RatioCounter
from tensorpack.tfutils.symbolic_functions import *
from tensorpack.tfutils.summary import *
......
......@@ -15,7 +15,7 @@ from tensorflow.contrib.layers import variance_scaling_initializer
from tensorpack import *
from tensorpack.utils import logger
from tensorpack.utils.stat import RatioCounter
from tensorpack.utils.stats import RatioCounter
from tensorpack.tfutils.symbolic_functions import *
from tensorpack.tfutils.summary import *
from tensorpack.dataflow.dataset import ILSVRCMeta
......
......@@ -6,12 +6,11 @@
import numpy as np
from collections import deque, namedtuple
import threading
from tqdm import tqdm
import six
from six.moves import queue
from ..dataflow import DataFlow
from ..utils import *
from ..utils import logger, get_tqdm
from ..utils.concurrency import LoopThread
from ..callbacks.base import Callback
......@@ -71,7 +70,7 @@ class ExpReplay(DataFlow, Callback):
self._populate_exp()
self.exploration = old_exploration
with tqdm(total=self.init_memory_size) as pbar:
with get_tqdm(total=self.init_memory_size) as pbar:
while len(self.mem) < self.init_memory_size:
self._populate_exp()
pbar.update()
......
......@@ -19,8 +19,8 @@ except ImportError:
import threading
from ..utils.fs import *
from ..utils.stat import *
from ..utils.fs import mkdir_p
from ..utils.stats import StatCounter
from .envbase import RLEnvironment, DiscreteActionSpace
......@@ -77,7 +77,7 @@ if __name__ == '__main__':
env = GymEnv('Breakout-v0', viz=0.1)
num = env.get_action_space().num_actions()
from ..utils import *
from ..utils import get_rng
rng = get_rng(num)
while True:
act = rng.choice(range(num))
......
......@@ -19,9 +19,9 @@ from ..callbacks import Callback
from ..tfutils.varmanip import SessionUpdate
from ..predict import OfflinePredictor
from ..utils import logger
from ..utils.timer import *
from ..utils.serialize import *
from ..utils.concurrency import *
#from ..utils.timer import *
from ..utils.serialize import loads, dumps
from ..utils.concurrency import LoopThread, ensure_proc_terminate
__all__ = ['SimulatorProcess', 'SimulatorMaster',
'SimulatorProcessStateExchange', 'SimulatorProcessSharedWeight',
......
......@@ -7,7 +7,7 @@ from contextlib import contextmanager
import time
from .base import Callback
from .stat import StatPrinter
from .stats import StatPrinter
from ..utils import logger
__all__ = ['Callbacks']
......
......@@ -11,7 +11,7 @@ from six.moves import zip, map
from ..dataflow import DataFlow
from ..utils import get_tqdm, logger, execute_only_once
from ..utils.stat import RatioCounter, BinaryStatistics
from ..utils.stats import RatioCounter, BinaryStatistics
from ..tfutils import get_op_tensor_name, get_op_var_name
from .base import Callback
from .dispatcher import OutputTensorDispatcer
......
......@@ -12,7 +12,8 @@ import uuid
import os
from .base import ProxyDataFlow
from ..utils.concurrency import *
from ..utils.concurrency import (ensure_proc_terminate,
mask_sigint, start_proc_mask_signal)
from ..utils.serialize import loads, dumps
from ..utils import logger
from ..utils.gpu import change_gpu
......@@ -82,6 +83,7 @@ class PrefetchData(ProxyDataFlow):
pass
def BlockParallel(ds, queue_size):
# TODO more doc
"""
Insert `BlockParallel` in dataflow pipeline to block parallelism on ds
......@@ -170,7 +172,8 @@ class PrefetchDataZMQ(ProxyDataFlow):
pass
class PrefetchOnGPUs(PrefetchDataZMQ):
""" Prefetch with each process having a specific CUDA_VISIBLE_DEVICES"""
""" Prefetch with each process having a specific CUDA_VISIBLE_DEVICES
variable"""
def __init__(self, ds, gpus, pipedir=None):
self.gpus = gpus
super(PrefetchOnGPUs, self).__init__(ds, len(gpus), pipedir)
......
......@@ -13,7 +13,7 @@ from ..utils.concurrency import DIE
from ..tfutils.modelutils import describe_model
from ..utils import logger
from .base import *
from .base import OfflinePredictor, AsyncPredictorBase
try:
if six.PY2:
......
......@@ -10,7 +10,7 @@ from collections import defaultdict
import six
import atexit
from .stat import StatCounter
from .stats import StatCounter
from . import logger
__all__ = ['total_timer', 'timed_operation',
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment