Commit fec86fec authored by Yuxin Wu's avatar Yuxin Wu

Generalize imagenet_utils

parent 0dabefa2
......@@ -24,12 +24,16 @@ class GoogleNetResize(imgaug.ImageAugmentor):
crop 8%~100% of the original image
See `Going Deeper with Convolutions` by Google.
"""
def __init__(self, crop_area_fraction=0.08,
aspect_ratio_low=0.75, aspect_ratio_high=1.333):
self._init(locals())
def _augment(self, img, _):
h, w = img.shape[:2]
area = h * w
for _ in range(10):
targetArea = self.rng.uniform(0.08, 1.0) * area
aspectR = self.rng.uniform(0.75, 1.333)
targetArea = self.rng.uniform(self.crop_area_fraction, 1.0) * area
aspectR = self.rng.uniform(self.aspect_ratio_low, self.aspect_ratio_high)
ww = int(np.sqrt(targetArea * aspectR) + 0.5)
hh = int(np.sqrt(targetArea / aspectR) + 0.5)
if self.rng.uniform() < 0.5:
......@@ -160,16 +164,20 @@ def compute_loss_and_error(logits, label):
class ImageNetModel(ModelDesc):
def __init__(self, data_format='NCHW', image_dtype=tf.uint8):
weight_decay = 1e-4
"""
uint8 instead of float32 is used as input type to reduce copy overhead.
It might hurt the performance a liiiitle bit.
The pretrained models were trained with float32.
"""
image_dtype = tf.uint8
def __init__(self, data_format='NCHW'):
if data_format == 'NCHW':
assert tf.test.is_gpu_available()
self.data_format = data_format
# uint8 instead of float32 is used as input type to reduce copy overhead.
# It might hurt the performance a liiiitle bit.
# The pretrained models were trained with float32.
self.image_dtype = image_dtype
def _get_inputs(self):
return [InputDesc(self.image_dtype, [None, 224, 224, 3], 'input'),
InputDesc(tf.int32, [None], 'label')]
......@@ -182,7 +190,8 @@ class ImageNetModel(ModelDesc):
logits = self.get_logits(image)
loss = compute_loss_and_error(logits, label)
wd_loss = regularize_cost('.*/W', tf.contrib.layers.l2_regularizer(1e-4), name='l2_regularize_loss')
wd_loss = regularize_cost('.*/W', tf.contrib.layers.l2_regularizer(self.weight_decay),
name='l2_regularize_loss')
add_moving_summary(loss, wd_loss)
self.cost = tf.add_n([loss, wd_loss], name='cost')
......
......@@ -31,6 +31,7 @@ def regularize_cost(regex, func, name='regularize_cost'):
Args:
regex (str): a regex to match variable names, e.g. "conv.*/W"
func: the regularization function, which takes a tensor and returns a scalar tensor.
E.g., ``tf.contrib.layers.l2_regularizer``.
Returns:
tf.Tensor: the total regularization cost.
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment