Commit 43489024 authored by Yuxin Wu's avatar Yuxin Wu

shape2d in argtools

parent 2c8cd32f
......@@ -108,7 +108,7 @@ def annotate_min_max(data_x, data_y, ax):
x_max, y_max = data_y[0], data_y[0]
x_min, y_min = data_x[0], data_y[0]
for i in xrange(1, len(data_x)):
for i in range(1, len(data_x)):
if data_y[i] > y_max:
y_max = data_y[i]
x_max = data_x[i]
......
......@@ -9,7 +9,7 @@ import numpy as np
from six.moves import range
import xml.etree.ElementTree as ET
from ...utils import logger, get_rng, get_dataset_path, memoized
from ...utils import logger, get_rng, get_dataset_path
from ...utils.loadcaffe import get_caffe_pb
from ...utils.fs import mkdir_p, download
from ...utils.timer import timed_operation
......@@ -17,9 +17,6 @@ from ..base import RNGDataFlow
__all__ = ['ILSVRCMeta', 'ILSVRC12']
@memoized
def log_once(s): logger.warn(s)
CAFFE_ILSVRC12_URL = "http://dl.caffe.berkeleyvision.org/caffe_ilsvrc12.tar.gz"
class ILSVRCMeta(object):
......
......@@ -4,6 +4,7 @@
from .base import ImageAugmentor
from ...utils.rect import Rect
from ...utils.argtools import shape2d
from six.moves import range
import numpy as np
......@@ -17,6 +18,7 @@ class RandomCrop(ImageAugmentor):
"""
:param crop_shape: a shape like (h, w)
"""
crop_shape = shape2d(crop_shape)
super(RandomCrop, self).__init__()
self._init(locals())
......@@ -43,6 +45,7 @@ class CenterCrop(ImageAugmentor):
"""
:param crop_shape: a shape like (h, w)
"""
crop_shape = shape2d(crop_shape)
self._init(locals())
def _augment(self, img, _):
......
......@@ -4,6 +4,7 @@
from .base import ImageAugmentor
from ...utils import logger
from ...utils.argtools import shape2d
import numpy as np
import cv2
......@@ -50,6 +51,7 @@ class Resize(ImageAugmentor):
"""
:param shape: shape in (h, w)
"""
shape = tuple(shape2d(shape))
self._init(locals())
def _augment(self, img, _):
......
......@@ -164,6 +164,7 @@ class PrefetchDataZMQ(ProxyDataFlow):
for x in self.procs:
x.terminate()
try:
# TODO test if logger here would overwrite log file
print("Prefetch process exited.")
except:
pass
......
......@@ -11,6 +11,7 @@ from ..tfutils.argscope import get_arg_scope
from ..tfutils.modelutils import get_shape_str
from ..tfutils.summary import add_activation_summary
from ..utils import logger
from ..utils.argtools import shape2d
# make sure each layer is only logged once
_layer_logged = set()
......@@ -93,17 +94,6 @@ def layer_register(
return wrapper
def shape2d(a):
"""
a: a int or tuple/list of length 2
"""
if type(a) == int:
return [a, a]
if isinstance(a, (list, tuple)):
assert len(a) == 2
return list(a)
raise RuntimeError("Illegal shape: {}".format(a))
def shape4d(a):
# for use with tensorflow NHWC ops
return [1] + shape2d(a) + [1]
......@@ -7,7 +7,8 @@ import numpy as np
import tensorflow as tf
import math
from ._common import layer_register, shape2d, shape4d
from ..utils import map_arg, logger
from ..utils import logger
from ..utils.argtools import shape2d
__all__ = ['Conv2D', 'Deconv2D']
......
......@@ -5,7 +5,8 @@
import tensorflow as tf
import numpy as np
from ._common import layer_register, shape2d, shape4d
from ._common import layer_register, shape4d
from ..utils.argtools import shape2d
from ..tfutils import symbolic_functions as symbf
__all__ = ['MaxPooling', 'FixedUnPooling', 'AvgPooling', 'GlobalAvgPooling',
......
......@@ -6,7 +6,7 @@ import tensorflow as tf
import re
from ..utils import logger
from ..utils.utils import memoized
from ..utils.argtools import memoized
from ..tfutils.tower import get_current_tower_context
from ._common import layer_register
......
......@@ -6,7 +6,7 @@ import six
import tensorflow as tf
import re
from ..utils import *
from ..utils.argtools import memoized
from .tower import get_current_tower_context
from . import get_global_step_var
from .symbolic_functions import rms
......
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# File: argtools.py
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
import inspect, six, functools
import collections
__all__ = [ 'map_arg', 'memoized', 'shape2d']
def map_arg(**maps):
"""
Apply a mapping on certains argument before calling original function.
maps: {key: map_func}
"""
def deco(func):
@functools.wraps(func)
def wrapper(*args, **kwargs):
argmap = inspect.getcallargs(func, *args, **kwargs)
for k, map_func in six.iteritems(maps):
if k in argmap:
argmap[k] = map_func(argmap[k])
return func(**argmap)
return wrapper
return deco
class memoized(object):
'''Decorator. Caches a function's return value each time it is called.
If called later with the same arguments, the cached value is returned
(not reevaluated).
'''
def __init__(self, func):
self.func = func
self.cache = {}
def __call__(self, *args):
if not isinstance(args, collections.Hashable):
# uncacheable. a list, for instance.
# better to not cache than blow up.
return self.func(*args)
if args in self.cache:
return self.cache[args]
else:
value = self.func(*args)
self.cache[args] = value
return value
def __repr__(self):
'''Return the function's docstring.'''
return self.func.__doc__
def __get__(self, obj, objtype):
'''Support instance methods.'''
return functools.partial(self.__call__, obj)
#_GLOBAL_MEMOIZED_CACHE = dict()
#def global_memoized(func):
#""" Make sure that the same `memoized` object is returned on different
#calls to global_memoized(func)
#"""
#ret = _GLOBAL_MEMOIZED_CACHE.get(func, None)
#if ret is None:
#ret = _GLOBAL_MEMOIZED_CACHE[func] = memoized(func)
#return ret
def shape2d(a):
"""
a: a int or tuple/list of length 2
"""
if type(a) == int:
return [a, a]
if isinstance(a, (list, tuple)):
assert len(a) == 2
return list(a)
raise RuntimeError("Illegal shape: {}".format(a))
......@@ -3,7 +3,8 @@
# File: discretize.py
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
from . import logger, memoized
from . import logger
from .argtools import memoized
from abc import abstractmethod, ABCMeta
import numpy as np
from six.moves import range
......
......@@ -10,8 +10,6 @@ from datetime import datetime
from six.moves import input
import sys
from .utils import memoized
__all__ = []
class _MyFormatter(logging.Formatter):
......
......@@ -4,16 +4,13 @@
import os, sys
from contextlib import contextmanager
import inspect, functools
import inspect
from datetime import datetime
import time
import collections
import numpy as np
import six
__all__ = ['change_env',
'map_arg',
'get_rng', 'memoized',
'get_rng',
'get_dataset_path',
'get_tqdm_kwargs',
'execute_only_once'
......@@ -29,62 +26,6 @@ def change_env(name, val):
else:
os.environ[name] = oldval
class memoized(object):
'''Decorator. Caches a function's return value each time it is called.
If called later with the same arguments, the cached value is returned
(not reevaluated).
'''
def __init__(self, func):
self.func = func
self.cache = {}
def __call__(self, *args):
if not isinstance(args, collections.Hashable):
# uncacheable. a list, for instance.
# better to not cache than blow up.
return self.func(*args)
if args in self.cache:
return self.cache[args]
else:
value = self.func(*args)
self.cache[args] = value
return value
def __repr__(self):
'''Return the function's docstring.'''
return self.func.__doc__
def __get__(self, obj, objtype):
'''Support instance methods.'''
return functools.partial(self.__call__, obj)
#_GLOBAL_MEMOIZED_CACHE = dict()
#def global_memoized(func):
#""" Make sure that the same `memoized` object is returned on different
#calls to global_memoized(func)
#"""
#ret = _GLOBAL_MEMOIZED_CACHE.get(func, None)
#if ret is None:
#ret = _GLOBAL_MEMOIZED_CACHE[func] = memoized(func)
#return ret
def map_arg(**maps):
"""
Apply a mapping on certains argument before calling original function.
maps: {key: map_func}
"""
def deco(func):
@functools.wraps(func)
def wrapper(*args, **kwargs):
argmap = inspect.getcallargs(func, *args, **kwargs)
for k, map_func in six.iteritems(maps):
if k in argmap:
argmap[k] = map_func(argmap[k])
return func(**argmap)
return wrapper
return deco
def get_rng(obj=None):
""" obj: some object to use to generate random seed"""
seed = (id(obj) + os.getpid() +
......
......@@ -4,15 +4,19 @@
# Credit: zxytim
import numpy as np
import os, sys
import io
import cv2
from .fs import mkdir_p
from .argtools import shape2d
try:
import matplotlib.pyplot as plt
except ImportError:
pass
__all__ = ['pyplot2img', 'build_patch_list', 'pyplot_viz']
__all__ = ['pyplot2img', 'build_patch_list', 'pyplot_viz',
'dump_dataflow_images']
def pyplot2img(plt):
buf = io.BytesIO()
......@@ -46,6 +50,7 @@ def build_patch_list(patch_list,
max_width=1000, max_height=1000,
shuffle=False, bgcolor=255):
"""
This is a generator.
patch_list: bhw or bhwc
"""
patch_list = np.asarray(patch_list)
......@@ -88,6 +93,53 @@ def build_patch_list(patch_list,
yield canvas
start = end
def dump_dataflow_images(df, index=0, batched=True,
number=300, output_dir=None,
scale=1, resize=None, viz=None, flipRGB=False, exit_after=True):
if output_dir:
mkdir_p(output_dir)
if viz is not None:
viz = shape2d(viz)
vizsize = viz[0] * viz[1]
if resize is not None:
resize = tuple(shape2d(resize))
vizlist = []
df.reset_state()
cnt = 0
while True:
for dp in df.get_data():
if not batched:
imgbatch = [dp[index]]
else:
imgbatch = dp[index]
for img in imgbatch:
cnt += 1
if cnt == number:
if exit_after:
sys.exit()
else:
return
if scale != 1:
img = img * scale
if resize is not None:
img = cv2.resize(img, resize)
if flipRGB:
img = img[:,:,::-1]
if output_dir:
fname = os.path.join(output_dir, '{:03d}.jpg'.format(cnt))
cv2.imwrite(fname, img)
if viz is not None:
vizlist.append(img)
if viz is not None and len(vizlist) >= vizsize:
patch = next(build_patch_list(
vizlist[:vizsize],
nr_row=viz[0], nr_col=viz[1]))
cv2.imshow("df-viz", patch)
cv2.waitKey()
vizlist = vizlist[vizsize:]
if __name__ == '__main__':
import cv2
imglist = []
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment