Commit defced98 authored by Yuxin Wu's avatar Yuxin Wu

small cleanups

parent d2048681
......@@ -22,7 +22,7 @@ def GroupNorm(x, group, gamma_initializer=tf.constant_initializer(1.)):
"""
shape = x.get_shape().as_list()
ndims = len(shape)
assert ndims in [2, 4]
assert ndims == 4, shape
chan = shape[1]
assert chan % group == 0, chan
group_size = chan // group
......@@ -61,8 +61,7 @@ class Model(ImageNetModel):
weight_decay = 5e-4
def get_logits(self, image):
with argscope(Conv2D, kernel_size=3,
kernel_initializer=tf.variance_scaling_initializer(scale=2.)), \
with argscope(Conv2D, kernel_initializer=tf.variance_scaling_initializer(scale=2.)), \
argscope([Conv2D, MaxPooling, BatchNorm], data_format='channels_first'):
logits = (LinearWrap(image)
.apply(convnormrelu, 'conv1_1', 64)
......@@ -104,7 +103,6 @@ class Model(ImageNetModel):
def get_data(name, batch):
isTrain = name == 'train'
global args
augmentors = fbresnet_augmentor(isTrain)
return get_imagenet_dataflow(args.data, name, batch, augmentors)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment