Commit e1cfbef8 authored by Yuxin Wu's avatar Yuxin Wu

clean-up import in ResNet

parent 08a5cf6f
...@@ -2,7 +2,6 @@ ...@@ -2,7 +2,6 @@
# -*- coding: UTF-8 -*- # -*- coding: UTF-8 -*-
# File: imagenet-resnet.py # File: imagenet-resnet.py
import cv2
import sys import sys
import argparse import argparse
import numpy as np import numpy as np
...@@ -10,9 +9,8 @@ import os ...@@ -10,9 +9,8 @@ import os
import multiprocessing import multiprocessing
import tensorflow as tf import tensorflow as tf
from tensorflow.contrib.layers import variance_scaling_initializer
from tensorpack import * from tensorpack import *
from tensorpack.utils.stats import RatioCounter
from tensorpack.tfutils.symbolic_functions import * from tensorpack.tfutils.symbolic_functions import *
from tensorpack.tfutils.summary import * from tensorpack.tfutils.summary import *
...@@ -24,6 +22,13 @@ TOTAL_BATCH_SIZE = 256 ...@@ -24,6 +22,13 @@ TOTAL_BATCH_SIZE = 256
INPUT_SHAPE = 224 INPUT_SHAPE = 224
DEPTH = None DEPTH = None
RESNET_CONFIG = {
18: ([2, 2, 2, 2], resnet_basicblock),
34: ([3, 4, 6, 3], resnet_basicblock),
50: ([3, 4, 6, 3], resnet_bottleneck),
101: ([3, 4, 23, 3], resnet_bottleneck)
}
class Model(ModelDesc): class Model(ModelDesc):
def __init__(self, data_format='NCHW'): def __init__(self, data_format='NCHW'):
...@@ -46,23 +51,16 @@ class Model(ModelDesc): ...@@ -46,23 +51,16 @@ class Model(ModelDesc):
if self.data_format == 'NCHW': if self.data_format == 'NCHW':
image = tf.transpose(image, [0, 3, 1, 2]) image = tf.transpose(image, [0, 3, 1, 2])
defs, block_func = RESNET_CONFIG[DEPTH]
cfg = {
18: ([2, 2, 2, 2], resnet_basicblock),
34: ([3, 4, 6, 3], resnet_basicblock),
50: ([3, 4, 6, 3], resnet_bottleneck),
101: ([3, 4, 23, 3], resnet_bottleneck)
}
defs, block_func = cfg[DEPTH]
with argscope([Conv2D, MaxPooling, GlobalAvgPooling, BatchNorm], data_format=self.data_format): with argscope([Conv2D, MaxPooling, GlobalAvgPooling, BatchNorm], data_format=self.data_format):
logits = resnet_backbone(image, defs, block_func) logits = resnet_backbone(image, defs, block_func)
loss = compute_loss_and_error(logits, label) loss = compute_loss_and_error(logits, label)
wd_cost = regularize_cost('.*/W', l2_regularizer(1e-4), name='l2_regularize_loss') wd_loss = regularize_cost('.*/W', l2_regularizer(1e-4), name='l2_regularize_loss')
add_moving_summary(loss, wd_cost) add_moving_summary(loss, wd_loss)
self.cost = tf.add_n([loss, wd_cost], name='cost') self.cost = tf.add_n([loss, wd_loss], name='cost')
def _get_optimizer(self): def _get_optimizer(self):
lr = get_scalar_var('learning_rate', 0.1, summary=True) lr = get_scalar_var('learning_rate', 0.1, summary=True)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment