Commit 90dd3ef4 authored by Yuxin Wu's avatar Yuxin Wu

update resnet with better init

parent 2d720b60
......@@ -46,7 +46,7 @@ class Model(ModelDesc):
def conv(name, l, channel, stride):
return Conv2D(name, l, channel, 3, stride=stride,
nl=tf.identity, use_bias=False,
W_init=tf.contrib.layers.xavier_initializer_conv2d(False))
W_init=tf.random_normal_initializer(stddev=2.0/9/channel))
def residual(name, l, increase_dim=False):
shape = l.get_shape().as_list()
......@@ -124,7 +124,6 @@ def get_data(train_or_test):
imgaug.RandomCrop((32, 32)),
imgaug.Flip(horiz=True),
imgaug.BrightnessAdd(20),
imgaug.Contrast((0.6,1.4)),
imgaug.MapImage(lambda x: x - pp_mean),
]
else:
......@@ -147,7 +146,8 @@ def get_config():
sess_config = get_default_sess_config(0.9)
lr = tf.Variable(0.1, trainable=False, name='learning_rate')
# warm up with small LR for 1 epoch
lr = tf.Variable(0.01, trainable=False, name='learning_rate')
tf.scalar_summary('learning_rate', lr)
return TrainConfig(
......@@ -158,7 +158,7 @@ def get_config():
PeriodicSaver(),
ValidationError(dataset_test, prefix='test'),
ScheduledHyperParamSetter('learning_rate',
[(82, 0.01), (123, 0.001), (300, 0.0001)])
[(1, 0.1), (82, 0.01), (123, 0.001), (300, 0.0001)])
]),
session_config=sess_config,
model=Model(n=18),
......
......@@ -56,6 +56,9 @@ class Callback(object):
return self.trainer.global_step
def trigger_epoch(self):
"""
epoch_num is the number of epoch finished.
"""
self.epoch_num += 1
self._trigger_epoch()
......
......@@ -25,7 +25,6 @@ class HyperParamSetter(Callback):
def _before_train(self):
all_vars = tf.all_variables()
for v in all_vars:
print v.name
if v.name == self.var_name:
self.var = v
break
......
......@@ -49,7 +49,7 @@ class Trainer(object):
if not hasattr(logger, 'LOG_DIR'):
raise RuntimeError("Please use logger.set_logger_dir at the beginning of your script.")
self.summary_writer = tf.train.SummaryWriter(
logger.LOG_DIR, graph_def=self.sess.graph_def)
logger.LOG_DIR, graph_def=self.sess.graph)
self.summary_op = tf.merge_all_summaries()
# create an empty StatHolder
self.stat_holder = StatHolder(logger.LOG_DIR, [])
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment