Commit bf0ef32e authored by Yuxin Wu's avatar Yuxin Wu

resnet improvements

parent e7884dd0
...@@ -20,9 +20,13 @@ from tensorpack.dataflow import imgaug ...@@ -20,9 +20,13 @@ from tensorpack.dataflow import imgaug
""" """
CIFAR10-resnet example. CIFAR10-resnet example.
I can reproduce the results in:
Deep Residual Learning for Image Recognition, arxiv:1512.03385 Deep Residual Learning for Image Recognition, arxiv:1512.03385
for n=5 and 18 (6.5% val error) using the variants proposed in:
Identity Mappings in Deep Residual Networks, arxiv::1603.05027
I can reproduce the results
for n=5 (about 7.7% val error) and 18 (about 6.4% val error)
This model uses the whole training set instead of a 95:5 train-val split.
""" """
BATCH_SIZE = 128 BATCH_SIZE = 128
...@@ -44,9 +48,9 @@ class Model(ModelDesc): ...@@ -44,9 +48,9 @@ class Model(ModelDesc):
def conv(name, l, channel, stride): def conv(name, l, channel, stride):
return Conv2D(name, l, channel, 3, stride=stride, return Conv2D(name, l, channel, 3, stride=stride,
nl=tf.identity, use_bias=False, nl=tf.identity, use_bias=False,
W_init=tf.random_normal_initializer(stddev=2.0/9/channel)) W_init=tf.random_normal_initializer(stddev=np.sqrt(2.0/9/channel)))
def residual(name, l, increase_dim=False): def residual(name, l, increase_dim=False, first=False):
shape = l.get_shape().as_list() shape = l.get_shape().as_list()
in_channel = shape[3] in_channel = shape[3]
...@@ -58,25 +62,29 @@ class Model(ModelDesc): ...@@ -58,25 +62,29 @@ class Model(ModelDesc):
stride1 = 1 stride1 = 1
with tf.variable_scope(name) as scope: with tf.variable_scope(name) as scope:
c1 = conv('conv1', l, out_channel, stride1) if not first:
b1 = BatchNorm('bn1', c1, is_training) b1 = BatchNorm('bn1', l, is_training)
b1 = tf.nn.relu(b1) b1 = tf.nn.relu(b1)
c2 = conv('conv2', b1, out_channel, 1) else:
b2 = BatchNorm('bn2', c2, is_training) b1 = l
c1 = conv('conv1', b1, out_channel, stride1)
b2 = BatchNorm('bn2', c1, is_training)
b2 = tf.nn.relu(b2)
c2 = conv('conv2', b2, out_channel, 1)
if increase_dim: if increase_dim:
l = AvgPooling('pool', l, 2) l = AvgPooling('pool', l, 2)
l = tf.pad(l, [[0,0], [0,0], [0,0], [in_channel//2, in_channel//2]]) l = tf.pad(l, [[0,0], [0,0], [0,0], [in_channel//2, in_channel//2]])
l = b2 + l l = c2 + l
l = tf.nn.relu(l)
return l return l
l = conv('conv1', image, 16, 1) l = conv('conv0', image, 16, 1)
l = BatchNorm('bn1', l, is_training) l = BatchNorm('bn0', l, is_training)
l = tf.nn.relu(l) l = tf.nn.relu(l)
for k in range(self.n): l = residual('res1.0', l, first=True)
for k in range(1, self.n):
l = residual('res1.{}'.format(k), l) l = residual('res1.{}'.format(k), l)
# 32,c=16 # 32,c=16
...@@ -88,6 +96,8 @@ class Model(ModelDesc): ...@@ -88,6 +96,8 @@ class Model(ModelDesc):
l = residual('res3.0', l, increase_dim=True) l = residual('res3.0', l, increase_dim=True)
for k in range(1, self.n): for k in range(1, self.n):
l = residual('res3.' + str(k), l) l = residual('res3.' + str(k), l)
l = BatchNorm('bnlast', l, is_training)
l = tf.nn.relu(l)
# 8,c=64 # 8,c=64
l = GlobalAvgPooling('gap', l) l = GlobalAvgPooling('gap', l)
logits = FullyConnected('linear', l, out_dim=10, summary_activation=False, logits = FullyConnected('linear', l, out_dim=10, summary_activation=False,
...@@ -123,6 +133,7 @@ def get_data(train_or_test): ...@@ -123,6 +133,7 @@ def get_data(train_or_test):
imgaug.RandomCrop((32, 32)), imgaug.RandomCrop((32, 32)),
imgaug.Flip(horiz=True), imgaug.Flip(horiz=True),
imgaug.BrightnessAdd(20), imgaug.BrightnessAdd(20),
#imgaug.Contrast((0.6,1.4)),
imgaug.MapImage(lambda x: x - pp_mean), imgaug.MapImage(lambda x: x - pp_mean),
] ]
else: else:
...@@ -145,7 +156,6 @@ def get_config(): ...@@ -145,7 +156,6 @@ def get_config():
sess_config = get_default_sess_config(0.9) sess_config = get_default_sess_config(0.9)
# warm up with small LR for 1 epoch
lr = tf.Variable(0.01, trainable=False, name='learning_rate') lr = tf.Variable(0.01, trainable=False, name='learning_rate')
tf.scalar_summary('learning_rate', lr) tf.scalar_summary('learning_rate', lr)
...@@ -157,7 +167,7 @@ def get_config(): ...@@ -157,7 +167,7 @@ def get_config():
PeriodicSaver(), PeriodicSaver(),
ValidationError(dataset_test, prefix='test'), ValidationError(dataset_test, prefix='test'),
ScheduledHyperParamSetter('learning_rate', ScheduledHyperParamSetter('learning_rate',
[(1, 0.1), (82, 0.01), (123, 0.001), (300, 0.0001)]) [(1, 0.1), (82, 0.01), (123, 0.001), (300, 0.0002)])
]), ]),
session_config=sess_config, session_config=sess_config,
model=Model(n=18), model=Model(n=18),
......
...@@ -49,7 +49,7 @@ class Trainer(object): ...@@ -49,7 +49,7 @@ class Trainer(object):
if not hasattr(logger, 'LOG_DIR'): if not hasattr(logger, 'LOG_DIR'):
raise RuntimeError("Please use logger.set_logger_dir at the beginning of your script.") raise RuntimeError("Please use logger.set_logger_dir at the beginning of your script.")
self.summary_writer = tf.train.SummaryWriter( self.summary_writer = tf.train.SummaryWriter(
logger.LOG_DIR, graph_def=self.sess.graph) logger.LOG_DIR, graph=self.sess.graph)
self.summary_op = tf.merge_all_summaries() self.summary_op = tf.merge_all_summaries()
# create an empty StatHolder # create an empty StatHolder
self.stat_holder = StatHolder(logger.LOG_DIR, []) self.stat_holder = StatHolder(logger.LOG_DIR, [])
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment