Commit 7bdaf8ec authored by Yuxin Wu's avatar Yuxin Wu

docs cleanup

parent f17d16da
...@@ -399,11 +399,11 @@ _DEPRECATED_NAMES = set([ ...@@ -399,11 +399,11 @@ _DEPRECATED_NAMES = set([
'l2_regularizer', 'l1_regularizer', 'l2_regularizer', 'l1_regularizer',
# internal only # internal only
'execute_only_once',
'humanize_time_delta',
'SessionUpdate', 'SessionUpdate',
'average_grads', 'get_checkpoint_path',
'aggregate_grads', 'IterSpeedCounter'
'allreduce_grads',
'get_checkpoint_path'
]) ])
def autodoc_skip_member(app, what, name, obj, skip, options): def autodoc_skip_member(app, what, name, obj, skip, options):
......
...@@ -56,15 +56,6 @@ tensorpack.utils.serialize module ...@@ -56,15 +56,6 @@ tensorpack.utils.serialize module
:undoc-members: :undoc-members:
:show-inheritance: :show-inheritance:
tensorpack.utils.compatible_serialize module
--------------------------------------------
.. automodule:: tensorpack.utils.compatible_serialize
:members:
:undoc-members:
:show-inheritance:
tensorpack.utils.stats module tensorpack.utils.stats module
----------------------------- -----------------------------
......
...@@ -10,10 +10,9 @@ import bob.ap ...@@ -10,10 +10,9 @@ import bob.ap
import scipy.io.wavfile as wavfile import scipy.io.wavfile as wavfile
from tensorpack.dataflow import DataFlow, LMDBSerializer from tensorpack.dataflow import DataFlow, LMDBSerializer
from tensorpack.utils import fs, logger, serialize from tensorpack.utils import fs, logger, serialize, get_tqdm
from tensorpack.utils.argtools import memoized from tensorpack.utils.argtools import memoized
from tensorpack.utils.stats import OnlineMoments from tensorpack.utils.stats import OnlineMoments
from tensorpack.utils.utils import get_tqdm
CHARSET = set(string.ascii_lowercase + ' ') CHARSET = set(string.ascii_lowercase + ' ')
PHONEME_LIST = [ PHONEME_LIST = [
......
...@@ -13,9 +13,8 @@ from gym import spaces ...@@ -13,9 +13,8 @@ from gym import spaces
from gym.envs.atari.atari_env import ACTION_MEANING from gym.envs.atari.atari_env import ACTION_MEANING
from six.moves import range from six.moves import range
from tensorpack.utils import logger from tensorpack.utils import logger, execute_only_once, get_rng
from tensorpack.utils.fs import get_dataset_path from tensorpack.utils.fs import get_dataset_path
from tensorpack.utils.utils import execute_only_once, get_rng
__all__ = ['AtariPlayer'] __all__ = ['AtariPlayer']
......
...@@ -7,13 +7,11 @@ import numpy as np ...@@ -7,13 +7,11 @@ import numpy as np
import random import random
import time import time
from six.moves import queue from six.moves import queue
from tqdm import tqdm
from tensorpack.callbacks import Callback from tensorpack.callbacks import Callback
from tensorpack.utils import logger from tensorpack.utils import logger, get_tqdm
from tensorpack.utils.concurrency import ShareSessionThread, StoppableThread from tensorpack.utils.concurrency import ShareSessionThread, StoppableThread
from tensorpack.utils.stats import StatCounter from tensorpack.utils.stats import StatCounter
from tensorpack.utils.utils import get_tqdm_kwargs
def play_one_episode(env, func, render=False): def play_one_episode(env, func, render=False):
...@@ -87,7 +85,7 @@ def eval_with_funcs(predictors, nr_eval, get_player_fn, verbose=False): ...@@ -87,7 +85,7 @@ def eval_with_funcs(predictors, nr_eval, get_player_fn, verbose=False):
if verbose: if verbose:
logger.info("Score: {}".format(r)) logger.info("Score: {}".format(r))
for _ in tqdm(range(nr_eval), **get_tqdm_kwargs()): for _ in get_tqdm(range(nr_eval)):
fetch() fetch()
# waiting is necessary, otherwise the estimated mean score is biased # waiting is necessary, otherwise the estimated mean score is biased
logger.info("Waiting for all the workers to finish the last run...") logger.info("Waiting for all the workers to finish the last run...")
......
...@@ -12,9 +12,8 @@ from six.moves import queue, range ...@@ -12,9 +12,8 @@ from six.moves import queue, range
from tensorpack.utils.concurrency import LoopThread, ShareSessionThread from tensorpack.utils.concurrency import LoopThread, ShareSessionThread
from tensorpack.callbacks.base import Callback from tensorpack.callbacks.base import Callback
from tensorpack.dataflow import DataFlow from tensorpack.dataflow import DataFlow
from tensorpack.utils import logger from tensorpack.utils import logger, get_rng, get_tqdm
from tensorpack.utils.stats import StatCounter from tensorpack.utils.stats import StatCounter
from tensorpack.utils.utils import get_rng, get_tqdm
__all__ = ['ExpReplay'] __all__ = ['ExpReplay']
......
...@@ -17,8 +17,7 @@ from scipy import interpolate ...@@ -17,8 +17,7 @@ from scipy import interpolate
from tensorpack.callbacks import Callback from tensorpack.callbacks import Callback
from tensorpack.tfutils.common import get_tf_version_tuple from tensorpack.tfutils.common import get_tf_version_tuple
from tensorpack.utils import logger from tensorpack.utils import logger, get_tqdm
from tensorpack.utils.utils import get_tqdm
from common import CustomResize, clip_boxes from common import CustomResize, clip_boxes
from config import config as cfg from config import config as cfg
......
...@@ -47,6 +47,8 @@ if __name__ == '__main__': ...@@ -47,6 +47,8 @@ if __name__ == '__main__':
# Setup logging ... # Setup logging ...
is_horovod = cfg.TRAINER == 'horovod' is_horovod = cfg.TRAINER == 'horovod'
if is_horovod:
hvd.init()
if not is_horovod or hvd.rank() == 0: if not is_horovod or hvd.rank() == 0:
logger.set_logger_dir(args.logdir, 'd') logger.set_logger_dir(args.logdir, 'd')
logger.info("Environment Information:\n" + collect_env_info()) logger.info("Environment Information:\n" + collect_env_info())
......
...@@ -129,7 +129,7 @@ class GANTrainer(TowerTrainer): ...@@ -129,7 +129,7 @@ class GANTrainer(TowerTrainer):
self.tower_func = TowerFunc(get_cost, model.get_input_signature()) self.tower_func = TowerFunc(get_cost, model.get_input_signature())
devices = [LeastLoadedDeviceSetter(d, raw_devices) for d in raw_devices] devices = [LeastLoadedDeviceSetter(d, raw_devices) for d in raw_devices]
cost_list = DataParallelBuilder.build_on_towers( cost_list = DataParallelBuilder.call_for_each_tower(
list(range(num_gpu)), list(range(num_gpu)),
lambda: self.tower_func(*input.get_input_tensors()), lambda: self.tower_func(*input.get_input_tensors()),
devices) devices)
......
...@@ -11,7 +11,7 @@ from ..utils.argtools import memoized ...@@ -11,7 +11,7 @@ from ..utils.argtools import memoized
from .training import DataParallelBuilder, GraphBuilder from .training import DataParallelBuilder, GraphBuilder
from .utils import OverrideCachingDevice, aggregate_grads, override_to_local_variable from .utils import OverrideCachingDevice, aggregate_grads, override_to_local_variable
__all__ = ['DistributedParameterServerBuilder', 'DistributedReplicatedBuilder'] __all__ = []
class DistributedBuilderBase(GraphBuilder): class DistributedBuilderBase(GraphBuilder):
......
...@@ -15,13 +15,12 @@ from ..tfutils.common import get_tf_version_tuple ...@@ -15,13 +15,12 @@ from ..tfutils.common import get_tf_version_tuple
from ..tfutils.gradproc import ScaleGradient from ..tfutils.gradproc import ScaleGradient
from ..tfutils.tower import TrainTowerContext from ..tfutils.tower import TrainTowerContext
from ..utils import logger from ..utils import logger
from ..utils.develop import HIDE_DOC
from .utils import ( from .utils import (
GradientPacker, LeastLoadedDeviceSetter, aggregate_grads, allreduce_grads, allreduce_grads_hierarchical, GradientPacker, LeastLoadedDeviceSetter, aggregate_grads, allreduce_grads, allreduce_grads_hierarchical,
merge_grad_list, override_to_local_variable, split_grad_list) merge_grad_list, override_to_local_variable, split_grad_list)
__all__ = ['GraphBuilder', __all__ = ["DataParallelBuilder"]
'SyncMultiGPUParameterServerBuilder', 'DataParallelBuilder',
'SyncMultiGPUReplicatedBuilder', 'AsyncMultiGPUBuilder']
@six.add_metaclass(ABCMeta) @six.add_metaclass(ABCMeta)
...@@ -117,6 +116,7 @@ class DataParallelBuilder(GraphBuilder): ...@@ -117,6 +116,7 @@ class DataParallelBuilder(GraphBuilder):
ret.append(func()) ret.append(func())
return ret return ret
@HIDE_DOC
@staticmethod @staticmethod
def build_on_towers(*args, **kwargs): def build_on_towers(*args, **kwargs):
return DataParallelBuilder.call_for_each_tower(*args, **kwargs) return DataParallelBuilder.call_for_each_tower(*args, **kwargs)
......
...@@ -13,13 +13,7 @@ from ..tfutils.varreplace import custom_getter_scope ...@@ -13,13 +13,7 @@ from ..tfutils.varreplace import custom_getter_scope
from ..utils import logger from ..utils import logger
from ..utils.argtools import call_only_once from ..utils.argtools import call_only_once
__all__ = ['LeastLoadedDeviceSetter', __all__ = ["LeastLoadedDeviceSetter"]
'OverrideCachingDevice',
'override_to_local_variable',
'allreduce_grads',
'average_grads',
'aggregate_grads'
]
""" """
......
...@@ -10,9 +10,8 @@ from ..input_source import PlaceholderInput ...@@ -10,9 +10,8 @@ from ..input_source import PlaceholderInput
from ..tfutils.common import get_tensors_by_names from ..tfutils.common import get_tensors_by_names
from ..tfutils.tower import PredictTowerContext from ..tfutils.tower import PredictTowerContext
__all__ = ['PredictorBase', 'AsyncPredictorBase', __all__ = ['PredictorBase',
'OnlinePredictor', 'OfflinePredictor', 'OnlinePredictor', 'OfflinePredictor']
]
@six.add_metaclass(ABCMeta) @six.add_metaclass(ABCMeta)
...@@ -62,7 +61,7 @@ class AsyncPredictorBase(PredictorBase): ...@@ -62,7 +61,7 @@ class AsyncPredictorBase(PredictorBase):
dp (list): A datapoint as inputs. It could be either batched or not dp (list): A datapoint as inputs. It could be either batched or not
batched depending on the predictor implementation). batched depending on the predictor implementation).
callback: a thread-safe callback to get called with callback: a thread-safe callback to get called with
either outputs or (inputs, outputs). either outputs or (inputs, outputs), if `return_input` is True.
Returns: Returns:
concurrent.futures.Future: a Future of results concurrent.futures.Future: a Future of results
""" """
......
...@@ -14,8 +14,7 @@ from ..utils import logger ...@@ -14,8 +14,7 @@ from ..utils import logger
from ..utils.concurrency import DIE, ShareSessionThread, StoppableThread from ..utils.concurrency import DIE, ShareSessionThread, StoppableThread
from .base import AsyncPredictorBase, OfflinePredictor, OnlinePredictor from .base import AsyncPredictorBase, OfflinePredictor, OnlinePredictor
__all__ = ['MultiProcessPredictWorker', 'MultiProcessQueuePredictWorker', __all__ = ['MultiThreadAsyncPredictor']
'MultiThreadAsyncPredictor']
class MultiProcessPredictWorker(multiprocessing.Process): class MultiProcessPredictWorker(multiprocessing.Process):
...@@ -171,7 +170,13 @@ class MultiThreadAsyncPredictor(AsyncPredictorBase): ...@@ -171,7 +170,13 @@ class MultiThreadAsyncPredictor(AsyncPredictorBase):
def put_task(self, dp, callback=None): def put_task(self, dp, callback=None):
""" """
Same as in :meth:`AsyncPredictorBase.put_task`. Args:
dp (list): A datapoint as inputs. It could be either batched or not
batched depending on the predictor implementation).
callback: a thread-safe callback. When the results are ready, it will be called
with the "future" object.
Returns:
concurrent.futures.Future: a Future of results.
""" """
f = Future() f = Future()
if callback is not None: if callback is not None:
......
...@@ -11,6 +11,7 @@ from six.moves import range, zip ...@@ -11,6 +11,7 @@ from six.moves import range, zip
from ..dataflow import DataFlow from ..dataflow import DataFlow
from ..dataflow.remote import dump_dataflow_to_process_queue from ..dataflow.remote import dump_dataflow_to_process_queue
from ..utils import logger from ..utils import logger
from ..utils.develop import HIDE_DOC
from ..utils.concurrency import DIE, OrderedResultGatherProc, ensure_proc_terminate from ..utils.concurrency import DIE, OrderedResultGatherProc, ensure_proc_terminate
from ..utils.gpu import change_gpu, get_num_gpu from ..utils.gpu import change_gpu, get_num_gpu
from ..utils.utils import get_tqdm from ..utils.utils import get_tqdm
...@@ -63,6 +64,7 @@ class SimpleDatasetPredictor(DatasetPredictorBase): ...@@ -63,6 +64,7 @@ class SimpleDatasetPredictor(DatasetPredictorBase):
super(SimpleDatasetPredictor, self).__init__(config, dataset) super(SimpleDatasetPredictor, self).__init__(config, dataset)
self.predictor = OfflinePredictor(config) self.predictor = OfflinePredictor(config)
@HIDE_DOC
def get_result(self): def get_result(self):
self.dataset.reset_state() self.dataset.reset_state()
try: try:
...@@ -142,6 +144,7 @@ class MultiProcessDatasetPredictor(DatasetPredictorBase): ...@@ -142,6 +144,7 @@ class MultiProcessDatasetPredictor(DatasetPredictorBase):
self.result_queue = self.outqueue self.result_queue = self.outqueue
ensure_proc_terminate(self.workers + [self.inqueue_proc]) ensure_proc_terminate(self.workers + [self.inqueue_proc])
@HIDE_DOC
def get_result(self): def get_result(self):
try: try:
sz = len(self.dataset) sz = len(self.dataset)
......
...@@ -184,47 +184,47 @@ def save_chkpt_vars(dic, path): ...@@ -184,47 +184,47 @@ def save_chkpt_vars(dic, path):
saver.save(sess, path, write_meta_graph=False) saver.save(sess, path, write_meta_graph=False)
def get_checkpoint_path(model_path): def get_checkpoint_path(path):
""" """
Work around TF problems in checkpoint path handling. Work around TF problems in checkpoint path handling.
Args: Args:
model_path: a user-input path path: a user-input path
Returns: Returns:
str: the argument that can be passed to NewCheckpointReader str: the argument that can be passed to NewCheckpointReader
""" """
if os.path.basename(model_path) == model_path: if os.path.basename(path) == path:
model_path = os.path.join('.', model_path) # avoid #4921 and #6142 path = os.path.join('.', path) # avoid #4921 and #6142
if os.path.basename(model_path) == 'checkpoint': if os.path.basename(path) == 'checkpoint':
assert tfv1.gfile.Exists(model_path), model_path assert tfv1.gfile.Exists(path), path
model_path = tf.train.latest_checkpoint(os.path.dirname(model_path)) path = tf.train.latest_checkpoint(os.path.dirname(path))
# to be consistent with either v1 or v2 # to be consistent with either v1 or v2
# fix paths if provided a wrong one # fix paths if provided a wrong one
new_path = model_path new_path = path
if '00000-of-00001' in model_path: if '00000-of-00001' in path:
new_path = model_path.split('.data')[0] new_path = path.split('.data')[0]
elif model_path.endswith('.index'): elif path.endswith('.index'):
new_path = model_path.split('.index')[0] new_path = path.split('.index')[0]
if new_path != model_path: if new_path != path:
logger.info( logger.info(
"Checkpoint path {} is auto-corrected to {}.".format(model_path, new_path)) "Checkpoint path {} is auto-corrected to {}.".format(path, new_path))
model_path = new_path path = new_path
assert tfv1.gfile.Exists(model_path) or tfv1.gfile.Exists(model_path + '.index'), model_path assert tfv1.gfile.Exists(path) or tfv1.gfile.Exists(path + '.index'), path
return model_path return path
def load_chkpt_vars(model_path): def load_chkpt_vars(path):
""" Load all variables from a checkpoint to a dict. """ Load all variables from a checkpoint to a dict.
Args: Args:
model_path(str): path to a checkpoint. path(str): path to a checkpoint.
Returns: Returns:
dict: a name:value dict dict: a name:value dict
""" """
model_path = get_checkpoint_path(model_path) path = get_checkpoint_path(path)
reader = tfv1.train.NewCheckpointReader(model_path) reader = tfv1.train.NewCheckpointReader(path)
var_names = reader.get_variable_to_shape_map().keys() var_names = reader.get_variable_to_shape_map().keys()
result = {} result = {}
for n in var_names: for n in var_names:
......
...@@ -13,7 +13,7 @@ else: ...@@ -13,7 +13,7 @@ else:
import functools import functools
__all__ = ['map_arg', 'memoized', 'memoized_method', 'graph_memoized', 'shape2d', 'shape4d', __all__ = ['map_arg', 'memoized', 'memoized_method', 'graph_memoized', 'shape2d', 'shape4d',
'memoized_ignoreargs', 'log_once', 'call_only_once'] 'memoized_ignoreargs', 'log_once']
def map_arg(**maps): def map_arg(**maps):
......
...@@ -26,8 +26,7 @@ else: ...@@ -26,8 +26,7 @@ else:
__all__ = ['StoppableThread', 'LoopThread', 'ShareSessionThread', __all__ = ['StoppableThread', 'LoopThread', 'ShareSessionThread',
'ensure_proc_terminate', 'ensure_proc_terminate',
'OrderedResultGatherProc', 'OrderedContainer', 'DIE', 'start_proc_mask_signal']
'mask_sigint', 'start_proc_mask_signal']
class StoppableThread(threading.Thread): class StoppableThread(threading.Thread):
......
...@@ -97,6 +97,7 @@ def load_caffe(model_desc, model_file): ...@@ -97,6 +97,7 @@ def load_caffe(model_desc, model_file):
""" """
Load a caffe model. You must be able to ``import caffe`` to use this Load a caffe model. You must be able to ``import caffe`` to use this
function. function.
Args: Args:
model_desc (str): path to caffe model description file (.prototxt). model_desc (str): path to caffe model description file (.prototxt).
model_file (str): path to caffe model parameter file (.caffemodel). model_file (str): path to caffe model parameter file (.caffemodel).
...@@ -116,6 +117,7 @@ def load_caffe(model_desc, model_file): ...@@ -116,6 +117,7 @@ def load_caffe(model_desc, model_file):
def get_caffe_pb(): def get_caffe_pb():
""" """
Get caffe protobuf. Get caffe protobuf.
Returns: Returns:
The imported caffe protobuf module. The imported caffe protobuf module.
""" """
......
...@@ -13,7 +13,7 @@ from .develop import create_dummy_func ...@@ -13,7 +13,7 @@ from .develop import create_dummy_func
msgpack_numpy.patch() msgpack_numpy.patch()
assert msgpack.version >= (0, 5, 2) assert msgpack.version >= (0, 5, 2)
__all__ = ['loads', 'dumps', 'NonPicklableWrapper'] __all__ = ['loads', 'dumps']
MAX_MSGPACK_LEN = 1000000000 MAX_MSGPACK_LEN = 1000000000
......
...@@ -15,8 +15,7 @@ if six.PY3: ...@@ -15,8 +15,7 @@ if six.PY3:
from time import perf_counter as timer # noqa from time import perf_counter as timer # noqa
__all__ = ['total_timer', 'timed_operation', __all__ = ['timed_operation', 'IterSpeedCounter', 'Timer']
'print_total_timer', 'IterSpeedCounter', 'Timer']
@contextmanager @contextmanager
...@@ -55,7 +54,7 @@ _TOTAL_TIMER_DATA = defaultdict(StatCounter) ...@@ -55,7 +54,7 @@ _TOTAL_TIMER_DATA = defaultdict(StatCounter)
@contextmanager @contextmanager
def total_timer(msg): def total_timer(msg):
""" A context which add the time spent inside to TotalTimer. """ """ A context which add the time spent inside to the global TotalTimer. """
start = timer() start = timer()
yield yield
t = timer() - start t = timer() - start
...@@ -64,7 +63,7 @@ def total_timer(msg): ...@@ -64,7 +63,7 @@ def total_timer(msg):
def print_total_timer(): def print_total_timer():
""" """
Print the content of the TotalTimer, if it's not empty. This function will automatically get Print the content of the global TotalTimer, if it's not empty. This function will automatically get
called when program exits. called when program exits.
""" """
if len(_TOTAL_TIMER_DATA) == 0: if len(_TOTAL_TIMER_DATA) == 0:
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment