Commit 9d83d921 authored by Yuxin Wu's avatar Yuxin Wu

imgaug & fix svhn-dorefa double fw

parent abd38c59
......@@ -34,7 +34,7 @@ Multi-GPU training is ready to use by simply changing the trainer.
+ other requirements:
```
pip install --user -r requirements.txt
pip install --user -r opt-requirements.txt (some optional dependencies)
pip install --user -r opt-requirements.txt (some optional dependencies, you can install later if needed)
```
+ Use [tcmalloc](http://goog-perftools.sourceforge.net/doc/tcmalloc.html) whenever possible: see [TF issue](https://github.com/tensorflow/tensorflow/issues/2942)
+ allow `import tensorpack` everywhere:
......
......@@ -145,6 +145,7 @@ class Model(ModelDesc):
.apply(fg).BatchNorm('bn6')
.apply(cabs)
.FullyConnected('fc1', 10, nl=tf.identity)())
tf.get_variable = old_get_variable
prob = tf.nn.softmax(logits, name='output')
cost = tf.nn.sparse_softmax_cross_entropy_with_logits(logits, label)
......
......@@ -4,3 +4,4 @@ nltk
h5py
pyzmq
tornado; python_version < '3.0'
lmdb
......@@ -4,6 +4,7 @@
from .base import ImageAugmentor
import numpy as np
import cv2
__all__ = ['Brightness', 'Contrast', 'MeanVarianceNormalize', 'GaussianBlur', 'Gamma']
......@@ -97,6 +98,7 @@ class Gamma(ImageAugmentor):
return self._rand_range(*self.range)
def _augment(self, img, gamma):
lut = ((np.arange(256, dtype='float32') / 255) ** (1. / (1. + gamma)) * 255).astype('uint8')
cv2.LUT(img, lut, img)
img = (img * 255.0).astype('uint8')
img = cv2.LUT(img, lut).astype('float32') / 255.0
return img
......@@ -6,12 +6,31 @@
from .base import ImageAugmentor
__all__ = ['RandomChooseAug', 'MapImage', 'Identity']
__all__ = ['RandomChooseAug', 'MapImage', 'Identity', 'RandomApplyAug']
class Identity(ImageAugmentor):
def _augment(self, img, _):
return img
class RandomApplyAug(ImageAugmentor):
""" Randomly apply the augmentor with a prob. Otherwise do nothing"""
def __init__(self, aug, prob):
self._init(locals())
def _get_augment_params(self, img):
p = self.rng.rand()
if p < self.prob:
prm = self.aug._get_augment_params(img)
return (True, prm)
else:
return (False, None)
def _augment(self, img, prm):
if not prm[0]:
return img
else:
return self.aug._augment(img, prm[1])
class RandomChooseAug(ImageAugmentor):
def __init__(self, aug_lists):
"""
......
......@@ -71,8 +71,8 @@ class RandomResize(ImageAugmentor):
while True:
sx = self._rand_range(*self.xrange)
sy = self._rand_range(*self.yrange)
destX = max(sx * img.shape[1], self.minimum[0])
destY = max(sy * img.shape[0], self.minimum[1])
destX = int(max(sx * img.shape[1], self.minimum[0]))
destY = int(max(sy * img.shape[0], self.minimum[1]))
oldr = img.shape[1] * 1.0 / img.shape[0]
newr = destX * 1.0 / destY
diff = abs(newr - oldr) / oldr
......
......@@ -8,7 +8,7 @@ from .base import ImageAugmentor
from abc import abstractmethod
import numpy as np
__all__ = [ 'CenterPaste', 'ConstantBackgroundFiller']
__all__ = [ 'CenterPaste', 'BackgroundFiller', 'ConstantBackgroundFiller']
class BackgroundFiller(object):
......
......@@ -34,15 +34,13 @@ def layer_register(summary_activation=False, log_shape=True):
#@decorator only enable me when building docs.
def wrapper(func):
@wraps(func)
def wrapped_func(*args, **kwargs):
name = args[0]
assert isinstance(name, six.string_types), \
'name must be the first argument. Args: {}'.format(args)
args = args[1:]
def wrapped_func(name, inputs, *args, **kwargs):
assert isinstance(name, six.string_types), name
do_summary = kwargs.pop(
'summary_activation', summary_activation)
inputs = args[0]
args = (inputs,) + args
# TODO use inspect.getcallargs to enhance?
# update from current argument scope
actual_args = copy.copy(get_arg_scope()[func.__name__])
actual_args.update(kwargs)
......
......@@ -7,6 +7,7 @@ import numpy as np
import tensorflow as tf
import math
from ._common import *
from ..utils import map_arg
__all__ = ['Conv2D']
......
......@@ -17,7 +17,8 @@ def describe_model():
total += ele
msg.append("{}: shape={}, dim={}".format(
v.name, shape.as_list(), ele))
msg.append("Total dim={}".format(total))
size_mb = total * 4 / 1024.0**2
msg.append("Total param={} ({:01f} MB)".format(total, size_mb))
logger.info("Model Params: {}".format('\n'.join(msg)))
......
......@@ -4,14 +4,16 @@
import os, sys
from contextlib import contextmanager
import inspect, functools
from datetime import datetime
import time
import collections
import numpy as np
import six
from . import logger
__all__ = ['change_env',
__all__ = ['change_env', 'map_arg',
'get_rng', 'memoized',
'get_nr_gpu',
'get_gpus',
......@@ -70,6 +72,22 @@ class memoized(object):
'''Support instance methods.'''
return functools.partial(self.__call__, obj)
def map_arg(**maps):
"""
Apply a mapping on certains argument before calling original function.
maps: {key: map_func}
"""
def deco(func):
@functools.wraps(func)
def wrapper(*args, **kwargs):
argmap = inspect.getcallargs(func, *args, **kwargs)
for k, map_func in six.iteritems(maps):
if k in argmap:
argmap[k] = map_func(argmap[k])
return func(**argmap)
return wrapper
return deco
def get_rng(obj=None):
""" obj: some object to use to generate random seed"""
seed = (id(obj) + os.getpid() +
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment