Commit 466b5192 authored by Yuxin Wu's avatar Yuxin Wu

simplify pre-resnet model code

parent 02f41b25
......@@ -112,25 +112,21 @@ def get_imagenet_dataflow(
return ds
def resnet_shortcut(l, n_out, stride):
def resnet_shortcut(l, n_out, stride, nl=tf.identity):
data_format = get_arg_scope()['Conv2D']['data_format']
n_in = l.get_shape().as_list()[1 if data_format == 'NCHW' else 3]
if n_in != n_out: # change dimension when channel is not the same
return Conv2D('convshortcut', l, n_out, 1, stride=stride)
return Conv2D('convshortcut', l, n_out, 1, stride=stride, nl=nl)
else:
return l
def apply_preactivation(l, preact):
"""
'no_preact' for the first resblock only, because the input is activated already.
'both_preact' for the first block in each group, due to the projection shotcut.
'no_preact' for the first resblock in each group only, because the input is activated already.
'default' for all the non-first blocks, where identity mapping is preserved on shortcut path.
"""
if preact == 'both_preact':
l = BNReLU('preact', l)
shortcut = l
elif preact == 'default':
if preact == 'default':
shortcut = l
l = BNReLU('preact', l)
else:
......@@ -153,15 +149,17 @@ def resnet_bottleneck(l, ch_out, stride, preact):
return l + resnet_shortcut(shortcut, ch_out * 4, stride)
def resnet_group(l, name, block_func, features, count, stride, first=False):
def preresnet_group(l, name, block_func, features, count, stride):
with tf.variable_scope(name):
with tf.variable_scope('block0'):
l = block_func(l, features, stride,
'no_preact' if first else 'both_preact')
for i in range(1, count):
for i in range(0, count):
with tf.variable_scope('block{}'.format(i)):
l = block_func(l, features, 1, 'default')
return l
# first block doesn't need activation
l = block_func(l, features,
stride if i == 0 else 1,
'no_preact' if i == 0 else 'default')
# end of each group need an extra activation
l = BNReLU('bnlast', l)
return l
def resnet_backbone(image, num_blocks, block_func):
......@@ -170,11 +168,10 @@ def resnet_backbone(image, num_blocks, block_func):
logits = (LinearWrap(image)
.Conv2D('conv0', 64, 7, stride=2, nl=BNReLU)
.MaxPooling('pool0', shape=3, stride=2, padding='SAME')
.apply(resnet_group, 'group0', block_func, 64, num_blocks[0], 1, first=True)
.apply(resnet_group, 'group1', block_func, 128, num_blocks[1], 2)
.apply(resnet_group, 'group2', block_func, 256, num_blocks[2], 2)
.apply(resnet_group, 'group3', block_func, 512, num_blocks[3], 2)
.BNReLU('bnlast')
.apply(preresnet_group, 'group0', block_func, 64, num_blocks[0], 1)
.apply(preresnet_group, 'group1', block_func, 128, num_blocks[1], 2)
.apply(preresnet_group, 'group2', block_func, 256, num_blocks[2], 2)
.apply(preresnet_group, 'group3', block_func, 512, num_blocks[3], 2)
.GlobalAvgPooling('gap')
.FullyConnected('linear', 1000, nl=tf.identity)())
return logits
......
......@@ -20,7 +20,7 @@ from tensorpack.utils.gpu import get_nr_gpu
from tensorpack.utils import viz
from imagenet_resnet_utils import (
fbresnet_augmentor, resnet_basicblock, resnet_bottleneck, resnet_group,
fbresnet_augmentor, resnet_basicblock, preresnet_group,
image_preprocess, compute_loss_and_error)
......@@ -42,8 +42,6 @@ class Model(ModelDesc):
cfg = {
18: ([2, 2, 2, 2], resnet_basicblock),
34: ([3, 4, 6, 3], resnet_basicblock),
50: ([3, 4, 6, 3], resnet_bottleneck),
101: ([3, 4, 23, 3], resnet_bottleneck)
}
defs, block_func = cfg[DEPTH]
......@@ -53,11 +51,10 @@ class Model(ModelDesc):
convmaps = (LinearWrap(image)
.Conv2D('conv0', 64, 7, stride=2, nl=BNReLU)
.MaxPooling('pool0', shape=3, stride=2, padding='SAME')
.apply(resnet_group, 'group0', block_func, 64, defs[0], 1, first=True)
.apply(resnet_group, 'group1', block_func, 128, defs[1], 2)
.apply(resnet_group, 'group2', block_func, 256, defs[2], 2)
.apply(resnet_group, 'group3new', block_func, 512, defs[3], 1)
.BNReLU('bnlast')())
.apply(preresnet_group, 'group0', block_func, 64, defs[0], 1)
.apply(preresnet_group, 'group1', block_func, 128, defs[1], 2)
.apply(preresnet_group, 'group2', block_func, 256, defs[2], 2)
.apply(preresnet_group, 'group3new', block_func, 512, defs[3], 1)())
print(convmaps)
logits = (LinearWrap(convmaps)
.GlobalAvgPooling('gap')
......@@ -125,7 +122,7 @@ def viz_cam(model_file, data_dir):
model=Model(),
session_init=get_model_loader(model_file),
input_names=['input', 'label'],
output_names=['wrong-top1', 'bnlast/Relu', 'linearnew/W'],
output_names=['wrong-top1', 'group3new/bnlast/Relu', 'linearnew/W'],
return_input=True
)
meta = dataset.ILSVRCMeta().get_synset_words_1000()
......
......@@ -317,6 +317,7 @@ class ThreadedMapData(ProxyDataFlow):
self.buffer_size = buffer_size
self.map_func = map_func
self._threads = []
self._evt = None
def reset_state(self):
super(ThreadedMapData, self).reset_state()
......@@ -372,6 +373,7 @@ class ThreadedMapData(ProxyDataFlow):
yield self._out_queue.get()
def __del__(self):
self._evt.set()
if self._evt is not None:
self._evt.set()
for p in self._threads:
p.join()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment