Commit 7bdaf8ec authored by Yuxin Wu's avatar Yuxin Wu

docs cleanup

parent f17d16da
......@@ -399,11 +399,11 @@ _DEPRECATED_NAMES = set([
'l2_regularizer', 'l1_regularizer',
# internal only
'execute_only_once',
'humanize_time_delta',
'SessionUpdate',
'average_grads',
'aggregate_grads',
'allreduce_grads',
'get_checkpoint_path'
'get_checkpoint_path',
'IterSpeedCounter'
])
def autodoc_skip_member(app, what, name, obj, skip, options):
......
......@@ -56,15 +56,6 @@ tensorpack.utils.serialize module
:undoc-members:
:show-inheritance:
tensorpack.utils.compatible_serialize module
--------------------------------------------
.. automodule:: tensorpack.utils.compatible_serialize
:members:
:undoc-members:
:show-inheritance:
tensorpack.utils.stats module
-----------------------------
......
......@@ -10,10 +10,9 @@ import bob.ap
import scipy.io.wavfile as wavfile
from tensorpack.dataflow import DataFlow, LMDBSerializer
from tensorpack.utils import fs, logger, serialize
from tensorpack.utils import fs, logger, serialize, get_tqdm
from tensorpack.utils.argtools import memoized
from tensorpack.utils.stats import OnlineMoments
from tensorpack.utils.utils import get_tqdm
CHARSET = set(string.ascii_lowercase + ' ')
PHONEME_LIST = [
......
......@@ -13,9 +13,8 @@ from gym import spaces
from gym.envs.atari.atari_env import ACTION_MEANING
from six.moves import range
from tensorpack.utils import logger
from tensorpack.utils import logger, execute_only_once, get_rng
from tensorpack.utils.fs import get_dataset_path
from tensorpack.utils.utils import execute_only_once, get_rng
__all__ = ['AtariPlayer']
......
......@@ -7,13 +7,11 @@ import numpy as np
import random
import time
from six.moves import queue
from tqdm import tqdm
from tensorpack.callbacks import Callback
from tensorpack.utils import logger
from tensorpack.utils import logger, get_tqdm
from tensorpack.utils.concurrency import ShareSessionThread, StoppableThread
from tensorpack.utils.stats import StatCounter
from tensorpack.utils.utils import get_tqdm_kwargs
def play_one_episode(env, func, render=False):
......@@ -87,7 +85,7 @@ def eval_with_funcs(predictors, nr_eval, get_player_fn, verbose=False):
if verbose:
logger.info("Score: {}".format(r))
for _ in tqdm(range(nr_eval), **get_tqdm_kwargs()):
for _ in get_tqdm(range(nr_eval)):
fetch()
# waiting is necessary, otherwise the estimated mean score is biased
logger.info("Waiting for all the workers to finish the last run...")
......
......@@ -12,9 +12,8 @@ from six.moves import queue, range
from tensorpack.utils.concurrency import LoopThread, ShareSessionThread
from tensorpack.callbacks.base import Callback
from tensorpack.dataflow import DataFlow
from tensorpack.utils import logger
from tensorpack.utils import logger, get_rng, get_tqdm
from tensorpack.utils.stats import StatCounter
from tensorpack.utils.utils import get_rng, get_tqdm
__all__ = ['ExpReplay']
......
......@@ -17,8 +17,7 @@ from scipy import interpolate
from tensorpack.callbacks import Callback
from tensorpack.tfutils.common import get_tf_version_tuple
from tensorpack.utils import logger
from tensorpack.utils.utils import get_tqdm
from tensorpack.utils import logger, get_tqdm
from common import CustomResize, clip_boxes
from config import config as cfg
......
......@@ -47,6 +47,8 @@ if __name__ == '__main__':
# Setup logging ...
is_horovod = cfg.TRAINER == 'horovod'
if is_horovod:
hvd.init()
if not is_horovod or hvd.rank() == 0:
logger.set_logger_dir(args.logdir, 'd')
logger.info("Environment Information:\n" + collect_env_info())
......
......@@ -129,7 +129,7 @@ class GANTrainer(TowerTrainer):
self.tower_func = TowerFunc(get_cost, model.get_input_signature())
devices = [LeastLoadedDeviceSetter(d, raw_devices) for d in raw_devices]
cost_list = DataParallelBuilder.build_on_towers(
cost_list = DataParallelBuilder.call_for_each_tower(
list(range(num_gpu)),
lambda: self.tower_func(*input.get_input_tensors()),
devices)
......
......@@ -11,7 +11,7 @@ from ..utils.argtools import memoized
from .training import DataParallelBuilder, GraphBuilder
from .utils import OverrideCachingDevice, aggregate_grads, override_to_local_variable
__all__ = ['DistributedParameterServerBuilder', 'DistributedReplicatedBuilder']
__all__ = []
class DistributedBuilderBase(GraphBuilder):
......
......@@ -15,13 +15,12 @@ from ..tfutils.common import get_tf_version_tuple
from ..tfutils.gradproc import ScaleGradient
from ..tfutils.tower import TrainTowerContext
from ..utils import logger
from ..utils.develop import HIDE_DOC
from .utils import (
GradientPacker, LeastLoadedDeviceSetter, aggregate_grads, allreduce_grads, allreduce_grads_hierarchical,
merge_grad_list, override_to_local_variable, split_grad_list)
__all__ = ['GraphBuilder',
'SyncMultiGPUParameterServerBuilder', 'DataParallelBuilder',
'SyncMultiGPUReplicatedBuilder', 'AsyncMultiGPUBuilder']
__all__ = ["DataParallelBuilder"]
@six.add_metaclass(ABCMeta)
......@@ -117,6 +116,7 @@ class DataParallelBuilder(GraphBuilder):
ret.append(func())
return ret
@HIDE_DOC
@staticmethod
def build_on_towers(*args, **kwargs):
return DataParallelBuilder.call_for_each_tower(*args, **kwargs)
......
......@@ -13,13 +13,7 @@ from ..tfutils.varreplace import custom_getter_scope
from ..utils import logger
from ..utils.argtools import call_only_once
__all__ = ['LeastLoadedDeviceSetter',
'OverrideCachingDevice',
'override_to_local_variable',
'allreduce_grads',
'average_grads',
'aggregate_grads'
]
__all__ = ["LeastLoadedDeviceSetter"]
"""
......
......@@ -10,9 +10,8 @@ from ..input_source import PlaceholderInput
from ..tfutils.common import get_tensors_by_names
from ..tfutils.tower import PredictTowerContext
__all__ = ['PredictorBase', 'AsyncPredictorBase',
'OnlinePredictor', 'OfflinePredictor',
]
__all__ = ['PredictorBase',
'OnlinePredictor', 'OfflinePredictor']
@six.add_metaclass(ABCMeta)
......@@ -62,7 +61,7 @@ class AsyncPredictorBase(PredictorBase):
dp (list): A datapoint as inputs. It could be either batched or not
batched depending on the predictor implementation).
callback: a thread-safe callback to get called with
either outputs or (inputs, outputs).
either outputs or (inputs, outputs), if `return_input` is True.
Returns:
concurrent.futures.Future: a Future of results
"""
......
......@@ -14,8 +14,7 @@ from ..utils import logger
from ..utils.concurrency import DIE, ShareSessionThread, StoppableThread
from .base import AsyncPredictorBase, OfflinePredictor, OnlinePredictor
__all__ = ['MultiProcessPredictWorker', 'MultiProcessQueuePredictWorker',
'MultiThreadAsyncPredictor']
__all__ = ['MultiThreadAsyncPredictor']
class MultiProcessPredictWorker(multiprocessing.Process):
......@@ -171,7 +170,13 @@ class MultiThreadAsyncPredictor(AsyncPredictorBase):
def put_task(self, dp, callback=None):
"""
Same as in :meth:`AsyncPredictorBase.put_task`.
Args:
dp (list): A datapoint as inputs. It could be either batched or not
batched depending on the predictor implementation).
callback: a thread-safe callback. When the results are ready, it will be called
with the "future" object.
Returns:
concurrent.futures.Future: a Future of results.
"""
f = Future()
if callback is not None:
......
......@@ -11,6 +11,7 @@ from six.moves import range, zip
from ..dataflow import DataFlow
from ..dataflow.remote import dump_dataflow_to_process_queue
from ..utils import logger
from ..utils.develop import HIDE_DOC
from ..utils.concurrency import DIE, OrderedResultGatherProc, ensure_proc_terminate
from ..utils.gpu import change_gpu, get_num_gpu
from ..utils.utils import get_tqdm
......@@ -63,6 +64,7 @@ class SimpleDatasetPredictor(DatasetPredictorBase):
super(SimpleDatasetPredictor, self).__init__(config, dataset)
self.predictor = OfflinePredictor(config)
@HIDE_DOC
def get_result(self):
self.dataset.reset_state()
try:
......@@ -142,6 +144,7 @@ class MultiProcessDatasetPredictor(DatasetPredictorBase):
self.result_queue = self.outqueue
ensure_proc_terminate(self.workers + [self.inqueue_proc])
@HIDE_DOC
def get_result(self):
try:
sz = len(self.dataset)
......
......@@ -184,47 +184,47 @@ def save_chkpt_vars(dic, path):
saver.save(sess, path, write_meta_graph=False)
def get_checkpoint_path(model_path):
def get_checkpoint_path(path):
"""
Work around TF problems in checkpoint path handling.
Args:
model_path: a user-input path
path: a user-input path
Returns:
str: the argument that can be passed to NewCheckpointReader
"""
if os.path.basename(model_path) == model_path:
model_path = os.path.join('.', model_path) # avoid #4921 and #6142
if os.path.basename(model_path) == 'checkpoint':
assert tfv1.gfile.Exists(model_path), model_path
model_path = tf.train.latest_checkpoint(os.path.dirname(model_path))
if os.path.basename(path) == path:
path = os.path.join('.', path) # avoid #4921 and #6142
if os.path.basename(path) == 'checkpoint':
assert tfv1.gfile.Exists(path), path
path = tf.train.latest_checkpoint(os.path.dirname(path))
# to be consistent with either v1 or v2
# fix paths if provided a wrong one
new_path = model_path
if '00000-of-00001' in model_path:
new_path = model_path.split('.data')[0]
elif model_path.endswith('.index'):
new_path = model_path.split('.index')[0]
if new_path != model_path:
new_path = path
if '00000-of-00001' in path:
new_path = path.split('.data')[0]
elif path.endswith('.index'):
new_path = path.split('.index')[0]
if new_path != path:
logger.info(
"Checkpoint path {} is auto-corrected to {}.".format(model_path, new_path))
model_path = new_path
assert tfv1.gfile.Exists(model_path) or tfv1.gfile.Exists(model_path + '.index'), model_path
return model_path
"Checkpoint path {} is auto-corrected to {}.".format(path, new_path))
path = new_path
assert tfv1.gfile.Exists(path) or tfv1.gfile.Exists(path + '.index'), path
return path
def load_chkpt_vars(model_path):
def load_chkpt_vars(path):
""" Load all variables from a checkpoint to a dict.
Args:
model_path(str): path to a checkpoint.
path(str): path to a checkpoint.
Returns:
dict: a name:value dict
"""
model_path = get_checkpoint_path(model_path)
reader = tfv1.train.NewCheckpointReader(model_path)
path = get_checkpoint_path(path)
reader = tfv1.train.NewCheckpointReader(path)
var_names = reader.get_variable_to_shape_map().keys()
result = {}
for n in var_names:
......
......@@ -13,7 +13,7 @@ else:
import functools
__all__ = ['map_arg', 'memoized', 'memoized_method', 'graph_memoized', 'shape2d', 'shape4d',
'memoized_ignoreargs', 'log_once', 'call_only_once']
'memoized_ignoreargs', 'log_once']
def map_arg(**maps):
......
......@@ -26,8 +26,7 @@ else:
__all__ = ['StoppableThread', 'LoopThread', 'ShareSessionThread',
'ensure_proc_terminate',
'OrderedResultGatherProc', 'OrderedContainer', 'DIE',
'mask_sigint', 'start_proc_mask_signal']
'start_proc_mask_signal']
class StoppableThread(threading.Thread):
......
......@@ -97,6 +97,7 @@ def load_caffe(model_desc, model_file):
"""
Load a caffe model. You must be able to ``import caffe`` to use this
function.
Args:
model_desc (str): path to caffe model description file (.prototxt).
model_file (str): path to caffe model parameter file (.caffemodel).
......@@ -116,6 +117,7 @@ def load_caffe(model_desc, model_file):
def get_caffe_pb():
"""
Get caffe protobuf.
Returns:
The imported caffe protobuf module.
"""
......
......@@ -13,7 +13,7 @@ from .develop import create_dummy_func
msgpack_numpy.patch()
assert msgpack.version >= (0, 5, 2)
__all__ = ['loads', 'dumps', 'NonPicklableWrapper']
__all__ = ['loads', 'dumps']
MAX_MSGPACK_LEN = 1000000000
......
......@@ -15,8 +15,7 @@ if six.PY3:
from time import perf_counter as timer # noqa
__all__ = ['total_timer', 'timed_operation',
'print_total_timer', 'IterSpeedCounter', 'Timer']
__all__ = ['timed_operation', 'IterSpeedCounter', 'Timer']
@contextmanager
......@@ -55,7 +54,7 @@ _TOTAL_TIMER_DATA = defaultdict(StatCounter)
@contextmanager
def total_timer(msg):
""" A context which add the time spent inside to TotalTimer. """
""" A context which add the time spent inside to the global TotalTimer. """
start = timer()
yield
t = timer() - start
......@@ -64,7 +63,7 @@ def total_timer(msg):
def print_total_timer():
"""
Print the content of the TotalTimer, if it's not empty. This function will automatically get
Print the content of the global TotalTimer, if it's not empty. This function will automatically get
called when program exits.
"""
if len(_TOTAL_TIMER_DATA) == 0:
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment