Commit 8f797c63 authored by Yuxin Wu's avatar Yuxin Wu

api docs for utils/

parent 1b73d9cc
......@@ -10,3 +10,4 @@ cd "$PROG_DIR"
make clean
#sphinx-apidoc -o modules ../tensorpack -f -d 10
make html
xdotool windowactivate --sync $(xdotool search --desktop 0 Chromium) key F5
......@@ -6,3 +6,4 @@ msgpack-python
msgpack-numpy
pyzmq
subprocess32; python_version < '3.0'
functools32; python_version < '3.0'
......@@ -4,20 +4,23 @@
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
import operator
import inspect
import six
import functools
import collections
from . import logger
if six.PY2:
import functools32 as functools
else:
import functools
__all__ = ['map_arg', 'memoized', 'shape2d', 'memoized_ignoreargs', 'log_once']
def map_arg(**maps):
"""
Apply a mapping on certains argument before calling original function.
maps: {key: map_func}
Apply a mapping on certains argument before calling the original function.
Args:
maps (dict): {key: map_func}
"""
def deco(func):
@functools.wraps(func)
......@@ -31,45 +34,18 @@ def map_arg(**maps):
return deco
class memoized(object):
'''Decorator. Caches a function's return value each time it is called.
If called later with the same arguments, the cached value is returned
(not reevaluated).
'''
def __init__(self, func):
self.func = func
self.cache = {}
def __call__(self, *args, **kwargs):
kwlist = tuple(sorted(list(kwargs), key=operator.itemgetter(0)))
if not isinstance(args, collections.Hashable) or \
not isinstance(kwlist, collections.Hashable):
logger.warn("Arguments to memoized call is unhashable!")
# uncacheable. a list, for instance.
# better to not cache than blow up.
return self.func(*args, **kwargs)
key = (args, kwlist)
if key in self.cache:
return self.cache[key]
else:
value = self.func(*args, **kwargs)
self.cache[key] = value
return value
def __repr__(self):
'''Return the function's docstring.'''
return self.func.__doc__
def __get__(self, obj, objtype):
'''Support instance methods.'''
return functools.partial(self.__call__, obj)
memoized = functools.lru_cache(maxsize=None)
""" Equivalent to :func:`functools.lru_cache` """
_MEMOIZED_NOARGS = {}
def memoized_ignoreargs(func):
"""
A decorator. It performs memoization ignoring the arguments used to call
the function.
"""
hash(func) # make sure it is hashable. TODO is it necessary?
def wrapper(*args, **kwargs):
......@@ -93,7 +69,13 @@ def memoized_ignoreargs(func):
def shape2d(a):
"""
Ensure a 2D shape.
Args:
a: a int or tuple/list of length 2
Returns:
list: of length 2. if ``a`` is a int, return ``[a, a]``.
"""
if type(a) == int:
return [a, a]
......@@ -105,4 +87,12 @@ def shape2d(a):
@memoized
def log_once(message, func):
"""
Log certain message only once. Call this function more than one times with
the same message will result in no-op.
Args:
message(str): message to log
func(str): the name of the logger method. e.g. "info", "warn", "error".
"""
getattr(logger, func)(message)
......@@ -36,15 +36,18 @@ class StoppableThread(threading.Thread):
self._stop_evt = threading.Event()
def stop(self):
""" stop the thread"""
""" Stop the thread"""
self._stop_evt.set()
def stopped(self):
""" check whether the thread is stopped or not"""
"""
Returns:
bool: whether the thread is stopped or not
"""
return self._stop_evt.isSet()
def queue_put_stoppable(self, q, obj):
""" put obj to queue, but will give up if the thread is stopped"""
""" Put obj to queue, but will give up when the thread is stopped"""
while not self.stopped():
try:
q.put(obj, timeout=5)
......@@ -53,7 +56,7 @@ class StoppableThread(threading.Thread):
pass
def queue_get_stoppable(self, q):
""" take obj from queue, but will give up if the thread is stopped"""
""" Take obj from queue, but will give up when the thread is stopped"""
while not self.stopped():
try:
return q.get(timeout=5)
......@@ -66,7 +69,8 @@ class LoopThread(StoppableThread):
def __init__(self, func, pausable=True):
"""
:param func: the function to run
Args:
func: the function to run
"""
super(LoopThread, self).__init__()
self._func = func
......@@ -83,10 +87,12 @@ class LoopThread(StoppableThread):
self._func()
def pause(self):
""" Pause the loop """
assert self._pausable
self._lock.acquire()
def resume(self):
""" Resume the loop """
assert self._pausable
self._lock.release()
......@@ -97,6 +103,12 @@ class DIE(object):
def ensure_proc_terminate(proc):
"""
Make sure processes terminate when main process exit.
Args:
proc (multiprocessing.Process or list)
"""
if isinstance(proc, list):
for p in proc:
ensure_proc_terminate(p)
......@@ -117,12 +129,21 @@ def ensure_proc_terminate(proc):
@contextmanager
def mask_sigint():
"""
Returns:
a context where ``SIGINT`` is ignored.
"""
sigint_handler = signal.signal(signal.SIGINT, signal.SIG_IGN)
yield
signal.signal(signal.SIGINT, sigint_handler)
def start_proc_mask_signal(proc):
""" Start process(es) with SIGINT ignored.
Args:
proc: (multiprocessing.Process or list)
"""
if not isinstance(proc, list):
proc = [proc]
......@@ -132,6 +153,12 @@ def start_proc_mask_signal(proc):
def subproc_call(cmd, timeout=None):
"""
Execute a command with timeout, and return both STDOUT/STDERR.
Args:
cmd(str): the command to execute.
timeout(float): timeout in seconds.
"""
try:
output = subprocess.check_output(
cmd, stderr=subprocess.STDOUT,
......@@ -147,15 +174,28 @@ def subproc_call(cmd, timeout=None):
class OrderedContainer(object):
"""
Like a priority queue, but will always wait for item with index (x+1) before producing (x+2).
Like a queue, but will always wait to receive item with rank
(x+1) and produce (x+1) before producing (x+2).
Warning:
It is not thread-safe.
"""
def __init__(self, start=0):
"""
Args:
start(int): the starting rank.
"""
self.ranks = []
self.data = []
self.wait_for = start
def put(self, rank, val):
"""
Args:
rank(int): rank of th element. All elements must have different ranks.
val: an object
"""
idx = bisect.bisect(self.ranks, rank)
self.ranks.insert(idx, rank)
self.data.insert(idx, val)
......@@ -183,9 +223,11 @@ class OrderedResultGatherProc(multiprocessing.Process):
def __init__(self, data_queue, nr_producer, start=0):
"""
:param data_queue: a multiprocessing.Queue to produce input dp
:param nr_producer: number of producer processes. Will terminate after receiving this many of DIE sentinel.
:param start: the first task index
Args:
data_queue(multiprocessing.Queue): a queue which contains datapoints.
nr_producer(int): number of producer processes. This process will
terminate after receiving this many of :class:`DIE` sentinel.
start(int): the rank of the first object
"""
super(OrderedResultGatherProc, self).__init__()
self.data_queue = data_queue
......
......@@ -9,6 +9,7 @@ __all__ = ['enable_call_trace']
def enable_call_trace():
""" Enable trace for calls to any function. """
def tracer(frame, event, arg):
if event == 'call':
co = frame.f_code
......
......@@ -29,12 +29,12 @@ class Discretizer1D(Discretizer):
class UniformDiscretizer1D(Discretizer1D):
def __init__(self, minv, maxv, spacing):
"""
:params minv: minimum value of the first bin
:params maxv: maximum value of the last bin
:param spacing: width of a bin
Args:
minv(float): minimum value of the first bin
maxv(float): maximum value of the last bin
spacing(float): width of a bin
"""
self.minv = float(minv)
self.maxv = float(maxv)
......@@ -42,9 +42,19 @@ class UniformDiscretizer1D(Discretizer1D):
self.nr_bin = int(np.ceil((self.maxv - self.minv) / self.spacing))
def get_nr_bin(self):
"""
Returns:
int: number of bins
"""
return self.nr_bin
def get_bin(self, v):
"""
Args:
v(float): value
Returns:
int: the bin index for value ``v``.
"""
if v < self.minv:
log_once("UniformDiscretizer1D: value smaller than min!", 'warn')
return 0
......@@ -56,10 +66,23 @@ class UniformDiscretizer1D(Discretizer1D):
0, self.nr_bin - 1))
def get_bin_center(self, bin_id):
"""
Args:
bin_id(int)
Returns:
float: the center of this bin.
"""
return self.minv + self.spacing * (bin_id + 0.5)
def get_distribution(self, v, smooth_factor=0.05, smooth_radius=2):
""" return a smoothed one-hot distribution of the sample v.
"""
Args:
v(float): a sample
smooth_factor(float):
smooth_radius(int):
Returns:
numpy.ndarray: array of length ``self.nr_bin``, a smoothed one-hot
distribution centered at the bin of sample ``v``.
"""
b = self.get_bin(v)
ret = np.zeros((self.nr_bin, ), dtype='float32')
......@@ -78,10 +101,11 @@ class UniformDiscretizer1D(Discretizer1D):
class UniformDiscretizerND(Discretizer):
""" A combination of several :class:`UniformDiscretizer1D`. """
def __init__(self, *min_max_spacing):
"""
:params min_max_spacing: (minv, maxv, spacing) for each dimension
Args:
min_max_spacing: (minv, maxv, spacing) for each dimension
"""
self.n = len(min_max_spacing)
self.discretizers = [UniformDiscretizer1D(*k) for k in min_max_spacing]
......
......@@ -13,7 +13,11 @@ __all__ = ['mkdir_p', 'download', 'recursive_walk']
def mkdir_p(dirname):
""" make a dir recursively, but do nothing if the dir exists"""
""" Make a dir recursively, but do nothing if the dir exists
Args:
dirname(str):
"""
assert dirname is not None
if dirname == '' or os.path.isdir(dirname):
return
......@@ -25,6 +29,10 @@ def mkdir_p(dirname):
def download(url, dir):
"""
Download URL to a directory. Will figure out the filename automatically
from URL.
"""
mkdir_p(dir)
fname = url.split('/')[-1]
fpath = os.path.join(dir, fname)
......@@ -50,6 +58,10 @@ def download(url, dir):
def recursive_walk(rootdir):
"""
Yields:
str: All files in rootdir, recursively.
"""
for r, dirs, files in os.walk(rootdir):
for f in files:
yield os.path.join(r, f)
......
......@@ -20,8 +20,10 @@ globalns = NS()
def use_global_argument(args):
"""
Add the content of argparse.Namespace to globalns
:param args: Argument
Add the content of :class:`argparse.Namespace` to globalns.
Args:
args (argparse.Namespace): arguments
"""
assert isinstance(args, argparse.Namespace), type(args)
for k, v in six.iteritems(vars(args)):
......
......@@ -10,6 +10,10 @@ __all__ = ['change_gpu', 'get_nr_gpu', 'get_gpus']
def change_gpu(val):
"""
Returns:
a context where ``CUDA_VISIBLE_DEVICES=val``.
"""
val = str(val)
if val == '-1':
val = ''
......@@ -17,13 +21,20 @@ def change_gpu(val):
def get_nr_gpu():
"""
Returns:
int: the number of GPU from ``CUDA_VISIBLE_DEVICES``.
"""
env = os.environ.get('CUDA_VISIBLE_DEVICES', None)
assert env is not None, 'gpu not set!' # TODO
return len(env.split(','))
def get_gpus():
""" return a list of GPU physical id"""
"""
Returns:
list: a list of int of GPU id.
"""
env = os.environ.get('CUDA_VISIBLE_DEVICES', None)
assert env is not None, 'gpu not set!' # TODO
return map(int, env.strip().split(','))
......@@ -94,7 +94,13 @@ class CaffeLayerProcessor(object):
def load_caffe(model_desc, model_file):
"""
:return: a dict of params
Load a caffe model. You must be able to ``import caffe`` to use this
function.
Args:
model_desc (str): path to caffe model description file (.prototxt).
model_file (str): path to caffe model parameter file (.caffemodel).
Returns:
dict: the parameters.
"""
with change_env('GLOG_minloglevel', '2'):
import caffe
......@@ -107,6 +113,11 @@ def load_caffe(model_desc, model_file):
def get_caffe_pb():
"""
Get caffe protobuf.
Returns:
The imported caffe protobuf module.
"""
dir = get_dataset_path('caffe')
caffe_pb_file = os.path.join(dir, 'caffe_pb2.py')
if not os.path.isfile(caffe_pb_file):
......
......@@ -73,8 +73,10 @@ def _set_file(path):
def set_logger_dir(dirname, action=None):
"""
Set the directory for global logging.
:param dirname: log directory
:param action: an action (k/b/d/n) to be performed. Will ask user by default.
Args:
dirname(str): log directory
action(str): an action of ("k","b","d","n") to be performed. Will ask user by default.
"""
global LOG_FILE, LOG_DIR
if os.path.isdir(dirname):
......@@ -108,13 +110,14 @@ If you're resuming from a previous run you can choose to keep it.""")
def disable_logger():
""" disable all logging ability from this moment"""
""" Disable all logging ability from this moment"""
for func in _LOGGING_METHOD:
globals()[func] = lambda x: None
def auto_set_dir(action=None, overwrite=False):
""" set log directory to a subdir inside 'train_log', with the name being
"""
Set log directory to a subdir inside "train_log", with the name being
the main python file currently running"""
if LOG_DIR is not None and not overwrite:
# dir already set
......@@ -128,4 +131,5 @@ def auto_set_dir(action=None, overwrite=False):
def warn_dependency(name, dependencies):
""" Print warning about an import failure due to missing dependencies. """
warn("Failed to import '{}', {} won't be available'".format(dependencies, name)) # noqa: F821
......@@ -9,8 +9,13 @@ __all__ = ['LookUpTable']
class LookUpTable(object):
""" Maintain mapping from index to objects. """
def __init__(self, objlist):
"""
Args:
objlist(list): list of objects
"""
self.idx2obj = dict(enumerate(objlist))
self.obj2idx = {v: k for k, v in six.iteritems(self.idx2obj)}
......
......@@ -8,8 +8,9 @@ import numpy as np
class Rect(object):
"""
A Rectangle.
Note that x1 = x+w, not x+w-1 or something
A rectangle class.
Note that x1 = x + w, not x+w-1 or something else.
"""
__slots__ = ['x', 'y', 'w', 'h']
......@@ -51,9 +52,11 @@ class Rect(object):
def validate(self, shape=None):
"""
Is a valid bounding box within this shape
:param shape: [h, w]
:returns: boolean
Check that this rect is a valid bounding box within this shape.
Args:
shape: [h, w]
Returns:
bool
"""
if min(self.x, self.y) < 0:
return False
......
......@@ -11,8 +11,18 @@ __all__ = ['loads', 'dumps']
def dumps(obj):
"""
Serialize an object.
Returns:
str
"""
return msgpack.dumps(obj, use_bin_type=True)
def loads(buf):
"""
Args:
buf (str): serialized object.
"""
return msgpack.loads(buf)
......@@ -3,7 +3,7 @@
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
import numpy as np
__all__ = ['StatCounter', 'Accuracy', 'BinaryStatistics', 'RatioCounter',
__all__ = ['StatCounter', 'BinaryStatistics', 'RatioCounter', 'Accuracy',
'OnlineMoments']
......@@ -14,6 +14,10 @@ class StatCounter(object):
self.reset()
def feed(self, v):
"""
Args:
v(float or np.ndarray): has to be the same shape between calls.
"""
self._values.append(v)
def reset(self):
......@@ -40,7 +44,7 @@ class StatCounter(object):
class RatioCounter(object):
""" A counter to count ratio of something"""
""" A counter to count ratio of something. """
def __init__(self):
self.reset()
......@@ -50,6 +54,11 @@ class RatioCounter(object):
self._cnt = 0
def feed(self, cnt, tot=1):
"""
Args:
cnt(int): the count of some event of interest.
tot(int): the total number of events.
"""
self._tot += tot
self._cnt += cnt
......@@ -61,6 +70,10 @@ class RatioCounter(object):
@property
def count(self):
"""
Returns:
int: the total
"""
return self._tot
......@@ -90,8 +103,9 @@ class BinaryStatistics(object):
def feed(self, pred, label):
"""
:param pred: 0/1 np array
:param label: 0/1 np array of the same size
Args:
pred (np.ndarray): binary array.
label (np.ndarray): binary array of the same size.
"""
assert pred.shape == label.shape, "{} != {}".format(pred.shape, label.shape)
self.nr_pos += (label == 1).sum()
......@@ -127,7 +141,8 @@ class BinaryStatistics(object):
class OnlineMoments(object):
"""Compute 1st and 2nd moments online
"""Compute 1st and 2nd moments online (to avoid storing all elements).
See algorithm at: https://www.wikiwand.com/en/Algorithms_for_calculating_variance#/Online_algorithm
"""
......@@ -137,6 +152,10 @@ class OnlineMoments(object):
self._n = 0
def feed(self, x):
"""
Args:
x (float or np.ndarray): must have the same shape.
"""
self._n += 1
delta = x - self._mean
self._mean += delta * (1.0 / self._n)
......
......@@ -17,30 +17,27 @@ __all__ = ['total_timer', 'timed_operation',
'print_total_timer', 'IterSpeedCounter']
class IterSpeedCounter(object):
""" To count how often some code gets reached"""
@contextmanager
def timed_operation(msg, log_start=False):
"""
Surround a context with a timer.
def __init__(self, print_every, name=None):
self.cnt = 0
self.print_every = int(print_every)
self.name = name if name else 'IterSpeed'
Args:
msg(str): the log to print.
log_start(bool): whether to print also at the beginning.
def reset(self):
self.start = time.time()
Example:
.. code-block:: python
def __call__(self):
if self.cnt == 0:
self.reset()
self.cnt += 1
if self.cnt % self.print_every != 0:
return
t = time.time() - self.start
logger.info("{}: {:.2f} sec, {} times, {:.3g} sec/time".format(
self.name, t, self.cnt, t / self.cnt))
with timed_operation('Good Stuff'):
time.sleep(1)
Will print:
@contextmanager
def timed_operation(msg, log_start=False):
.. code-block:: python
Good stuff finished, time:1sec.
"""
if log_start:
logger.info('Start {} ...'.format(msg))
start = time.time()
......@@ -54,6 +51,7 @@ _TOTAL_TIMER_DATA = defaultdict(StatCounter)
@contextmanager
def total_timer(msg):
""" A context which add the time spent inside to TotalTimer. """
start = time.time()
yield
t = time.time() - start
......@@ -61,6 +59,10 @@ def total_timer(msg):
def print_total_timer():
"""
Print the content of the TotalTimer, if it's not empty. This function will automatically get
called when program exits.
"""
if len(_TOTAL_TIMER_DATA) == 0:
return
for k, v in six.iteritems(_TOTAL_TIMER_DATA):
......@@ -69,3 +71,41 @@ def print_total_timer():
atexit.register(print_total_timer)
class IterSpeedCounter(object):
""" Test how often some code gets reached.
Example:
Print the speed of the iteration every 100 times.
.. code-block:: python
speed = IterSpeedCounter(100)
for k in range(1000):
# do something
speed()
"""
def __init__(self, print_every, name=None):
"""
Args:
print_every(int): interval to print.
name(str): name to used when print.
"""
self.cnt = 0
self.print_every = int(print_every)
self.name = name if name else 'IterSpeed'
def reset(self):
self.start = time.time()
def __call__(self):
if self.cnt == 0:
self.reset()
self.cnt += 1
if self.cnt % self.print_every != 0:
return
t = time.time() - self.start
logger.info("{}: {:.2f} sec, {} times, {:.3g} sec/time".format(
self.name, t, self.cnt, t / self.cnt))
......@@ -22,6 +22,14 @@ __all__ = ['change_env',
@contextmanager
def change_env(name, val):
"""
Args:
name(str), val(str):
Returns:
a context where the environment variable ``name`` being set to
``val``. It will be set back after the context exits.
"""
oldval = os.environ.get(name, None)
os.environ[name] = val
yield
......@@ -32,7 +40,14 @@ def change_env(name, val):
def get_rng(obj=None):
""" obj: some object to use to generate random seed"""
"""
Get a good RNG.
Args:
obj: some object to use to generate random seed.
Returns:
np.random.RandomState: the RNG.
"""
seed = (id(obj) + os.getpid() +
int(datetime.now().strftime("%Y%m%d%H%M%S%f"))) % 4294967295
return np.random.RandomState(seed)
......@@ -43,10 +58,18 @@ _EXECUTE_HISTORY = set()
def execute_only_once():
"""
when called with:
Each called in the code to this function is guranteed to return True the
first time and False afterwards.
Returns:
bool: whether this is the first time this function gets called from
this line of code.
Example:
.. code-block:: python
if execute_only_once():
# do something
The body is guranteed to be executed only the first time.
# do something only once
"""
f = inspect.currentframe().f_back
ident = (f.f_code.co_filename, f.f_lineno)
......@@ -57,6 +80,15 @@ def execute_only_once():
def get_dataset_path(*args):
"""
Get the path to some dataset under ``$TENSORPACK_DATASET``.
Args:
args: strings to be joined to form path.
Returns:
str: path to the dataset.
"""
d = os.environ.get('TENSORPACK_DATASET', None)
if d is None:
d = os.path.abspath(os.path.join(
......@@ -69,6 +101,14 @@ def get_dataset_path(*args):
def get_tqdm_kwargs(**kwargs):
"""
Return default arguments to be used with tqdm.
Args:
kwargs: extra arguments to be used.
Returns:
dict:
"""
default = dict(
smoothing=0.5,
dynamic_ncols=True,
......@@ -85,9 +125,15 @@ def get_tqdm_kwargs(**kwargs):
def get_tqdm(**kwargs):
""" Similar to :func:`get_tqdm_kwargs`, but returns the tqdm object
directly. """
return tqdm(**get_tqdm_kwargs(**kwargs))
def building_rtfd():
"""
Returns:
bool: if tensorpack is imported to generate docs now.
"""
return os.environ.get('READTHEDOCS') == 'True' \
or os.environ.get('TENSORPACK_DOC_BUILDING')
......@@ -16,11 +16,12 @@ try:
except ImportError:
pass
__all__ = ['pyplot2img', 'build_patch_list', 'pyplot_viz',
'dump_dataflow_images', 'interactive_imshow']
__all__ = ['pyplot2img', 'interactive_imshow', 'build_patch_list',
'pyplot_viz', 'dump_dataflow_images']
def pyplot2img(plt):
""" Convert a pyplot instance to image """
buf = io.BytesIO()
plt.axis('off')
plt.savefig(buf, format='png', bbox_inches='tight', pad_inches=0)
......@@ -32,8 +33,13 @@ def pyplot2img(plt):
def pyplot_viz(img, shape=None):
""" use pyplot to visualize the image
Note: this is quite slow. and the returned image will have a border
""" Use pyplot to visualize the image. e.g., when input is grayscale, the result
will automatically have a colormap.
Returns:
np.ndarray: an image.
Note:
this is quite slow. and the returned image will have a border
"""
plt.clf()
plt.axes([0, 0, 1, 1])
......@@ -54,8 +60,17 @@ def minnone(x, y):
def interactive_imshow(img, lclick_cb=None, rclick_cb=None, **kwargs):
"""
:param lclick_cb: a callback(img, x, y) for left click
:param kwargs: can be {key_cb_a ... key_cb_z: callback(img)}
Args:
img (np.ndarray): an image to show.
lclick_cb: a callback func(img, x, y) for left click event.
kwargs: can be {key_cb_a: callback_img, key_cb_b: callback_img}, to
specify a callback func(img) for keypress.
Some existing keypress event handler:
* q: destroy the current window
* x: execute ``sys.exit()``
* s: save image to "out.png"
"""
name = 'random_window_name'
cv2.imshow(name, img)
......@@ -84,15 +99,27 @@ def build_patch_list(patch_list,
shuffle=False, bgcolor=255,
viz=False, lclick_cb=None):
"""
Generate patches.
:param patch_list: bhw or bhwc images in [0,255]
:param border: defaults to 0.1 * min(image_width, image_height)
:param nr_row, nr_col: rows and cols of the grid
:parma max_width, max_height: if nr_row/col are not given, use this to infer the rows and cols
:param shuffle: shuffle the images
:param bgcolor: background color
:param viz: use interactive imshow to visualize the results
:param lclick_cb: only useful when viz=True. a callback(patch, idx)
Stacked patches into grid, to produce visualizations like the following:
.. image:: https://github.com/ppwwyyxx/tensorpack/raw/master/examples/GAN/demo/CelebA-samples.jpg
Args:
patch_list(np.ndarray): NHW or NHWC images in [0,255].
nr_row(int), nr_col(int): rows and cols of the grid.
border(int): border length between images.
Defaults to ``0.1 * min(image_w, image_h)``.
max_width(int), max_height(int): Maximum allowed size of the
visualization image. If ``nr_row/nr_col`` are not given, will use this to infer the rows and cols.
shuffle(bool): shuffle the images inside ``patch_list``.
bgcolor(int): background color in [0, 255].
viz(bool): whether to use :func:`interactive_imshow` to visualize the results.
lclick_cb: A callback function to get called when ``viz==True`` and an
image get clicked. It takes the image patch and its index in
``patch_list`` as arguments. (The index is invalid when
``shuffle==True``.)
Yields:
np.ndarray: the visualization image.
"""
# setup parameters
patch_list = np.asarray(patch_list)
......@@ -156,16 +183,22 @@ def dump_dataflow_images(df, index=0, batched=True,
scale=1, resize=None, viz=None,
flipRGB=False, exit_after=True):
"""
:param df: a DataFlow
:param index: the index of the image component
:param batched: whether the component contains batched images or not
:param number: how many datapoint to take from the DataFlow
:param output_dir: output directory to save images, default to not save.
:param scale: scale the value, usually either 1 or 255
:param resize: (h, w) or Nne, resize the images
:param viz: (h, w) or None, visualize the images in grid with imshow
:param flipRGB: apply a RGB<->BGR conversion or not
:param exit_after: exit the process after this function
Dump or visualize images of a :class:`DataFlow`.
Args:
df (DataFlow): the DataFlow.
index (int): the index of the image component.
batched (bool): whether the component contains batched images (NHW or
NHWC) or not (HW or HWC).
number (int): how many datapoint to take from the DataFlow.
output_dir (str): output directory to save images, default to not save.
scale (float): scale the value, usually either 1 or 255.
resize (tuple or None): tuple of (h, w) to resize the images to.
viz (tuple or None): tuple of (h, w) determining the grid size to use
with :func:`build_patch_list` for visualization. No visualization will happen by
default.
flipRGB (bool): apply a RGB<->BGR conversion or not.
exit_after (bool): ``sys.exit()`` after this function.
"""
if output_dir:
mkdir_p(output_dir)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment