Commit 8f797c63 authored by Yuxin Wu's avatar Yuxin Wu

api docs for utils/

parent 1b73d9cc
...@@ -10,3 +10,4 @@ cd "$PROG_DIR" ...@@ -10,3 +10,4 @@ cd "$PROG_DIR"
make clean make clean
#sphinx-apidoc -o modules ../tensorpack -f -d 10 #sphinx-apidoc -o modules ../tensorpack -f -d 10
make html make html
xdotool windowactivate --sync $(xdotool search --desktop 0 Chromium) key F5
...@@ -6,3 +6,4 @@ msgpack-python ...@@ -6,3 +6,4 @@ msgpack-python
msgpack-numpy msgpack-numpy
pyzmq pyzmq
subprocess32; python_version < '3.0' subprocess32; python_version < '3.0'
functools32; python_version < '3.0'
...@@ -4,20 +4,23 @@ ...@@ -4,20 +4,23 @@
# Author: Yuxin Wu <ppwwyyxxc@gmail.com> # Author: Yuxin Wu <ppwwyyxxc@gmail.com>
import operator
import inspect import inspect
import six import six
import functools
import collections
from . import logger from . import logger
if six.PY2:
import functools32 as functools
else:
import functools
__all__ = ['map_arg', 'memoized', 'shape2d', 'memoized_ignoreargs', 'log_once'] __all__ = ['map_arg', 'memoized', 'shape2d', 'memoized_ignoreargs', 'log_once']
def map_arg(**maps): def map_arg(**maps):
""" """
Apply a mapping on certains argument before calling original function. Apply a mapping on certains argument before calling the original function.
maps: {key: map_func}
Args:
maps (dict): {key: map_func}
""" """
def deco(func): def deco(func):
@functools.wraps(func) @functools.wraps(func)
...@@ -31,45 +34,18 @@ def map_arg(**maps): ...@@ -31,45 +34,18 @@ def map_arg(**maps):
return deco return deco
class memoized(object): memoized = functools.lru_cache(maxsize=None)
'''Decorator. Caches a function's return value each time it is called. """ Equivalent to :func:`functools.lru_cache` """
If called later with the same arguments, the cached value is returned
(not reevaluated).
'''
def __init__(self, func):
self.func = func
self.cache = {}
def __call__(self, *args, **kwargs):
kwlist = tuple(sorted(list(kwargs), key=operator.itemgetter(0)))
if not isinstance(args, collections.Hashable) or \
not isinstance(kwlist, collections.Hashable):
logger.warn("Arguments to memoized call is unhashable!")
# uncacheable. a list, for instance.
# better to not cache than blow up.
return self.func(*args, **kwargs)
key = (args, kwlist)
if key in self.cache:
return self.cache[key]
else:
value = self.func(*args, **kwargs)
self.cache[key] = value
return value
def __repr__(self):
'''Return the function's docstring.'''
return self.func.__doc__
def __get__(self, obj, objtype):
'''Support instance methods.'''
return functools.partial(self.__call__, obj)
_MEMOIZED_NOARGS = {} _MEMOIZED_NOARGS = {}
def memoized_ignoreargs(func): def memoized_ignoreargs(func):
"""
A decorator. It performs memoization ignoring the arguments used to call
the function.
"""
hash(func) # make sure it is hashable. TODO is it necessary? hash(func) # make sure it is hashable. TODO is it necessary?
def wrapper(*args, **kwargs): def wrapper(*args, **kwargs):
...@@ -93,7 +69,13 @@ def memoized_ignoreargs(func): ...@@ -93,7 +69,13 @@ def memoized_ignoreargs(func):
def shape2d(a): def shape2d(a):
""" """
Ensure a 2D shape.
Args:
a: a int or tuple/list of length 2 a: a int or tuple/list of length 2
Returns:
list: of length 2. if ``a`` is a int, return ``[a, a]``.
""" """
if type(a) == int: if type(a) == int:
return [a, a] return [a, a]
...@@ -105,4 +87,12 @@ def shape2d(a): ...@@ -105,4 +87,12 @@ def shape2d(a):
@memoized @memoized
def log_once(message, func): def log_once(message, func):
"""
Log certain message only once. Call this function more than one times with
the same message will result in no-op.
Args:
message(str): message to log
func(str): the name of the logger method. e.g. "info", "warn", "error".
"""
getattr(logger, func)(message) getattr(logger, func)(message)
...@@ -36,15 +36,18 @@ class StoppableThread(threading.Thread): ...@@ -36,15 +36,18 @@ class StoppableThread(threading.Thread):
self._stop_evt = threading.Event() self._stop_evt = threading.Event()
def stop(self): def stop(self):
""" stop the thread""" """ Stop the thread"""
self._stop_evt.set() self._stop_evt.set()
def stopped(self): def stopped(self):
""" check whether the thread is stopped or not""" """
Returns:
bool: whether the thread is stopped or not
"""
return self._stop_evt.isSet() return self._stop_evt.isSet()
def queue_put_stoppable(self, q, obj): def queue_put_stoppable(self, q, obj):
""" put obj to queue, but will give up if the thread is stopped""" """ Put obj to queue, but will give up when the thread is stopped"""
while not self.stopped(): while not self.stopped():
try: try:
q.put(obj, timeout=5) q.put(obj, timeout=5)
...@@ -53,7 +56,7 @@ class StoppableThread(threading.Thread): ...@@ -53,7 +56,7 @@ class StoppableThread(threading.Thread):
pass pass
def queue_get_stoppable(self, q): def queue_get_stoppable(self, q):
""" take obj from queue, but will give up if the thread is stopped""" """ Take obj from queue, but will give up when the thread is stopped"""
while not self.stopped(): while not self.stopped():
try: try:
return q.get(timeout=5) return q.get(timeout=5)
...@@ -66,7 +69,8 @@ class LoopThread(StoppableThread): ...@@ -66,7 +69,8 @@ class LoopThread(StoppableThread):
def __init__(self, func, pausable=True): def __init__(self, func, pausable=True):
""" """
:param func: the function to run Args:
func: the function to run
""" """
super(LoopThread, self).__init__() super(LoopThread, self).__init__()
self._func = func self._func = func
...@@ -83,10 +87,12 @@ class LoopThread(StoppableThread): ...@@ -83,10 +87,12 @@ class LoopThread(StoppableThread):
self._func() self._func()
def pause(self): def pause(self):
""" Pause the loop """
assert self._pausable assert self._pausable
self._lock.acquire() self._lock.acquire()
def resume(self): def resume(self):
""" Resume the loop """
assert self._pausable assert self._pausable
self._lock.release() self._lock.release()
...@@ -97,6 +103,12 @@ class DIE(object): ...@@ -97,6 +103,12 @@ class DIE(object):
def ensure_proc_terminate(proc): def ensure_proc_terminate(proc):
"""
Make sure processes terminate when main process exit.
Args:
proc (multiprocessing.Process or list)
"""
if isinstance(proc, list): if isinstance(proc, list):
for p in proc: for p in proc:
ensure_proc_terminate(p) ensure_proc_terminate(p)
...@@ -117,12 +129,21 @@ def ensure_proc_terminate(proc): ...@@ -117,12 +129,21 @@ def ensure_proc_terminate(proc):
@contextmanager @contextmanager
def mask_sigint(): def mask_sigint():
"""
Returns:
a context where ``SIGINT`` is ignored.
"""
sigint_handler = signal.signal(signal.SIGINT, signal.SIG_IGN) sigint_handler = signal.signal(signal.SIGINT, signal.SIG_IGN)
yield yield
signal.signal(signal.SIGINT, sigint_handler) signal.signal(signal.SIGINT, sigint_handler)
def start_proc_mask_signal(proc): def start_proc_mask_signal(proc):
""" Start process(es) with SIGINT ignored.
Args:
proc: (multiprocessing.Process or list)
"""
if not isinstance(proc, list): if not isinstance(proc, list):
proc = [proc] proc = [proc]
...@@ -132,6 +153,12 @@ def start_proc_mask_signal(proc): ...@@ -132,6 +153,12 @@ def start_proc_mask_signal(proc):
def subproc_call(cmd, timeout=None): def subproc_call(cmd, timeout=None):
"""
Execute a command with timeout, and return both STDOUT/STDERR.
Args:
cmd(str): the command to execute.
timeout(float): timeout in seconds.
"""
try: try:
output = subprocess.check_output( output = subprocess.check_output(
cmd, stderr=subprocess.STDOUT, cmd, stderr=subprocess.STDOUT,
...@@ -147,15 +174,28 @@ def subproc_call(cmd, timeout=None): ...@@ -147,15 +174,28 @@ def subproc_call(cmd, timeout=None):
class OrderedContainer(object): class OrderedContainer(object):
""" """
Like a priority queue, but will always wait for item with index (x+1) before producing (x+2). Like a queue, but will always wait to receive item with rank
(x+1) and produce (x+1) before producing (x+2).
Warning:
It is not thread-safe.
""" """
def __init__(self, start=0): def __init__(self, start=0):
"""
Args:
start(int): the starting rank.
"""
self.ranks = [] self.ranks = []
self.data = [] self.data = []
self.wait_for = start self.wait_for = start
def put(self, rank, val): def put(self, rank, val):
"""
Args:
rank(int): rank of th element. All elements must have different ranks.
val: an object
"""
idx = bisect.bisect(self.ranks, rank) idx = bisect.bisect(self.ranks, rank)
self.ranks.insert(idx, rank) self.ranks.insert(idx, rank)
self.data.insert(idx, val) self.data.insert(idx, val)
...@@ -183,9 +223,11 @@ class OrderedResultGatherProc(multiprocessing.Process): ...@@ -183,9 +223,11 @@ class OrderedResultGatherProc(multiprocessing.Process):
def __init__(self, data_queue, nr_producer, start=0): def __init__(self, data_queue, nr_producer, start=0):
""" """
:param data_queue: a multiprocessing.Queue to produce input dp Args:
:param nr_producer: number of producer processes. Will terminate after receiving this many of DIE sentinel. data_queue(multiprocessing.Queue): a queue which contains datapoints.
:param start: the first task index nr_producer(int): number of producer processes. This process will
terminate after receiving this many of :class:`DIE` sentinel.
start(int): the rank of the first object
""" """
super(OrderedResultGatherProc, self).__init__() super(OrderedResultGatherProc, self).__init__()
self.data_queue = data_queue self.data_queue = data_queue
......
...@@ -9,6 +9,7 @@ __all__ = ['enable_call_trace'] ...@@ -9,6 +9,7 @@ __all__ = ['enable_call_trace']
def enable_call_trace(): def enable_call_trace():
""" Enable trace for calls to any function. """
def tracer(frame, event, arg): def tracer(frame, event, arg):
if event == 'call': if event == 'call':
co = frame.f_code co = frame.f_code
......
...@@ -29,12 +29,12 @@ class Discretizer1D(Discretizer): ...@@ -29,12 +29,12 @@ class Discretizer1D(Discretizer):
class UniformDiscretizer1D(Discretizer1D): class UniformDiscretizer1D(Discretizer1D):
def __init__(self, minv, maxv, spacing): def __init__(self, minv, maxv, spacing):
""" """
:params minv: minimum value of the first bin Args:
:params maxv: maximum value of the last bin minv(float): minimum value of the first bin
:param spacing: width of a bin maxv(float): maximum value of the last bin
spacing(float): width of a bin
""" """
self.minv = float(minv) self.minv = float(minv)
self.maxv = float(maxv) self.maxv = float(maxv)
...@@ -42,9 +42,19 @@ class UniformDiscretizer1D(Discretizer1D): ...@@ -42,9 +42,19 @@ class UniformDiscretizer1D(Discretizer1D):
self.nr_bin = int(np.ceil((self.maxv - self.minv) / self.spacing)) self.nr_bin = int(np.ceil((self.maxv - self.minv) / self.spacing))
def get_nr_bin(self): def get_nr_bin(self):
"""
Returns:
int: number of bins
"""
return self.nr_bin return self.nr_bin
def get_bin(self, v): def get_bin(self, v):
"""
Args:
v(float): value
Returns:
int: the bin index for value ``v``.
"""
if v < self.minv: if v < self.minv:
log_once("UniformDiscretizer1D: value smaller than min!", 'warn') log_once("UniformDiscretizer1D: value smaller than min!", 'warn')
return 0 return 0
...@@ -56,10 +66,23 @@ class UniformDiscretizer1D(Discretizer1D): ...@@ -56,10 +66,23 @@ class UniformDiscretizer1D(Discretizer1D):
0, self.nr_bin - 1)) 0, self.nr_bin - 1))
def get_bin_center(self, bin_id): def get_bin_center(self, bin_id):
"""
Args:
bin_id(int)
Returns:
float: the center of this bin.
"""
return self.minv + self.spacing * (bin_id + 0.5) return self.minv + self.spacing * (bin_id + 0.5)
def get_distribution(self, v, smooth_factor=0.05, smooth_radius=2): def get_distribution(self, v, smooth_factor=0.05, smooth_radius=2):
""" return a smoothed one-hot distribution of the sample v. """
Args:
v(float): a sample
smooth_factor(float):
smooth_radius(int):
Returns:
numpy.ndarray: array of length ``self.nr_bin``, a smoothed one-hot
distribution centered at the bin of sample ``v``.
""" """
b = self.get_bin(v) b = self.get_bin(v)
ret = np.zeros((self.nr_bin, ), dtype='float32') ret = np.zeros((self.nr_bin, ), dtype='float32')
...@@ -78,10 +101,11 @@ class UniformDiscretizer1D(Discretizer1D): ...@@ -78,10 +101,11 @@ class UniformDiscretizer1D(Discretizer1D):
class UniformDiscretizerND(Discretizer): class UniformDiscretizerND(Discretizer):
""" A combination of several :class:`UniformDiscretizer1D`. """
def __init__(self, *min_max_spacing): def __init__(self, *min_max_spacing):
""" """
:params min_max_spacing: (minv, maxv, spacing) for each dimension Args:
min_max_spacing: (minv, maxv, spacing) for each dimension
""" """
self.n = len(min_max_spacing) self.n = len(min_max_spacing)
self.discretizers = [UniformDiscretizer1D(*k) for k in min_max_spacing] self.discretizers = [UniformDiscretizer1D(*k) for k in min_max_spacing]
......
...@@ -13,7 +13,11 @@ __all__ = ['mkdir_p', 'download', 'recursive_walk'] ...@@ -13,7 +13,11 @@ __all__ = ['mkdir_p', 'download', 'recursive_walk']
def mkdir_p(dirname): def mkdir_p(dirname):
""" make a dir recursively, but do nothing if the dir exists""" """ Make a dir recursively, but do nothing if the dir exists
Args:
dirname(str):
"""
assert dirname is not None assert dirname is not None
if dirname == '' or os.path.isdir(dirname): if dirname == '' or os.path.isdir(dirname):
return return
...@@ -25,6 +29,10 @@ def mkdir_p(dirname): ...@@ -25,6 +29,10 @@ def mkdir_p(dirname):
def download(url, dir): def download(url, dir):
"""
Download URL to a directory. Will figure out the filename automatically
from URL.
"""
mkdir_p(dir) mkdir_p(dir)
fname = url.split('/')[-1] fname = url.split('/')[-1]
fpath = os.path.join(dir, fname) fpath = os.path.join(dir, fname)
...@@ -50,6 +58,10 @@ def download(url, dir): ...@@ -50,6 +58,10 @@ def download(url, dir):
def recursive_walk(rootdir): def recursive_walk(rootdir):
"""
Yields:
str: All files in rootdir, recursively.
"""
for r, dirs, files in os.walk(rootdir): for r, dirs, files in os.walk(rootdir):
for f in files: for f in files:
yield os.path.join(r, f) yield os.path.join(r, f)
......
...@@ -20,8 +20,10 @@ globalns = NS() ...@@ -20,8 +20,10 @@ globalns = NS()
def use_global_argument(args): def use_global_argument(args):
""" """
Add the content of argparse.Namespace to globalns Add the content of :class:`argparse.Namespace` to globalns.
:param args: Argument
Args:
args (argparse.Namespace): arguments
""" """
assert isinstance(args, argparse.Namespace), type(args) assert isinstance(args, argparse.Namespace), type(args)
for k, v in six.iteritems(vars(args)): for k, v in six.iteritems(vars(args)):
......
...@@ -10,6 +10,10 @@ __all__ = ['change_gpu', 'get_nr_gpu', 'get_gpus'] ...@@ -10,6 +10,10 @@ __all__ = ['change_gpu', 'get_nr_gpu', 'get_gpus']
def change_gpu(val): def change_gpu(val):
"""
Returns:
a context where ``CUDA_VISIBLE_DEVICES=val``.
"""
val = str(val) val = str(val)
if val == '-1': if val == '-1':
val = '' val = ''
...@@ -17,13 +21,20 @@ def change_gpu(val): ...@@ -17,13 +21,20 @@ def change_gpu(val):
def get_nr_gpu(): def get_nr_gpu():
"""
Returns:
int: the number of GPU from ``CUDA_VISIBLE_DEVICES``.
"""
env = os.environ.get('CUDA_VISIBLE_DEVICES', None) env = os.environ.get('CUDA_VISIBLE_DEVICES', None)
assert env is not None, 'gpu not set!' # TODO assert env is not None, 'gpu not set!' # TODO
return len(env.split(',')) return len(env.split(','))
def get_gpus(): def get_gpus():
""" return a list of GPU physical id""" """
Returns:
list: a list of int of GPU id.
"""
env = os.environ.get('CUDA_VISIBLE_DEVICES', None) env = os.environ.get('CUDA_VISIBLE_DEVICES', None)
assert env is not None, 'gpu not set!' # TODO assert env is not None, 'gpu not set!' # TODO
return map(int, env.strip().split(',')) return map(int, env.strip().split(','))
...@@ -94,7 +94,13 @@ class CaffeLayerProcessor(object): ...@@ -94,7 +94,13 @@ class CaffeLayerProcessor(object):
def load_caffe(model_desc, model_file): def load_caffe(model_desc, model_file):
""" """
:return: a dict of params Load a caffe model. You must be able to ``import caffe`` to use this
function.
Args:
model_desc (str): path to caffe model description file (.prototxt).
model_file (str): path to caffe model parameter file (.caffemodel).
Returns:
dict: the parameters.
""" """
with change_env('GLOG_minloglevel', '2'): with change_env('GLOG_minloglevel', '2'):
import caffe import caffe
...@@ -107,6 +113,11 @@ def load_caffe(model_desc, model_file): ...@@ -107,6 +113,11 @@ def load_caffe(model_desc, model_file):
def get_caffe_pb(): def get_caffe_pb():
"""
Get caffe protobuf.
Returns:
The imported caffe protobuf module.
"""
dir = get_dataset_path('caffe') dir = get_dataset_path('caffe')
caffe_pb_file = os.path.join(dir, 'caffe_pb2.py') caffe_pb_file = os.path.join(dir, 'caffe_pb2.py')
if not os.path.isfile(caffe_pb_file): if not os.path.isfile(caffe_pb_file):
......
...@@ -73,8 +73,10 @@ def _set_file(path): ...@@ -73,8 +73,10 @@ def _set_file(path):
def set_logger_dir(dirname, action=None): def set_logger_dir(dirname, action=None):
""" """
Set the directory for global logging. Set the directory for global logging.
:param dirname: log directory
:param action: an action (k/b/d/n) to be performed. Will ask user by default. Args:
dirname(str): log directory
action(str): an action of ("k","b","d","n") to be performed. Will ask user by default.
""" """
global LOG_FILE, LOG_DIR global LOG_FILE, LOG_DIR
if os.path.isdir(dirname): if os.path.isdir(dirname):
...@@ -108,13 +110,14 @@ If you're resuming from a previous run you can choose to keep it.""") ...@@ -108,13 +110,14 @@ If you're resuming from a previous run you can choose to keep it.""")
def disable_logger(): def disable_logger():
""" disable all logging ability from this moment""" """ Disable all logging ability from this moment"""
for func in _LOGGING_METHOD: for func in _LOGGING_METHOD:
globals()[func] = lambda x: None globals()[func] = lambda x: None
def auto_set_dir(action=None, overwrite=False): def auto_set_dir(action=None, overwrite=False):
""" set log directory to a subdir inside 'train_log', with the name being """
Set log directory to a subdir inside "train_log", with the name being
the main python file currently running""" the main python file currently running"""
if LOG_DIR is not None and not overwrite: if LOG_DIR is not None and not overwrite:
# dir already set # dir already set
...@@ -128,4 +131,5 @@ def auto_set_dir(action=None, overwrite=False): ...@@ -128,4 +131,5 @@ def auto_set_dir(action=None, overwrite=False):
def warn_dependency(name, dependencies): def warn_dependency(name, dependencies):
""" Print warning about an import failure due to missing dependencies. """
warn("Failed to import '{}', {} won't be available'".format(dependencies, name)) # noqa: F821 warn("Failed to import '{}', {} won't be available'".format(dependencies, name)) # noqa: F821
...@@ -9,8 +9,13 @@ __all__ = ['LookUpTable'] ...@@ -9,8 +9,13 @@ __all__ = ['LookUpTable']
class LookUpTable(object): class LookUpTable(object):
""" Maintain mapping from index to objects. """
def __init__(self, objlist): def __init__(self, objlist):
"""
Args:
objlist(list): list of objects
"""
self.idx2obj = dict(enumerate(objlist)) self.idx2obj = dict(enumerate(objlist))
self.obj2idx = {v: k for k, v in six.iteritems(self.idx2obj)} self.obj2idx = {v: k for k, v in six.iteritems(self.idx2obj)}
......
...@@ -8,8 +8,9 @@ import numpy as np ...@@ -8,8 +8,9 @@ import numpy as np
class Rect(object): class Rect(object):
""" """
A Rectangle. A rectangle class.
Note that x1 = x+w, not x+w-1 or something
Note that x1 = x + w, not x+w-1 or something else.
""" """
__slots__ = ['x', 'y', 'w', 'h'] __slots__ = ['x', 'y', 'w', 'h']
...@@ -51,9 +52,11 @@ class Rect(object): ...@@ -51,9 +52,11 @@ class Rect(object):
def validate(self, shape=None): def validate(self, shape=None):
""" """
Is a valid bounding box within this shape Check that this rect is a valid bounding box within this shape.
:param shape: [h, w] Args:
:returns: boolean shape: [h, w]
Returns:
bool
""" """
if min(self.x, self.y) < 0: if min(self.x, self.y) < 0:
return False return False
......
...@@ -11,8 +11,18 @@ __all__ = ['loads', 'dumps'] ...@@ -11,8 +11,18 @@ __all__ = ['loads', 'dumps']
def dumps(obj): def dumps(obj):
"""
Serialize an object.
Returns:
str
"""
return msgpack.dumps(obj, use_bin_type=True) return msgpack.dumps(obj, use_bin_type=True)
def loads(buf): def loads(buf):
"""
Args:
buf (str): serialized object.
"""
return msgpack.loads(buf) return msgpack.loads(buf)
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
# Author: Yuxin Wu <ppwwyyxx@gmail.com> # Author: Yuxin Wu <ppwwyyxx@gmail.com>
import numpy as np import numpy as np
__all__ = ['StatCounter', 'Accuracy', 'BinaryStatistics', 'RatioCounter', __all__ = ['StatCounter', 'BinaryStatistics', 'RatioCounter', 'Accuracy',
'OnlineMoments'] 'OnlineMoments']
...@@ -14,6 +14,10 @@ class StatCounter(object): ...@@ -14,6 +14,10 @@ class StatCounter(object):
self.reset() self.reset()
def feed(self, v): def feed(self, v):
"""
Args:
v(float or np.ndarray): has to be the same shape between calls.
"""
self._values.append(v) self._values.append(v)
def reset(self): def reset(self):
...@@ -40,7 +44,7 @@ class StatCounter(object): ...@@ -40,7 +44,7 @@ class StatCounter(object):
class RatioCounter(object): class RatioCounter(object):
""" A counter to count ratio of something""" """ A counter to count ratio of something. """
def __init__(self): def __init__(self):
self.reset() self.reset()
...@@ -50,6 +54,11 @@ class RatioCounter(object): ...@@ -50,6 +54,11 @@ class RatioCounter(object):
self._cnt = 0 self._cnt = 0
def feed(self, cnt, tot=1): def feed(self, cnt, tot=1):
"""
Args:
cnt(int): the count of some event of interest.
tot(int): the total number of events.
"""
self._tot += tot self._tot += tot
self._cnt += cnt self._cnt += cnt
...@@ -61,6 +70,10 @@ class RatioCounter(object): ...@@ -61,6 +70,10 @@ class RatioCounter(object):
@property @property
def count(self): def count(self):
"""
Returns:
int: the total
"""
return self._tot return self._tot
...@@ -90,8 +103,9 @@ class BinaryStatistics(object): ...@@ -90,8 +103,9 @@ class BinaryStatistics(object):
def feed(self, pred, label): def feed(self, pred, label):
""" """
:param pred: 0/1 np array Args:
:param label: 0/1 np array of the same size pred (np.ndarray): binary array.
label (np.ndarray): binary array of the same size.
""" """
assert pred.shape == label.shape, "{} != {}".format(pred.shape, label.shape) assert pred.shape == label.shape, "{} != {}".format(pred.shape, label.shape)
self.nr_pos += (label == 1).sum() self.nr_pos += (label == 1).sum()
...@@ -127,7 +141,8 @@ class BinaryStatistics(object): ...@@ -127,7 +141,8 @@ class BinaryStatistics(object):
class OnlineMoments(object): class OnlineMoments(object):
"""Compute 1st and 2nd moments online """Compute 1st and 2nd moments online (to avoid storing all elements).
See algorithm at: https://www.wikiwand.com/en/Algorithms_for_calculating_variance#/Online_algorithm See algorithm at: https://www.wikiwand.com/en/Algorithms_for_calculating_variance#/Online_algorithm
""" """
...@@ -137,6 +152,10 @@ class OnlineMoments(object): ...@@ -137,6 +152,10 @@ class OnlineMoments(object):
self._n = 0 self._n = 0
def feed(self, x): def feed(self, x):
"""
Args:
x (float or np.ndarray): must have the same shape.
"""
self._n += 1 self._n += 1
delta = x - self._mean delta = x - self._mean
self._mean += delta * (1.0 / self._n) self._mean += delta * (1.0 / self._n)
......
...@@ -17,30 +17,27 @@ __all__ = ['total_timer', 'timed_operation', ...@@ -17,30 +17,27 @@ __all__ = ['total_timer', 'timed_operation',
'print_total_timer', 'IterSpeedCounter'] 'print_total_timer', 'IterSpeedCounter']
class IterSpeedCounter(object): @contextmanager
""" To count how often some code gets reached""" def timed_operation(msg, log_start=False):
"""
Surround a context with a timer.
def __init__(self, print_every, name=None): Args:
self.cnt = 0 msg(str): the log to print.
self.print_every = int(print_every) log_start(bool): whether to print also at the beginning.
self.name = name if name else 'IterSpeed'
def reset(self): Example:
self.start = time.time() .. code-block:: python
def __call__(self): with timed_operation('Good Stuff'):
if self.cnt == 0: time.sleep(1)
self.reset()
self.cnt += 1
if self.cnt % self.print_every != 0:
return
t = time.time() - self.start
logger.info("{}: {:.2f} sec, {} times, {:.3g} sec/time".format(
self.name, t, self.cnt, t / self.cnt))
Will print:
@contextmanager .. code-block:: python
def timed_operation(msg, log_start=False):
Good stuff finished, time:1sec.
"""
if log_start: if log_start:
logger.info('Start {} ...'.format(msg)) logger.info('Start {} ...'.format(msg))
start = time.time() start = time.time()
...@@ -54,6 +51,7 @@ _TOTAL_TIMER_DATA = defaultdict(StatCounter) ...@@ -54,6 +51,7 @@ _TOTAL_TIMER_DATA = defaultdict(StatCounter)
@contextmanager @contextmanager
def total_timer(msg): def total_timer(msg):
""" A context which add the time spent inside to TotalTimer. """
start = time.time() start = time.time()
yield yield
t = time.time() - start t = time.time() - start
...@@ -61,6 +59,10 @@ def total_timer(msg): ...@@ -61,6 +59,10 @@ def total_timer(msg):
def print_total_timer(): def print_total_timer():
"""
Print the content of the TotalTimer, if it's not empty. This function will automatically get
called when program exits.
"""
if len(_TOTAL_TIMER_DATA) == 0: if len(_TOTAL_TIMER_DATA) == 0:
return return
for k, v in six.iteritems(_TOTAL_TIMER_DATA): for k, v in six.iteritems(_TOTAL_TIMER_DATA):
...@@ -69,3 +71,41 @@ def print_total_timer(): ...@@ -69,3 +71,41 @@ def print_total_timer():
atexit.register(print_total_timer) atexit.register(print_total_timer)
class IterSpeedCounter(object):
""" Test how often some code gets reached.
Example:
Print the speed of the iteration every 100 times.
.. code-block:: python
speed = IterSpeedCounter(100)
for k in range(1000):
# do something
speed()
"""
def __init__(self, print_every, name=None):
"""
Args:
print_every(int): interval to print.
name(str): name to used when print.
"""
self.cnt = 0
self.print_every = int(print_every)
self.name = name if name else 'IterSpeed'
def reset(self):
self.start = time.time()
def __call__(self):
if self.cnt == 0:
self.reset()
self.cnt += 1
if self.cnt % self.print_every != 0:
return
t = time.time() - self.start
logger.info("{}: {:.2f} sec, {} times, {:.3g} sec/time".format(
self.name, t, self.cnt, t / self.cnt))
...@@ -22,6 +22,14 @@ __all__ = ['change_env', ...@@ -22,6 +22,14 @@ __all__ = ['change_env',
@contextmanager @contextmanager
def change_env(name, val): def change_env(name, val):
"""
Args:
name(str), val(str):
Returns:
a context where the environment variable ``name`` being set to
``val``. It will be set back after the context exits.
"""
oldval = os.environ.get(name, None) oldval = os.environ.get(name, None)
os.environ[name] = val os.environ[name] = val
yield yield
...@@ -32,7 +40,14 @@ def change_env(name, val): ...@@ -32,7 +40,14 @@ def change_env(name, val):
def get_rng(obj=None): def get_rng(obj=None):
""" obj: some object to use to generate random seed""" """
Get a good RNG.
Args:
obj: some object to use to generate random seed.
Returns:
np.random.RandomState: the RNG.
"""
seed = (id(obj) + os.getpid() + seed = (id(obj) + os.getpid() +
int(datetime.now().strftime("%Y%m%d%H%M%S%f"))) % 4294967295 int(datetime.now().strftime("%Y%m%d%H%M%S%f"))) % 4294967295
return np.random.RandomState(seed) return np.random.RandomState(seed)
...@@ -43,10 +58,18 @@ _EXECUTE_HISTORY = set() ...@@ -43,10 +58,18 @@ _EXECUTE_HISTORY = set()
def execute_only_once(): def execute_only_once():
""" """
when called with: Each called in the code to this function is guranteed to return True the
first time and False afterwards.
Returns:
bool: whether this is the first time this function gets called from
this line of code.
Example:
.. code-block:: python
if execute_only_once(): if execute_only_once():
# do something # do something only once
The body is guranteed to be executed only the first time.
""" """
f = inspect.currentframe().f_back f = inspect.currentframe().f_back
ident = (f.f_code.co_filename, f.f_lineno) ident = (f.f_code.co_filename, f.f_lineno)
...@@ -57,6 +80,15 @@ def execute_only_once(): ...@@ -57,6 +80,15 @@ def execute_only_once():
def get_dataset_path(*args): def get_dataset_path(*args):
"""
Get the path to some dataset under ``$TENSORPACK_DATASET``.
Args:
args: strings to be joined to form path.
Returns:
str: path to the dataset.
"""
d = os.environ.get('TENSORPACK_DATASET', None) d = os.environ.get('TENSORPACK_DATASET', None)
if d is None: if d is None:
d = os.path.abspath(os.path.join( d = os.path.abspath(os.path.join(
...@@ -69,6 +101,14 @@ def get_dataset_path(*args): ...@@ -69,6 +101,14 @@ def get_dataset_path(*args):
def get_tqdm_kwargs(**kwargs): def get_tqdm_kwargs(**kwargs):
"""
Return default arguments to be used with tqdm.
Args:
kwargs: extra arguments to be used.
Returns:
dict:
"""
default = dict( default = dict(
smoothing=0.5, smoothing=0.5,
dynamic_ncols=True, dynamic_ncols=True,
...@@ -85,9 +125,15 @@ def get_tqdm_kwargs(**kwargs): ...@@ -85,9 +125,15 @@ def get_tqdm_kwargs(**kwargs):
def get_tqdm(**kwargs): def get_tqdm(**kwargs):
""" Similar to :func:`get_tqdm_kwargs`, but returns the tqdm object
directly. """
return tqdm(**get_tqdm_kwargs(**kwargs)) return tqdm(**get_tqdm_kwargs(**kwargs))
def building_rtfd(): def building_rtfd():
"""
Returns:
bool: if tensorpack is imported to generate docs now.
"""
return os.environ.get('READTHEDOCS') == 'True' \ return os.environ.get('READTHEDOCS') == 'True' \
or os.environ.get('TENSORPACK_DOC_BUILDING') or os.environ.get('TENSORPACK_DOC_BUILDING')
...@@ -16,11 +16,12 @@ try: ...@@ -16,11 +16,12 @@ try:
except ImportError: except ImportError:
pass pass
__all__ = ['pyplot2img', 'build_patch_list', 'pyplot_viz', __all__ = ['pyplot2img', 'interactive_imshow', 'build_patch_list',
'dump_dataflow_images', 'interactive_imshow'] 'pyplot_viz', 'dump_dataflow_images']
def pyplot2img(plt): def pyplot2img(plt):
""" Convert a pyplot instance to image """
buf = io.BytesIO() buf = io.BytesIO()
plt.axis('off') plt.axis('off')
plt.savefig(buf, format='png', bbox_inches='tight', pad_inches=0) plt.savefig(buf, format='png', bbox_inches='tight', pad_inches=0)
...@@ -32,8 +33,13 @@ def pyplot2img(plt): ...@@ -32,8 +33,13 @@ def pyplot2img(plt):
def pyplot_viz(img, shape=None): def pyplot_viz(img, shape=None):
""" use pyplot to visualize the image """ Use pyplot to visualize the image. e.g., when input is grayscale, the result
Note: this is quite slow. and the returned image will have a border will automatically have a colormap.
Returns:
np.ndarray: an image.
Note:
this is quite slow. and the returned image will have a border
""" """
plt.clf() plt.clf()
plt.axes([0, 0, 1, 1]) plt.axes([0, 0, 1, 1])
...@@ -54,8 +60,17 @@ def minnone(x, y): ...@@ -54,8 +60,17 @@ def minnone(x, y):
def interactive_imshow(img, lclick_cb=None, rclick_cb=None, **kwargs): def interactive_imshow(img, lclick_cb=None, rclick_cb=None, **kwargs):
""" """
:param lclick_cb: a callback(img, x, y) for left click Args:
:param kwargs: can be {key_cb_a ... key_cb_z: callback(img)} img (np.ndarray): an image to show.
lclick_cb: a callback func(img, x, y) for left click event.
kwargs: can be {key_cb_a: callback_img, key_cb_b: callback_img}, to
specify a callback func(img) for keypress.
Some existing keypress event handler:
* q: destroy the current window
* x: execute ``sys.exit()``
* s: save image to "out.png"
""" """
name = 'random_window_name' name = 'random_window_name'
cv2.imshow(name, img) cv2.imshow(name, img)
...@@ -84,15 +99,27 @@ def build_patch_list(patch_list, ...@@ -84,15 +99,27 @@ def build_patch_list(patch_list,
shuffle=False, bgcolor=255, shuffle=False, bgcolor=255,
viz=False, lclick_cb=None): viz=False, lclick_cb=None):
""" """
Generate patches. Stacked patches into grid, to produce visualizations like the following:
:param patch_list: bhw or bhwc images in [0,255]
:param border: defaults to 0.1 * min(image_width, image_height) .. image:: https://github.com/ppwwyyxx/tensorpack/raw/master/examples/GAN/demo/CelebA-samples.jpg
:param nr_row, nr_col: rows and cols of the grid
:parma max_width, max_height: if nr_row/col are not given, use this to infer the rows and cols Args:
:param shuffle: shuffle the images patch_list(np.ndarray): NHW or NHWC images in [0,255].
:param bgcolor: background color nr_row(int), nr_col(int): rows and cols of the grid.
:param viz: use interactive imshow to visualize the results border(int): border length between images.
:param lclick_cb: only useful when viz=True. a callback(patch, idx) Defaults to ``0.1 * min(image_w, image_h)``.
max_width(int), max_height(int): Maximum allowed size of the
visualization image. If ``nr_row/nr_col`` are not given, will use this to infer the rows and cols.
shuffle(bool): shuffle the images inside ``patch_list``.
bgcolor(int): background color in [0, 255].
viz(bool): whether to use :func:`interactive_imshow` to visualize the results.
lclick_cb: A callback function to get called when ``viz==True`` and an
image get clicked. It takes the image patch and its index in
``patch_list`` as arguments. (The index is invalid when
``shuffle==True``.)
Yields:
np.ndarray: the visualization image.
""" """
# setup parameters # setup parameters
patch_list = np.asarray(patch_list) patch_list = np.asarray(patch_list)
...@@ -156,16 +183,22 @@ def dump_dataflow_images(df, index=0, batched=True, ...@@ -156,16 +183,22 @@ def dump_dataflow_images(df, index=0, batched=True,
scale=1, resize=None, viz=None, scale=1, resize=None, viz=None,
flipRGB=False, exit_after=True): flipRGB=False, exit_after=True):
""" """
:param df: a DataFlow Dump or visualize images of a :class:`DataFlow`.
:param index: the index of the image component
:param batched: whether the component contains batched images or not Args:
:param number: how many datapoint to take from the DataFlow df (DataFlow): the DataFlow.
:param output_dir: output directory to save images, default to not save. index (int): the index of the image component.
:param scale: scale the value, usually either 1 or 255 batched (bool): whether the component contains batched images (NHW or
:param resize: (h, w) or Nne, resize the images NHWC) or not (HW or HWC).
:param viz: (h, w) or None, visualize the images in grid with imshow number (int): how many datapoint to take from the DataFlow.
:param flipRGB: apply a RGB<->BGR conversion or not output_dir (str): output directory to save images, default to not save.
:param exit_after: exit the process after this function scale (float): scale the value, usually either 1 or 255.
resize (tuple or None): tuple of (h, w) to resize the images to.
viz (tuple or None): tuple of (h, w) determining the grid size to use
with :func:`build_patch_list` for visualization. No visualization will happen by
default.
flipRGB (bool): apply a RGB<->BGR conversion or not.
exit_after (bool): ``sys.exit()`` after this function.
""" """
if output_dir: if output_dir:
mkdir_p(output_dir) mkdir_p(output_dir)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment