Commit 43489024 authored by Yuxin Wu's avatar Yuxin Wu

shape2d in argtools

parent 2c8cd32f
...@@ -108,7 +108,7 @@ def annotate_min_max(data_x, data_y, ax): ...@@ -108,7 +108,7 @@ def annotate_min_max(data_x, data_y, ax):
x_max, y_max = data_y[0], data_y[0] x_max, y_max = data_y[0], data_y[0]
x_min, y_min = data_x[0], data_y[0] x_min, y_min = data_x[0], data_y[0]
for i in xrange(1, len(data_x)): for i in range(1, len(data_x)):
if data_y[i] > y_max: if data_y[i] > y_max:
y_max = data_y[i] y_max = data_y[i]
x_max = data_x[i] x_max = data_x[i]
......
...@@ -9,7 +9,7 @@ import numpy as np ...@@ -9,7 +9,7 @@ import numpy as np
from six.moves import range from six.moves import range
import xml.etree.ElementTree as ET import xml.etree.ElementTree as ET
from ...utils import logger, get_rng, get_dataset_path, memoized from ...utils import logger, get_rng, get_dataset_path
from ...utils.loadcaffe import get_caffe_pb from ...utils.loadcaffe import get_caffe_pb
from ...utils.fs import mkdir_p, download from ...utils.fs import mkdir_p, download
from ...utils.timer import timed_operation from ...utils.timer import timed_operation
...@@ -17,9 +17,6 @@ from ..base import RNGDataFlow ...@@ -17,9 +17,6 @@ from ..base import RNGDataFlow
__all__ = ['ILSVRCMeta', 'ILSVRC12'] __all__ = ['ILSVRCMeta', 'ILSVRC12']
@memoized
def log_once(s): logger.warn(s)
CAFFE_ILSVRC12_URL = "http://dl.caffe.berkeleyvision.org/caffe_ilsvrc12.tar.gz" CAFFE_ILSVRC12_URL = "http://dl.caffe.berkeleyvision.org/caffe_ilsvrc12.tar.gz"
class ILSVRCMeta(object): class ILSVRCMeta(object):
......
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
from .base import ImageAugmentor from .base import ImageAugmentor
from ...utils.rect import Rect from ...utils.rect import Rect
from ...utils.argtools import shape2d
from six.moves import range from six.moves import range
import numpy as np import numpy as np
...@@ -17,6 +18,7 @@ class RandomCrop(ImageAugmentor): ...@@ -17,6 +18,7 @@ class RandomCrop(ImageAugmentor):
""" """
:param crop_shape: a shape like (h, w) :param crop_shape: a shape like (h, w)
""" """
crop_shape = shape2d(crop_shape)
super(RandomCrop, self).__init__() super(RandomCrop, self).__init__()
self._init(locals()) self._init(locals())
...@@ -43,6 +45,7 @@ class CenterCrop(ImageAugmentor): ...@@ -43,6 +45,7 @@ class CenterCrop(ImageAugmentor):
""" """
:param crop_shape: a shape like (h, w) :param crop_shape: a shape like (h, w)
""" """
crop_shape = shape2d(crop_shape)
self._init(locals()) self._init(locals())
def _augment(self, img, _): def _augment(self, img, _):
......
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
from .base import ImageAugmentor from .base import ImageAugmentor
from ...utils import logger from ...utils import logger
from ...utils.argtools import shape2d
import numpy as np import numpy as np
import cv2 import cv2
...@@ -50,6 +51,7 @@ class Resize(ImageAugmentor): ...@@ -50,6 +51,7 @@ class Resize(ImageAugmentor):
""" """
:param shape: shape in (h, w) :param shape: shape in (h, w)
""" """
shape = tuple(shape2d(shape))
self._init(locals()) self._init(locals())
def _augment(self, img, _): def _augment(self, img, _):
......
...@@ -164,6 +164,7 @@ class PrefetchDataZMQ(ProxyDataFlow): ...@@ -164,6 +164,7 @@ class PrefetchDataZMQ(ProxyDataFlow):
for x in self.procs: for x in self.procs:
x.terminate() x.terminate()
try: try:
# TODO test if logger here would overwrite log file
print("Prefetch process exited.") print("Prefetch process exited.")
except: except:
pass pass
......
...@@ -11,6 +11,7 @@ from ..tfutils.argscope import get_arg_scope ...@@ -11,6 +11,7 @@ from ..tfutils.argscope import get_arg_scope
from ..tfutils.modelutils import get_shape_str from ..tfutils.modelutils import get_shape_str
from ..tfutils.summary import add_activation_summary from ..tfutils.summary import add_activation_summary
from ..utils import logger from ..utils import logger
from ..utils.argtools import shape2d
# make sure each layer is only logged once # make sure each layer is only logged once
_layer_logged = set() _layer_logged = set()
...@@ -93,17 +94,6 @@ def layer_register( ...@@ -93,17 +94,6 @@ def layer_register(
return wrapper return wrapper
def shape2d(a):
"""
a: a int or tuple/list of length 2
"""
if type(a) == int:
return [a, a]
if isinstance(a, (list, tuple)):
assert len(a) == 2
return list(a)
raise RuntimeError("Illegal shape: {}".format(a))
def shape4d(a): def shape4d(a):
# for use with tensorflow NHWC ops # for use with tensorflow NHWC ops
return [1] + shape2d(a) + [1] return [1] + shape2d(a) + [1]
...@@ -7,7 +7,8 @@ import numpy as np ...@@ -7,7 +7,8 @@ import numpy as np
import tensorflow as tf import tensorflow as tf
import math import math
from ._common import layer_register, shape2d, shape4d from ._common import layer_register, shape2d, shape4d
from ..utils import map_arg, logger from ..utils import logger
from ..utils.argtools import shape2d
__all__ = ['Conv2D', 'Deconv2D'] __all__ = ['Conv2D', 'Deconv2D']
......
...@@ -5,7 +5,8 @@ ...@@ -5,7 +5,8 @@
import tensorflow as tf import tensorflow as tf
import numpy as np import numpy as np
from ._common import layer_register, shape2d, shape4d from ._common import layer_register, shape4d
from ..utils.argtools import shape2d
from ..tfutils import symbolic_functions as symbf from ..tfutils import symbolic_functions as symbf
__all__ = ['MaxPooling', 'FixedUnPooling', 'AvgPooling', 'GlobalAvgPooling', __all__ = ['MaxPooling', 'FixedUnPooling', 'AvgPooling', 'GlobalAvgPooling',
......
...@@ -6,7 +6,7 @@ import tensorflow as tf ...@@ -6,7 +6,7 @@ import tensorflow as tf
import re import re
from ..utils import logger from ..utils import logger
from ..utils.utils import memoized from ..utils.argtools import memoized
from ..tfutils.tower import get_current_tower_context from ..tfutils.tower import get_current_tower_context
from ._common import layer_register from ._common import layer_register
......
...@@ -6,7 +6,7 @@ import six ...@@ -6,7 +6,7 @@ import six
import tensorflow as tf import tensorflow as tf
import re import re
from ..utils import * from ..utils.argtools import memoized
from .tower import get_current_tower_context from .tower import get_current_tower_context
from . import get_global_step_var from . import get_global_step_var
from .symbolic_functions import rms from .symbolic_functions import rms
......
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# File: argtools.py
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
import inspect, six, functools
import collections
__all__ = [ 'map_arg', 'memoized', 'shape2d']
def map_arg(**maps):
"""
Apply a mapping on certains argument before calling original function.
maps: {key: map_func}
"""
def deco(func):
@functools.wraps(func)
def wrapper(*args, **kwargs):
argmap = inspect.getcallargs(func, *args, **kwargs)
for k, map_func in six.iteritems(maps):
if k in argmap:
argmap[k] = map_func(argmap[k])
return func(**argmap)
return wrapper
return deco
class memoized(object):
'''Decorator. Caches a function's return value each time it is called.
If called later with the same arguments, the cached value is returned
(not reevaluated).
'''
def __init__(self, func):
self.func = func
self.cache = {}
def __call__(self, *args):
if not isinstance(args, collections.Hashable):
# uncacheable. a list, for instance.
# better to not cache than blow up.
return self.func(*args)
if args in self.cache:
return self.cache[args]
else:
value = self.func(*args)
self.cache[args] = value
return value
def __repr__(self):
'''Return the function's docstring.'''
return self.func.__doc__
def __get__(self, obj, objtype):
'''Support instance methods.'''
return functools.partial(self.__call__, obj)
#_GLOBAL_MEMOIZED_CACHE = dict()
#def global_memoized(func):
#""" Make sure that the same `memoized` object is returned on different
#calls to global_memoized(func)
#"""
#ret = _GLOBAL_MEMOIZED_CACHE.get(func, None)
#if ret is None:
#ret = _GLOBAL_MEMOIZED_CACHE[func] = memoized(func)
#return ret
def shape2d(a):
"""
a: a int or tuple/list of length 2
"""
if type(a) == int:
return [a, a]
if isinstance(a, (list, tuple)):
assert len(a) == 2
return list(a)
raise RuntimeError("Illegal shape: {}".format(a))
...@@ -3,7 +3,8 @@ ...@@ -3,7 +3,8 @@
# File: discretize.py # File: discretize.py
# Author: Yuxin Wu <ppwwyyxxc@gmail.com> # Author: Yuxin Wu <ppwwyyxxc@gmail.com>
from . import logger, memoized from . import logger
from .argtools import memoized
from abc import abstractmethod, ABCMeta from abc import abstractmethod, ABCMeta
import numpy as np import numpy as np
from six.moves import range from six.moves import range
......
...@@ -10,8 +10,6 @@ from datetime import datetime ...@@ -10,8 +10,6 @@ from datetime import datetime
from six.moves import input from six.moves import input
import sys import sys
from .utils import memoized
__all__ = [] __all__ = []
class _MyFormatter(logging.Formatter): class _MyFormatter(logging.Formatter):
......
...@@ -4,16 +4,13 @@ ...@@ -4,16 +4,13 @@
import os, sys import os, sys
from contextlib import contextmanager from contextlib import contextmanager
import inspect, functools import inspect
from datetime import datetime from datetime import datetime
import time import time
import collections
import numpy as np import numpy as np
import six
__all__ = ['change_env', __all__ = ['change_env',
'map_arg', 'get_rng',
'get_rng', 'memoized',
'get_dataset_path', 'get_dataset_path',
'get_tqdm_kwargs', 'get_tqdm_kwargs',
'execute_only_once' 'execute_only_once'
...@@ -29,62 +26,6 @@ def change_env(name, val): ...@@ -29,62 +26,6 @@ def change_env(name, val):
else: else:
os.environ[name] = oldval os.environ[name] = oldval
class memoized(object):
'''Decorator. Caches a function's return value each time it is called.
If called later with the same arguments, the cached value is returned
(not reevaluated).
'''
def __init__(self, func):
self.func = func
self.cache = {}
def __call__(self, *args):
if not isinstance(args, collections.Hashable):
# uncacheable. a list, for instance.
# better to not cache than blow up.
return self.func(*args)
if args in self.cache:
return self.cache[args]
else:
value = self.func(*args)
self.cache[args] = value
return value
def __repr__(self):
'''Return the function's docstring.'''
return self.func.__doc__
def __get__(self, obj, objtype):
'''Support instance methods.'''
return functools.partial(self.__call__, obj)
#_GLOBAL_MEMOIZED_CACHE = dict()
#def global_memoized(func):
#""" Make sure that the same `memoized` object is returned on different
#calls to global_memoized(func)
#"""
#ret = _GLOBAL_MEMOIZED_CACHE.get(func, None)
#if ret is None:
#ret = _GLOBAL_MEMOIZED_CACHE[func] = memoized(func)
#return ret
def map_arg(**maps):
"""
Apply a mapping on certains argument before calling original function.
maps: {key: map_func}
"""
def deco(func):
@functools.wraps(func)
def wrapper(*args, **kwargs):
argmap = inspect.getcallargs(func, *args, **kwargs)
for k, map_func in six.iteritems(maps):
if k in argmap:
argmap[k] = map_func(argmap[k])
return func(**argmap)
return wrapper
return deco
def get_rng(obj=None): def get_rng(obj=None):
""" obj: some object to use to generate random seed""" """ obj: some object to use to generate random seed"""
seed = (id(obj) + os.getpid() + seed = (id(obj) + os.getpid() +
......
...@@ -4,15 +4,19 @@ ...@@ -4,15 +4,19 @@
# Credit: zxytim # Credit: zxytim
import numpy as np import numpy as np
import os, sys
import io import io
import cv2 import cv2
from .fs import mkdir_p
from .argtools import shape2d
try: try:
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
except ImportError: except ImportError:
pass pass
__all__ = ['pyplot2img', 'build_patch_list', 'pyplot_viz'] __all__ = ['pyplot2img', 'build_patch_list', 'pyplot_viz',
'dump_dataflow_images']
def pyplot2img(plt): def pyplot2img(plt):
buf = io.BytesIO() buf = io.BytesIO()
...@@ -46,6 +50,7 @@ def build_patch_list(patch_list, ...@@ -46,6 +50,7 @@ def build_patch_list(patch_list,
max_width=1000, max_height=1000, max_width=1000, max_height=1000,
shuffle=False, bgcolor=255): shuffle=False, bgcolor=255):
""" """
This is a generator.
patch_list: bhw or bhwc patch_list: bhw or bhwc
""" """
patch_list = np.asarray(patch_list) patch_list = np.asarray(patch_list)
...@@ -88,6 +93,53 @@ def build_patch_list(patch_list, ...@@ -88,6 +93,53 @@ def build_patch_list(patch_list,
yield canvas yield canvas
start = end start = end
def dump_dataflow_images(df, index=0, batched=True,
number=300, output_dir=None,
scale=1, resize=None, viz=None, flipRGB=False, exit_after=True):
if output_dir:
mkdir_p(output_dir)
if viz is not None:
viz = shape2d(viz)
vizsize = viz[0] * viz[1]
if resize is not None:
resize = tuple(shape2d(resize))
vizlist = []
df.reset_state()
cnt = 0
while True:
for dp in df.get_data():
if not batched:
imgbatch = [dp[index]]
else:
imgbatch = dp[index]
for img in imgbatch:
cnt += 1
if cnt == number:
if exit_after:
sys.exit()
else:
return
if scale != 1:
img = img * scale
if resize is not None:
img = cv2.resize(img, resize)
if flipRGB:
img = img[:,:,::-1]
if output_dir:
fname = os.path.join(output_dir, '{:03d}.jpg'.format(cnt))
cv2.imwrite(fname, img)
if viz is not None:
vizlist.append(img)
if viz is not None and len(vizlist) >= vizsize:
patch = next(build_patch_list(
vizlist[:vizsize],
nr_row=viz[0], nr_col=viz[1]))
cv2.imshow("df-viz", patch)
cv2.waitKey()
vizlist = vizlist[vizsize:]
if __name__ == '__main__': if __name__ == '__main__':
import cv2 import cv2
imglist = [] imglist = []
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment