Commit 90dd3ef4 authored by Yuxin Wu's avatar Yuxin Wu

update resnet with better init

parent 2d720b60
...@@ -46,7 +46,7 @@ class Model(ModelDesc): ...@@ -46,7 +46,7 @@ class Model(ModelDesc):
def conv(name, l, channel, stride): def conv(name, l, channel, stride):
return Conv2D(name, l, channel, 3, stride=stride, return Conv2D(name, l, channel, 3, stride=stride,
nl=tf.identity, use_bias=False, nl=tf.identity, use_bias=False,
W_init=tf.contrib.layers.xavier_initializer_conv2d(False)) W_init=tf.random_normal_initializer(stddev=2.0/9/channel))
def residual(name, l, increase_dim=False): def residual(name, l, increase_dim=False):
shape = l.get_shape().as_list() shape = l.get_shape().as_list()
...@@ -124,7 +124,6 @@ def get_data(train_or_test): ...@@ -124,7 +124,6 @@ def get_data(train_or_test):
imgaug.RandomCrop((32, 32)), imgaug.RandomCrop((32, 32)),
imgaug.Flip(horiz=True), imgaug.Flip(horiz=True),
imgaug.BrightnessAdd(20), imgaug.BrightnessAdd(20),
imgaug.Contrast((0.6,1.4)),
imgaug.MapImage(lambda x: x - pp_mean), imgaug.MapImage(lambda x: x - pp_mean),
] ]
else: else:
...@@ -147,7 +146,8 @@ def get_config(): ...@@ -147,7 +146,8 @@ def get_config():
sess_config = get_default_sess_config(0.9) sess_config = get_default_sess_config(0.9)
lr = tf.Variable(0.1, trainable=False, name='learning_rate') # warm up with small LR for 1 epoch
lr = tf.Variable(0.01, trainable=False, name='learning_rate')
tf.scalar_summary('learning_rate', lr) tf.scalar_summary('learning_rate', lr)
return TrainConfig( return TrainConfig(
...@@ -158,7 +158,7 @@ def get_config(): ...@@ -158,7 +158,7 @@ def get_config():
PeriodicSaver(), PeriodicSaver(),
ValidationError(dataset_test, prefix='test'), ValidationError(dataset_test, prefix='test'),
ScheduledHyperParamSetter('learning_rate', ScheduledHyperParamSetter('learning_rate',
[(82, 0.01), (123, 0.001), (300, 0.0001)]) [(1, 0.1), (82, 0.01), (123, 0.001), (300, 0.0001)])
]), ]),
session_config=sess_config, session_config=sess_config,
model=Model(n=18), model=Model(n=18),
......
...@@ -56,6 +56,9 @@ class Callback(object): ...@@ -56,6 +56,9 @@ class Callback(object):
return self.trainer.global_step return self.trainer.global_step
def trigger_epoch(self): def trigger_epoch(self):
"""
epoch_num is the number of epoch finished.
"""
self.epoch_num += 1 self.epoch_num += 1
self._trigger_epoch() self._trigger_epoch()
......
...@@ -25,7 +25,6 @@ class HyperParamSetter(Callback): ...@@ -25,7 +25,6 @@ class HyperParamSetter(Callback):
def _before_train(self): def _before_train(self):
all_vars = tf.all_variables() all_vars = tf.all_variables()
for v in all_vars: for v in all_vars:
print v.name
if v.name == self.var_name: if v.name == self.var_name:
self.var = v self.var = v
break break
......
...@@ -49,7 +49,7 @@ class Trainer(object): ...@@ -49,7 +49,7 @@ class Trainer(object):
if not hasattr(logger, 'LOG_DIR'): if not hasattr(logger, 'LOG_DIR'):
raise RuntimeError("Please use logger.set_logger_dir at the beginning of your script.") raise RuntimeError("Please use logger.set_logger_dir at the beginning of your script.")
self.summary_writer = tf.train.SummaryWriter( self.summary_writer = tf.train.SummaryWriter(
logger.LOG_DIR, graph_def=self.sess.graph_def) logger.LOG_DIR, graph_def=self.sess.graph)
self.summary_op = tf.merge_all_summaries() self.summary_op = tf.merge_all_summaries()
# create an empty StatHolder # create an empty StatHolder
self.stat_holder = StatHolder(logger.LOG_DIR, []) self.stat_holder = StatHolder(logger.LOG_DIR, [])
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment