Commit 466b5192 authored by Yuxin Wu's avatar Yuxin Wu

simplify pre-resnet model code

parent 02f41b25
...@@ -112,25 +112,21 @@ def get_imagenet_dataflow( ...@@ -112,25 +112,21 @@ def get_imagenet_dataflow(
return ds return ds
def resnet_shortcut(l, n_out, stride): def resnet_shortcut(l, n_out, stride, nl=tf.identity):
data_format = get_arg_scope()['Conv2D']['data_format'] data_format = get_arg_scope()['Conv2D']['data_format']
n_in = l.get_shape().as_list()[1 if data_format == 'NCHW' else 3] n_in = l.get_shape().as_list()[1 if data_format == 'NCHW' else 3]
if n_in != n_out: # change dimension when channel is not the same if n_in != n_out: # change dimension when channel is not the same
return Conv2D('convshortcut', l, n_out, 1, stride=stride) return Conv2D('convshortcut', l, n_out, 1, stride=stride, nl=nl)
else: else:
return l return l
def apply_preactivation(l, preact): def apply_preactivation(l, preact):
""" """
'no_preact' for the first resblock only, because the input is activated already. 'no_preact' for the first resblock in each group only, because the input is activated already.
'both_preact' for the first block in each group, due to the projection shotcut.
'default' for all the non-first blocks, where identity mapping is preserved on shortcut path. 'default' for all the non-first blocks, where identity mapping is preserved on shortcut path.
""" """
if preact == 'both_preact': if preact == 'default':
l = BNReLU('preact', l)
shortcut = l
elif preact == 'default':
shortcut = l shortcut = l
l = BNReLU('preact', l) l = BNReLU('preact', l)
else: else:
...@@ -153,15 +149,17 @@ def resnet_bottleneck(l, ch_out, stride, preact): ...@@ -153,15 +149,17 @@ def resnet_bottleneck(l, ch_out, stride, preact):
return l + resnet_shortcut(shortcut, ch_out * 4, stride) return l + resnet_shortcut(shortcut, ch_out * 4, stride)
def resnet_group(l, name, block_func, features, count, stride, first=False): def preresnet_group(l, name, block_func, features, count, stride):
with tf.variable_scope(name): with tf.variable_scope(name):
with tf.variable_scope('block0'): for i in range(0, count):
l = block_func(l, features, stride,
'no_preact' if first else 'both_preact')
for i in range(1, count):
with tf.variable_scope('block{}'.format(i)): with tf.variable_scope('block{}'.format(i)):
l = block_func(l, features, 1, 'default') # first block doesn't need activation
return l l = block_func(l, features,
stride if i == 0 else 1,
'no_preact' if i == 0 else 'default')
# end of each group need an extra activation
l = BNReLU('bnlast', l)
return l
def resnet_backbone(image, num_blocks, block_func): def resnet_backbone(image, num_blocks, block_func):
...@@ -170,11 +168,10 @@ def resnet_backbone(image, num_blocks, block_func): ...@@ -170,11 +168,10 @@ def resnet_backbone(image, num_blocks, block_func):
logits = (LinearWrap(image) logits = (LinearWrap(image)
.Conv2D('conv0', 64, 7, stride=2, nl=BNReLU) .Conv2D('conv0', 64, 7, stride=2, nl=BNReLU)
.MaxPooling('pool0', shape=3, stride=2, padding='SAME') .MaxPooling('pool0', shape=3, stride=2, padding='SAME')
.apply(resnet_group, 'group0', block_func, 64, num_blocks[0], 1, first=True) .apply(preresnet_group, 'group0', block_func, 64, num_blocks[0], 1)
.apply(resnet_group, 'group1', block_func, 128, num_blocks[1], 2) .apply(preresnet_group, 'group1', block_func, 128, num_blocks[1], 2)
.apply(resnet_group, 'group2', block_func, 256, num_blocks[2], 2) .apply(preresnet_group, 'group2', block_func, 256, num_blocks[2], 2)
.apply(resnet_group, 'group3', block_func, 512, num_blocks[3], 2) .apply(preresnet_group, 'group3', block_func, 512, num_blocks[3], 2)
.BNReLU('bnlast')
.GlobalAvgPooling('gap') .GlobalAvgPooling('gap')
.FullyConnected('linear', 1000, nl=tf.identity)()) .FullyConnected('linear', 1000, nl=tf.identity)())
return logits return logits
......
...@@ -20,7 +20,7 @@ from tensorpack.utils.gpu import get_nr_gpu ...@@ -20,7 +20,7 @@ from tensorpack.utils.gpu import get_nr_gpu
from tensorpack.utils import viz from tensorpack.utils import viz
from imagenet_resnet_utils import ( from imagenet_resnet_utils import (
fbresnet_augmentor, resnet_basicblock, resnet_bottleneck, resnet_group, fbresnet_augmentor, resnet_basicblock, preresnet_group,
image_preprocess, compute_loss_and_error) image_preprocess, compute_loss_and_error)
...@@ -42,8 +42,6 @@ class Model(ModelDesc): ...@@ -42,8 +42,6 @@ class Model(ModelDesc):
cfg = { cfg = {
18: ([2, 2, 2, 2], resnet_basicblock), 18: ([2, 2, 2, 2], resnet_basicblock),
34: ([3, 4, 6, 3], resnet_basicblock), 34: ([3, 4, 6, 3], resnet_basicblock),
50: ([3, 4, 6, 3], resnet_bottleneck),
101: ([3, 4, 23, 3], resnet_bottleneck)
} }
defs, block_func = cfg[DEPTH] defs, block_func = cfg[DEPTH]
...@@ -53,11 +51,10 @@ class Model(ModelDesc): ...@@ -53,11 +51,10 @@ class Model(ModelDesc):
convmaps = (LinearWrap(image) convmaps = (LinearWrap(image)
.Conv2D('conv0', 64, 7, stride=2, nl=BNReLU) .Conv2D('conv0', 64, 7, stride=2, nl=BNReLU)
.MaxPooling('pool0', shape=3, stride=2, padding='SAME') .MaxPooling('pool0', shape=3, stride=2, padding='SAME')
.apply(resnet_group, 'group0', block_func, 64, defs[0], 1, first=True) .apply(preresnet_group, 'group0', block_func, 64, defs[0], 1)
.apply(resnet_group, 'group1', block_func, 128, defs[1], 2) .apply(preresnet_group, 'group1', block_func, 128, defs[1], 2)
.apply(resnet_group, 'group2', block_func, 256, defs[2], 2) .apply(preresnet_group, 'group2', block_func, 256, defs[2], 2)
.apply(resnet_group, 'group3new', block_func, 512, defs[3], 1) .apply(preresnet_group, 'group3new', block_func, 512, defs[3], 1)())
.BNReLU('bnlast')())
print(convmaps) print(convmaps)
logits = (LinearWrap(convmaps) logits = (LinearWrap(convmaps)
.GlobalAvgPooling('gap') .GlobalAvgPooling('gap')
...@@ -125,7 +122,7 @@ def viz_cam(model_file, data_dir): ...@@ -125,7 +122,7 @@ def viz_cam(model_file, data_dir):
model=Model(), model=Model(),
session_init=get_model_loader(model_file), session_init=get_model_loader(model_file),
input_names=['input', 'label'], input_names=['input', 'label'],
output_names=['wrong-top1', 'bnlast/Relu', 'linearnew/W'], output_names=['wrong-top1', 'group3new/bnlast/Relu', 'linearnew/W'],
return_input=True return_input=True
) )
meta = dataset.ILSVRCMeta().get_synset_words_1000() meta = dataset.ILSVRCMeta().get_synset_words_1000()
......
...@@ -317,6 +317,7 @@ class ThreadedMapData(ProxyDataFlow): ...@@ -317,6 +317,7 @@ class ThreadedMapData(ProxyDataFlow):
self.buffer_size = buffer_size self.buffer_size = buffer_size
self.map_func = map_func self.map_func = map_func
self._threads = [] self._threads = []
self._evt = None
def reset_state(self): def reset_state(self):
super(ThreadedMapData, self).reset_state() super(ThreadedMapData, self).reset_state()
...@@ -372,6 +373,7 @@ class ThreadedMapData(ProxyDataFlow): ...@@ -372,6 +373,7 @@ class ThreadedMapData(ProxyDataFlow):
yield self._out_queue.get() yield self._out_queue.get()
def __del__(self): def __del__(self):
self._evt.set() if self._evt is not None:
self._evt.set()
for p in self._threads: for p in self._threads:
p.join() p.join()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment