Commit fec86fec authored by Yuxin Wu's avatar Yuxin Wu

Generalize imagenet_utils

parent 0dabefa2
...@@ -24,12 +24,16 @@ class GoogleNetResize(imgaug.ImageAugmentor): ...@@ -24,12 +24,16 @@ class GoogleNetResize(imgaug.ImageAugmentor):
crop 8%~100% of the original image crop 8%~100% of the original image
See `Going Deeper with Convolutions` by Google. See `Going Deeper with Convolutions` by Google.
""" """
def __init__(self, crop_area_fraction=0.08,
aspect_ratio_low=0.75, aspect_ratio_high=1.333):
self._init(locals())
def _augment(self, img, _): def _augment(self, img, _):
h, w = img.shape[:2] h, w = img.shape[:2]
area = h * w area = h * w
for _ in range(10): for _ in range(10):
targetArea = self.rng.uniform(0.08, 1.0) * area targetArea = self.rng.uniform(self.crop_area_fraction, 1.0) * area
aspectR = self.rng.uniform(0.75, 1.333) aspectR = self.rng.uniform(self.aspect_ratio_low, self.aspect_ratio_high)
ww = int(np.sqrt(targetArea * aspectR) + 0.5) ww = int(np.sqrt(targetArea * aspectR) + 0.5)
hh = int(np.sqrt(targetArea / aspectR) + 0.5) hh = int(np.sqrt(targetArea / aspectR) + 0.5)
if self.rng.uniform() < 0.5: if self.rng.uniform() < 0.5:
...@@ -160,16 +164,20 @@ def compute_loss_and_error(logits, label): ...@@ -160,16 +164,20 @@ def compute_loss_and_error(logits, label):
class ImageNetModel(ModelDesc): class ImageNetModel(ModelDesc):
def __init__(self, data_format='NCHW', image_dtype=tf.uint8): weight_decay = 1e-4
"""
uint8 instead of float32 is used as input type to reduce copy overhead.
It might hurt the performance a liiiitle bit.
The pretrained models were trained with float32.
"""
image_dtype = tf.uint8
def __init__(self, data_format='NCHW'):
if data_format == 'NCHW': if data_format == 'NCHW':
assert tf.test.is_gpu_available() assert tf.test.is_gpu_available()
self.data_format = data_format self.data_format = data_format
# uint8 instead of float32 is used as input type to reduce copy overhead.
# It might hurt the performance a liiiitle bit.
# The pretrained models were trained with float32.
self.image_dtype = image_dtype
def _get_inputs(self): def _get_inputs(self):
return [InputDesc(self.image_dtype, [None, 224, 224, 3], 'input'), return [InputDesc(self.image_dtype, [None, 224, 224, 3], 'input'),
InputDesc(tf.int32, [None], 'label')] InputDesc(tf.int32, [None], 'label')]
...@@ -182,7 +190,8 @@ class ImageNetModel(ModelDesc): ...@@ -182,7 +190,8 @@ class ImageNetModel(ModelDesc):
logits = self.get_logits(image) logits = self.get_logits(image)
loss = compute_loss_and_error(logits, label) loss = compute_loss_and_error(logits, label)
wd_loss = regularize_cost('.*/W', tf.contrib.layers.l2_regularizer(1e-4), name='l2_regularize_loss') wd_loss = regularize_cost('.*/W', tf.contrib.layers.l2_regularizer(self.weight_decay),
name='l2_regularize_loss')
add_moving_summary(loss, wd_loss) add_moving_summary(loss, wd_loss)
self.cost = tf.add_n([loss, wd_loss], name='cost') self.cost = tf.add_n([loss, wd_loss], name='cost')
......
...@@ -31,6 +31,7 @@ def regularize_cost(regex, func, name='regularize_cost'): ...@@ -31,6 +31,7 @@ def regularize_cost(regex, func, name='regularize_cost'):
Args: Args:
regex (str): a regex to match variable names, e.g. "conv.*/W" regex (str): a regex to match variable names, e.g. "conv.*/W"
func: the regularization function, which takes a tensor and returns a scalar tensor. func: the regularization function, which takes a tensor and returns a scalar tensor.
E.g., ``tf.contrib.layers.l2_regularizer``.
Returns: Returns:
tf.Tensor: the total regularization cost. tf.Tensor: the total regularization cost.
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment