Commit 35527038 authored by Yuxin Wu's avatar Yuxin Wu

Allow zero weight_decay in ImagenetModel

parent e15332fd
......@@ -271,10 +271,10 @@ def get_eval_dataflow():
if __name__ == '__main__':
config.BASEDIR = '/home/wyx/data/coco'
config.BASEDIR = '/private/home/yuxinwu/data/coco'
config.TRAIN_DATASET = ['train2014']
from tensorpack.dataflow import PrintData
ds = get_train_dataflow()
ds = get_train_dataflow(add_mask=config.MODE_MASK)
ds = PrintData(ds, 100)
TestDataSpeed(ds, 50000).start()
ds.reset_state()
......
......@@ -156,10 +156,15 @@ class ImageNetModel(ModelDesc):
logits = self.get_logits(image)
loss = ImageNetModel.compute_loss_and_error(logits, label)
if self.weight_decay > 0:
wd_loss = regularize_cost('.*/W', tf.contrib.layers.l2_regularizer(self.weight_decay),
name='l2_regularize_loss')
add_moving_summary(loss, wd_loss)
self.cost = tf.add_n([loss, wd_loss], name='cost')
else:
self.cost = tf.identity(loss, name='cost')
add_moving_summary(self.cost)
@abstractmethod
def get_logits(self, image):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment