Commit defced98 authored by Yuxin Wu's avatar Yuxin Wu

small cleanups

parent d2048681
...@@ -22,7 +22,7 @@ def GroupNorm(x, group, gamma_initializer=tf.constant_initializer(1.)): ...@@ -22,7 +22,7 @@ def GroupNorm(x, group, gamma_initializer=tf.constant_initializer(1.)):
""" """
shape = x.get_shape().as_list() shape = x.get_shape().as_list()
ndims = len(shape) ndims = len(shape)
assert ndims in [2, 4] assert ndims == 4, shape
chan = shape[1] chan = shape[1]
assert chan % group == 0, chan assert chan % group == 0, chan
group_size = chan // group group_size = chan // group
...@@ -61,8 +61,7 @@ class Model(ImageNetModel): ...@@ -61,8 +61,7 @@ class Model(ImageNetModel):
weight_decay = 5e-4 weight_decay = 5e-4
def get_logits(self, image): def get_logits(self, image):
with argscope(Conv2D, kernel_size=3, with argscope(Conv2D, kernel_initializer=tf.variance_scaling_initializer(scale=2.)), \
kernel_initializer=tf.variance_scaling_initializer(scale=2.)), \
argscope([Conv2D, MaxPooling, BatchNorm], data_format='channels_first'): argscope([Conv2D, MaxPooling, BatchNorm], data_format='channels_first'):
logits = (LinearWrap(image) logits = (LinearWrap(image)
.apply(convnormrelu, 'conv1_1', 64) .apply(convnormrelu, 'conv1_1', 64)
...@@ -104,7 +103,6 @@ class Model(ImageNetModel): ...@@ -104,7 +103,6 @@ class Model(ImageNetModel):
def get_data(name, batch): def get_data(name, batch):
isTrain = name == 'train' isTrain = name == 'train'
global args
augmentors = fbresnet_augmentor(isTrain) augmentors = fbresnet_augmentor(isTrain)
return get_imagenet_dataflow(args.data, name, batch, augmentors) return get_imagenet_dataflow(args.data, name, batch, augmentors)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment