Commit 8f797c63 authored by Yuxin Wu's avatar Yuxin Wu

api docs for utils/

parent 1b73d9cc
......@@ -10,3 +10,4 @@ cd "$PROG_DIR"
make clean
#sphinx-apidoc -o modules ../tensorpack -f -d 10
make html
xdotool windowactivate --sync $(xdotool search --desktop 0 Chromium) key F5
......@@ -6,3 +6,4 @@ msgpack-python
msgpack-numpy
pyzmq
subprocess32; python_version < '3.0'
functools32; python_version < '3.0'
......@@ -4,20 +4,23 @@
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
import operator
import inspect
import six
import functools
import collections
from . import logger
if six.PY2:
import functools32 as functools
else:
import functools
__all__ = ['map_arg', 'memoized', 'shape2d', 'memoized_ignoreargs', 'log_once']
def map_arg(**maps):
"""
Apply a mapping on certains argument before calling original function.
maps: {key: map_func}
Apply a mapping on certains argument before calling the original function.
Args:
maps (dict): {key: map_func}
"""
def deco(func):
@functools.wraps(func)
......@@ -31,45 +34,18 @@ def map_arg(**maps):
return deco
class memoized(object):
'''Decorator. Caches a function's return value each time it is called.
If called later with the same arguments, the cached value is returned
(not reevaluated).
'''
def __init__(self, func):
self.func = func
self.cache = {}
def __call__(self, *args, **kwargs):
kwlist = tuple(sorted(list(kwargs), key=operator.itemgetter(0)))
if not isinstance(args, collections.Hashable) or \
not isinstance(kwlist, collections.Hashable):
logger.warn("Arguments to memoized call is unhashable!")
# uncacheable. a list, for instance.
# better to not cache than blow up.
return self.func(*args, **kwargs)
key = (args, kwlist)
if key in self.cache:
return self.cache[key]
else:
value = self.func(*args, **kwargs)
self.cache[key] = value
return value
def __repr__(self):
'''Return the function's docstring.'''
return self.func.__doc__
def __get__(self, obj, objtype):
'''Support instance methods.'''
return functools.partial(self.__call__, obj)
memoized = functools.lru_cache(maxsize=None)
""" Equivalent to :func:`functools.lru_cache` """
_MEMOIZED_NOARGS = {}
def memoized_ignoreargs(func):
"""
A decorator. It performs memoization ignoring the arguments used to call
the function.
"""
hash(func) # make sure it is hashable. TODO is it necessary?
def wrapper(*args, **kwargs):
......@@ -93,7 +69,13 @@ def memoized_ignoreargs(func):
def shape2d(a):
"""
a: a int or tuple/list of length 2
Ensure a 2D shape.
Args:
a: a int or tuple/list of length 2
Returns:
list: of length 2. if ``a`` is a int, return ``[a, a]``.
"""
if type(a) == int:
return [a, a]
......@@ -105,4 +87,12 @@ def shape2d(a):
@memoized
def log_once(message, func):
"""
Log certain message only once. Call this function more than one times with
the same message will result in no-op.
Args:
message(str): message to log
func(str): the name of the logger method. e.g. "info", "warn", "error".
"""
getattr(logger, func)(message)
......@@ -36,15 +36,18 @@ class StoppableThread(threading.Thread):
self._stop_evt = threading.Event()
def stop(self):
""" stop the thread"""
""" Stop the thread"""
self._stop_evt.set()
def stopped(self):
""" check whether the thread is stopped or not"""
"""
Returns:
bool: whether the thread is stopped or not
"""
return self._stop_evt.isSet()
def queue_put_stoppable(self, q, obj):
""" put obj to queue, but will give up if the thread is stopped"""
""" Put obj to queue, but will give up when the thread is stopped"""
while not self.stopped():
try:
q.put(obj, timeout=5)
......@@ -53,7 +56,7 @@ class StoppableThread(threading.Thread):
pass
def queue_get_stoppable(self, q):
""" take obj from queue, but will give up if the thread is stopped"""
""" Take obj from queue, but will give up when the thread is stopped"""
while not self.stopped():
try:
return q.get(timeout=5)
......@@ -66,7 +69,8 @@ class LoopThread(StoppableThread):
def __init__(self, func, pausable=True):
"""
:param func: the function to run
Args:
func: the function to run
"""
super(LoopThread, self).__init__()
self._func = func
......@@ -83,10 +87,12 @@ class LoopThread(StoppableThread):
self._func()
def pause(self):
""" Pause the loop """
assert self._pausable
self._lock.acquire()
def resume(self):
""" Resume the loop """
assert self._pausable
self._lock.release()
......@@ -97,6 +103,12 @@ class DIE(object):
def ensure_proc_terminate(proc):
"""
Make sure processes terminate when main process exit.
Args:
proc (multiprocessing.Process or list)
"""
if isinstance(proc, list):
for p in proc:
ensure_proc_terminate(p)
......@@ -117,12 +129,21 @@ def ensure_proc_terminate(proc):
@contextmanager
def mask_sigint():
"""
Returns:
a context where ``SIGINT`` is ignored.
"""
sigint_handler = signal.signal(signal.SIGINT, signal.SIG_IGN)
yield
signal.signal(signal.SIGINT, sigint_handler)
def start_proc_mask_signal(proc):
""" Start process(es) with SIGINT ignored.
Args:
proc: (multiprocessing.Process or list)
"""
if not isinstance(proc, list):
proc = [proc]
......@@ -132,6 +153,12 @@ def start_proc_mask_signal(proc):
def subproc_call(cmd, timeout=None):
"""
Execute a command with timeout, and return both STDOUT/STDERR.
Args:
cmd(str): the command to execute.
timeout(float): timeout in seconds.
"""
try:
output = subprocess.check_output(
cmd, stderr=subprocess.STDOUT,
......@@ -147,15 +174,28 @@ def subproc_call(cmd, timeout=None):
class OrderedContainer(object):
"""
Like a priority queue, but will always wait for item with index (x+1) before producing (x+2).
Like a queue, but will always wait to receive item with rank
(x+1) and produce (x+1) before producing (x+2).
Warning:
It is not thread-safe.
"""
def __init__(self, start=0):
"""
Args:
start(int): the starting rank.
"""
self.ranks = []
self.data = []
self.wait_for = start
def put(self, rank, val):
"""
Args:
rank(int): rank of th element. All elements must have different ranks.
val: an object
"""
idx = bisect.bisect(self.ranks, rank)
self.ranks.insert(idx, rank)
self.data.insert(idx, val)
......@@ -183,9 +223,11 @@ class OrderedResultGatherProc(multiprocessing.Process):
def __init__(self, data_queue, nr_producer, start=0):
"""
:param data_queue: a multiprocessing.Queue to produce input dp
:param nr_producer: number of producer processes. Will terminate after receiving this many of DIE sentinel.
:param start: the first task index
Args:
data_queue(multiprocessing.Queue): a queue which contains datapoints.
nr_producer(int): number of producer processes. This process will
terminate after receiving this many of :class:`DIE` sentinel.
start(int): the rank of the first object
"""
super(OrderedResultGatherProc, self).__init__()
self.data_queue = data_queue
......
......@@ -9,6 +9,7 @@ __all__ = ['enable_call_trace']
def enable_call_trace():
""" Enable trace for calls to any function. """
def tracer(frame, event, arg):
if event == 'call':
co = frame.f_code
......
......@@ -29,12 +29,12 @@ class Discretizer1D(Discretizer):
class UniformDiscretizer1D(Discretizer1D):
def __init__(self, minv, maxv, spacing):
"""
:params minv: minimum value of the first bin
:params maxv: maximum value of the last bin
:param spacing: width of a bin
Args:
minv(float): minimum value of the first bin
maxv(float): maximum value of the last bin
spacing(float): width of a bin
"""
self.minv = float(minv)
self.maxv = float(maxv)
......@@ -42,9 +42,19 @@ class UniformDiscretizer1D(Discretizer1D):
self.nr_bin = int(np.ceil((self.maxv - self.minv) / self.spacing))
def get_nr_bin(self):
"""
Returns:
int: number of bins
"""
return self.nr_bin
def get_bin(self, v):
"""
Args:
v(float): value
Returns:
int: the bin index for value ``v``.
"""
if v < self.minv:
log_once("UniformDiscretizer1D: value smaller than min!", 'warn')
return 0
......@@ -56,10 +66,23 @@ class UniformDiscretizer1D(Discretizer1D):
0, self.nr_bin - 1))
def get_bin_center(self, bin_id):
"""
Args:
bin_id(int)
Returns:
float: the center of this bin.
"""
return self.minv + self.spacing * (bin_id + 0.5)
def get_distribution(self, v, smooth_factor=0.05, smooth_radius=2):
""" return a smoothed one-hot distribution of the sample v.
"""
Args:
v(float): a sample
smooth_factor(float):
smooth_radius(int):
Returns:
numpy.ndarray: array of length ``self.nr_bin``, a smoothed one-hot
distribution centered at the bin of sample ``v``.
"""
b = self.get_bin(v)
ret = np.zeros((self.nr_bin, ), dtype='float32')
......@@ -78,10 +101,11 @@ class UniformDiscretizer1D(Discretizer1D):
class UniformDiscretizerND(Discretizer):
""" A combination of several :class:`UniformDiscretizer1D`. """
def __init__(self, *min_max_spacing):
"""
:params min_max_spacing: (minv, maxv, spacing) for each dimension
Args:
min_max_spacing: (minv, maxv, spacing) for each dimension
"""
self.n = len(min_max_spacing)
self.discretizers = [UniformDiscretizer1D(*k) for k in min_max_spacing]
......
......@@ -13,7 +13,11 @@ __all__ = ['mkdir_p', 'download', 'recursive_walk']
def mkdir_p(dirname):
""" make a dir recursively, but do nothing if the dir exists"""
""" Make a dir recursively, but do nothing if the dir exists
Args:
dirname(str):
"""
assert dirname is not None
if dirname == '' or os.path.isdir(dirname):
return
......@@ -25,6 +29,10 @@ def mkdir_p(dirname):
def download(url, dir):
"""
Download URL to a directory. Will figure out the filename automatically
from URL.
"""
mkdir_p(dir)
fname = url.split('/')[-1]
fpath = os.path.join(dir, fname)
......@@ -50,6 +58,10 @@ def download(url, dir):
def recursive_walk(rootdir):
"""
Yields:
str: All files in rootdir, recursively.
"""
for r, dirs, files in os.walk(rootdir):
for f in files:
yield os.path.join(r, f)
......
......@@ -20,8 +20,10 @@ globalns = NS()
def use_global_argument(args):
"""
Add the content of argparse.Namespace to globalns
:param args: Argument
Add the content of :class:`argparse.Namespace` to globalns.
Args:
args (argparse.Namespace): arguments
"""
assert isinstance(args, argparse.Namespace), type(args)
for k, v in six.iteritems(vars(args)):
......
......@@ -10,6 +10,10 @@ __all__ = ['change_gpu', 'get_nr_gpu', 'get_gpus']
def change_gpu(val):
"""
Returns:
a context where ``CUDA_VISIBLE_DEVICES=val``.
"""
val = str(val)
if val == '-1':
val = ''
......@@ -17,13 +21,20 @@ def change_gpu(val):
def get_nr_gpu():
"""
Returns:
int: the number of GPU from ``CUDA_VISIBLE_DEVICES``.
"""
env = os.environ.get('CUDA_VISIBLE_DEVICES', None)
assert env is not None, 'gpu not set!' # TODO
return len(env.split(','))
def get_gpus():
""" return a list of GPU physical id"""
"""
Returns:
list: a list of int of GPU id.
"""
env = os.environ.get('CUDA_VISIBLE_DEVICES', None)
assert env is not None, 'gpu not set!' # TODO
return map(int, env.strip().split(','))
......@@ -94,7 +94,13 @@ class CaffeLayerProcessor(object):
def load_caffe(model_desc, model_file):
"""
:return: a dict of params
Load a caffe model. You must be able to ``import caffe`` to use this
function.
Args:
model_desc (str): path to caffe model description file (.prototxt).
model_file (str): path to caffe model parameter file (.caffemodel).
Returns:
dict: the parameters.
"""
with change_env('GLOG_minloglevel', '2'):
import caffe
......@@ -107,6 +113,11 @@ def load_caffe(model_desc, model_file):
def get_caffe_pb():
"""
Get caffe protobuf.
Returns:
The imported caffe protobuf module.
"""
dir = get_dataset_path('caffe')
caffe_pb_file = os.path.join(dir, 'caffe_pb2.py')
if not os.path.isfile(caffe_pb_file):
......
......@@ -73,8 +73,10 @@ def _set_file(path):
def set_logger_dir(dirname, action=None):
"""
Set the directory for global logging.
:param dirname: log directory
:param action: an action (k/b/d/n) to be performed. Will ask user by default.
Args:
dirname(str): log directory
action(str): an action of ("k","b","d","n") to be performed. Will ask user by default.
"""
global LOG_FILE, LOG_DIR
if os.path.isdir(dirname):
......@@ -108,13 +110,14 @@ If you're resuming from a previous run you can choose to keep it.""")
def disable_logger():
""" disable all logging ability from this moment"""
""" Disable all logging ability from this moment"""
for func in _LOGGING_METHOD:
globals()[func] = lambda x: None
def auto_set_dir(action=None, overwrite=False):
""" set log directory to a subdir inside 'train_log', with the name being
"""
Set log directory to a subdir inside "train_log", with the name being
the main python file currently running"""
if LOG_DIR is not None and not overwrite:
# dir already set
......@@ -128,4 +131,5 @@ def auto_set_dir(action=None, overwrite=False):
def warn_dependency(name, dependencies):
""" Print warning about an import failure due to missing dependencies. """
warn("Failed to import '{}', {} won't be available'".format(dependencies, name)) # noqa: F821
......@@ -9,8 +9,13 @@ __all__ = ['LookUpTable']
class LookUpTable(object):
""" Maintain mapping from index to objects. """
def __init__(self, objlist):
"""
Args:
objlist(list): list of objects
"""
self.idx2obj = dict(enumerate(objlist))
self.obj2idx = {v: k for k, v in six.iteritems(self.idx2obj)}
......
......@@ -8,8 +8,9 @@ import numpy as np
class Rect(object):
"""
A Rectangle.
Note that x1 = x+w, not x+w-1 or something
A rectangle class.
Note that x1 = x + w, not x+w-1 or something else.
"""
__slots__ = ['x', 'y', 'w', 'h']
......@@ -51,9 +52,11 @@ class Rect(object):
def validate(self, shape=None):
"""
Is a valid bounding box within this shape
:param shape: [h, w]
:returns: boolean
Check that this rect is a valid bounding box within this shape.
Args:
shape: [h, w]
Returns:
bool
"""
if min(self.x, self.y) < 0:
return False
......
......@@ -11,8 +11,18 @@ __all__ = ['loads', 'dumps']
def dumps(obj):
"""
Serialize an object.
Returns:
str
"""
return msgpack.dumps(obj, use_bin_type=True)
def loads(buf):
"""
Args:
buf (str): serialized object.
"""
return msgpack.loads(buf)
......@@ -3,7 +3,7 @@
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
import numpy as np
__all__ = ['StatCounter', 'Accuracy', 'BinaryStatistics', 'RatioCounter',
__all__ = ['StatCounter', 'BinaryStatistics', 'RatioCounter', 'Accuracy',
'OnlineMoments']
......@@ -14,6 +14,10 @@ class StatCounter(object):
self.reset()
def feed(self, v):
"""
Args:
v(float or np.ndarray): has to be the same shape between calls.
"""
self._values.append(v)
def reset(self):
......@@ -40,7 +44,7 @@ class StatCounter(object):
class RatioCounter(object):
""" A counter to count ratio of something"""
""" A counter to count ratio of something. """
def __init__(self):
self.reset()
......@@ -50,6 +54,11 @@ class RatioCounter(object):
self._cnt = 0
def feed(self, cnt, tot=1):
"""
Args:
cnt(int): the count of some event of interest.
tot(int): the total number of events.
"""
self._tot += tot
self._cnt += cnt
......@@ -61,6 +70,10 @@ class RatioCounter(object):
@property
def count(self):
"""
Returns:
int: the total
"""
return self._tot
......@@ -90,8 +103,9 @@ class BinaryStatistics(object):
def feed(self, pred, label):
"""
:param pred: 0/1 np array
:param label: 0/1 np array of the same size
Args:
pred (np.ndarray): binary array.
label (np.ndarray): binary array of the same size.
"""
assert pred.shape == label.shape, "{} != {}".format(pred.shape, label.shape)
self.nr_pos += (label == 1).sum()
......@@ -127,7 +141,8 @@ class BinaryStatistics(object):
class OnlineMoments(object):
"""Compute 1st and 2nd moments online
"""Compute 1st and 2nd moments online (to avoid storing all elements).
See algorithm at: https://www.wikiwand.com/en/Algorithms_for_calculating_variance#/Online_algorithm
"""
......@@ -137,6 +152,10 @@ class OnlineMoments(object):
self._n = 0
def feed(self, x):
"""
Args:
x (float or np.ndarray): must have the same shape.
"""
self._n += 1
delta = x - self._mean
self._mean += delta * (1.0 / self._n)
......
......@@ -17,30 +17,27 @@ __all__ = ['total_timer', 'timed_operation',
'print_total_timer', 'IterSpeedCounter']
class IterSpeedCounter(object):
""" To count how often some code gets reached"""
@contextmanager
def timed_operation(msg, log_start=False):
"""
Surround a context with a timer.
def __init__(self, print_every, name=None):
self.cnt = 0
self.print_every = int(print_every)
self.name = name if name else 'IterSpeed'
Args:
msg(str): the log to print.
log_start(bool): whether to print also at the beginning.
def reset(self):
self.start = time.time()
Example:
.. code-block:: python
def __call__(self):
if self.cnt == 0:
self.reset()
self.cnt += 1
if self.cnt % self.print_every != 0:
return
t = time.time() - self.start
logger.info("{}: {:.2f} sec, {} times, {:.3g} sec/time".format(
self.name, t, self.cnt, t / self.cnt))
with timed_operation('Good Stuff'):
time.sleep(1)
Will print:
@contextmanager
def timed_operation(msg, log_start=False):
.. code-block:: python
Good stuff finished, time:1sec.
"""
if log_start:
logger.info('Start {} ...'.format(msg))
start = time.time()
......@@ -54,6 +51,7 @@ _TOTAL_TIMER_DATA = defaultdict(StatCounter)
@contextmanager
def total_timer(msg):
""" A context which add the time spent inside to TotalTimer. """
start = time.time()
yield
t = time.time() - start
......@@ -61,6 +59,10 @@ def total_timer(msg):
def print_total_timer():
"""
Print the content of the TotalTimer, if it's not empty. This function will automatically get
called when program exits.
"""
if len(_TOTAL_TIMER_DATA) == 0:
return
for k, v in six.iteritems(_TOTAL_TIMER_DATA):
......@@ -69,3 +71,41 @@ def print_total_timer():
atexit.register(print_total_timer)
class IterSpeedCounter(object):
""" Test how often some code gets reached.
Example:
Print the speed of the iteration every 100 times.
.. code-block:: python
speed = IterSpeedCounter(100)
for k in range(1000):
# do something
speed()
"""
def __init__(self, print_every, name=None):
"""
Args:
print_every(int): interval to print.
name(str): name to used when print.
"""
self.cnt = 0
self.print_every = int(print_every)
self.name = name if name else 'IterSpeed'
def reset(self):
self.start = time.time()
def __call__(self):
if self.cnt == 0:
self.reset()
self.cnt += 1
if self.cnt % self.print_every != 0:
return
t = time.time() - self.start
logger.info("{}: {:.2f} sec, {} times, {:.3g} sec/time".format(
self.name, t, self.cnt, t / self.cnt))
......@@ -22,6 +22,14 @@ __all__ = ['change_env',
@contextmanager
def change_env(name, val):
"""
Args:
name(str), val(str):
Returns:
a context where the environment variable ``name`` being set to
``val``. It will be set back after the context exits.
"""
oldval = os.environ.get(name, None)
os.environ[name] = val
yield
......@@ -32,7 +40,14 @@ def change_env(name, val):
def get_rng(obj=None):
""" obj: some object to use to generate random seed"""
"""
Get a good RNG.
Args:
obj: some object to use to generate random seed.
Returns:
np.random.RandomState: the RNG.
"""
seed = (id(obj) + os.getpid() +
int(datetime.now().strftime("%Y%m%d%H%M%S%f"))) % 4294967295
return np.random.RandomState(seed)
......@@ -43,10 +58,18 @@ _EXECUTE_HISTORY = set()
def execute_only_once():
"""
when called with:
if execute_only_once():
# do something
The body is guranteed to be executed only the first time.
Each called in the code to this function is guranteed to return True the
first time and False afterwards.
Returns:
bool: whether this is the first time this function gets called from
this line of code.
Example:
.. code-block:: python
if execute_only_once():
# do something only once
"""
f = inspect.currentframe().f_back
ident = (f.f_code.co_filename, f.f_lineno)
......@@ -57,6 +80,15 @@ def execute_only_once():
def get_dataset_path(*args):
"""
Get the path to some dataset under ``$TENSORPACK_DATASET``.
Args:
args: strings to be joined to form path.
Returns:
str: path to the dataset.
"""
d = os.environ.get('TENSORPACK_DATASET', None)
if d is None:
d = os.path.abspath(os.path.join(
......@@ -69,6 +101,14 @@ def get_dataset_path(*args):
def get_tqdm_kwargs(**kwargs):
"""
Return default arguments to be used with tqdm.
Args:
kwargs: extra arguments to be used.
Returns:
dict:
"""
default = dict(
smoothing=0.5,
dynamic_ncols=True,
......@@ -85,9 +125,15 @@ def get_tqdm_kwargs(**kwargs):
def get_tqdm(**kwargs):
""" Similar to :func:`get_tqdm_kwargs`, but returns the tqdm object
directly. """
return tqdm(**get_tqdm_kwargs(**kwargs))
def building_rtfd():
"""
Returns:
bool: if tensorpack is imported to generate docs now.
"""
return os.environ.get('READTHEDOCS') == 'True' \
or os.environ.get('TENSORPACK_DOC_BUILDING')
......@@ -16,11 +16,12 @@ try:
except ImportError:
pass
__all__ = ['pyplot2img', 'build_patch_list', 'pyplot_viz',
'dump_dataflow_images', 'interactive_imshow']
__all__ = ['pyplot2img', 'interactive_imshow', 'build_patch_list',
'pyplot_viz', 'dump_dataflow_images']
def pyplot2img(plt):
""" Convert a pyplot instance to image """
buf = io.BytesIO()
plt.axis('off')
plt.savefig(buf, format='png', bbox_inches='tight', pad_inches=0)
......@@ -32,8 +33,13 @@ def pyplot2img(plt):
def pyplot_viz(img, shape=None):
""" use pyplot to visualize the image
Note: this is quite slow. and the returned image will have a border
""" Use pyplot to visualize the image. e.g., when input is grayscale, the result
will automatically have a colormap.
Returns:
np.ndarray: an image.
Note:
this is quite slow. and the returned image will have a border
"""
plt.clf()
plt.axes([0, 0, 1, 1])
......@@ -54,8 +60,17 @@ def minnone(x, y):
def interactive_imshow(img, lclick_cb=None, rclick_cb=None, **kwargs):
"""
:param lclick_cb: a callback(img, x, y) for left click
:param kwargs: can be {key_cb_a ... key_cb_z: callback(img)}
Args:
img (np.ndarray): an image to show.
lclick_cb: a callback func(img, x, y) for left click event.
kwargs: can be {key_cb_a: callback_img, key_cb_b: callback_img}, to
specify a callback func(img) for keypress.
Some existing keypress event handler:
* q: destroy the current window
* x: execute ``sys.exit()``
* s: save image to "out.png"
"""
name = 'random_window_name'
cv2.imshow(name, img)
......@@ -84,15 +99,27 @@ def build_patch_list(patch_list,
shuffle=False, bgcolor=255,
viz=False, lclick_cb=None):
"""
Generate patches.
:param patch_list: bhw or bhwc images in [0,255]
:param border: defaults to 0.1 * min(image_width, image_height)
:param nr_row, nr_col: rows and cols of the grid
:parma max_width, max_height: if nr_row/col are not given, use this to infer the rows and cols
:param shuffle: shuffle the images
:param bgcolor: background color
:param viz: use interactive imshow to visualize the results
:param lclick_cb: only useful when viz=True. a callback(patch, idx)
Stacked patches into grid, to produce visualizations like the following:
.. image:: https://github.com/ppwwyyxx/tensorpack/raw/master/examples/GAN/demo/CelebA-samples.jpg
Args:
patch_list(np.ndarray): NHW or NHWC images in [0,255].
nr_row(int), nr_col(int): rows and cols of the grid.
border(int): border length between images.
Defaults to ``0.1 * min(image_w, image_h)``.
max_width(int), max_height(int): Maximum allowed size of the
visualization image. If ``nr_row/nr_col`` are not given, will use this to infer the rows and cols.
shuffle(bool): shuffle the images inside ``patch_list``.
bgcolor(int): background color in [0, 255].
viz(bool): whether to use :func:`interactive_imshow` to visualize the results.
lclick_cb: A callback function to get called when ``viz==True`` and an
image get clicked. It takes the image patch and its index in
``patch_list`` as arguments. (The index is invalid when
``shuffle==True``.)
Yields:
np.ndarray: the visualization image.
"""
# setup parameters
patch_list = np.asarray(patch_list)
......@@ -156,16 +183,22 @@ def dump_dataflow_images(df, index=0, batched=True,
scale=1, resize=None, viz=None,
flipRGB=False, exit_after=True):
"""
:param df: a DataFlow
:param index: the index of the image component
:param batched: whether the component contains batched images or not
:param number: how many datapoint to take from the DataFlow
:param output_dir: output directory to save images, default to not save.
:param scale: scale the value, usually either 1 or 255
:param resize: (h, w) or Nne, resize the images
:param viz: (h, w) or None, visualize the images in grid with imshow
:param flipRGB: apply a RGB<->BGR conversion or not
:param exit_after: exit the process after this function
Dump or visualize images of a :class:`DataFlow`.
Args:
df (DataFlow): the DataFlow.
index (int): the index of the image component.
batched (bool): whether the component contains batched images (NHW or
NHWC) or not (HW or HWC).
number (int): how many datapoint to take from the DataFlow.
output_dir (str): output directory to save images, default to not save.
scale (float): scale the value, usually either 1 or 255.
resize (tuple or None): tuple of (h, w) to resize the images to.
viz (tuple or None): tuple of (h, w) determining the grid size to use
with :func:`build_patch_list` for visualization. No visualization will happen by
default.
flipRGB (bool): apply a RGB<->BGR conversion or not.
exit_after (bool): ``sys.exit()`` after this function.
"""
if output_dir:
mkdir_p(output_dir)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment