Commit e1cfbef8 authored by Yuxin Wu's avatar Yuxin Wu

clean-up import in ResNet

parent 08a5cf6f
......@@ -2,7 +2,6 @@
# -*- coding: UTF-8 -*-
# File: imagenet-resnet.py
import cv2
import sys
import argparse
import numpy as np
......@@ -10,9 +9,8 @@ import os
import multiprocessing
import tensorflow as tf
from tensorflow.contrib.layers import variance_scaling_initializer
from tensorpack import *
from tensorpack.utils.stats import RatioCounter
from tensorpack.tfutils.symbolic_functions import *
from tensorpack.tfutils.summary import *
......@@ -24,6 +22,13 @@ TOTAL_BATCH_SIZE = 256
INPUT_SHAPE = 224
DEPTH = None
RESNET_CONFIG = {
18: ([2, 2, 2, 2], resnet_basicblock),
34: ([3, 4, 6, 3], resnet_basicblock),
50: ([3, 4, 6, 3], resnet_bottleneck),
101: ([3, 4, 23, 3], resnet_bottleneck)
}
class Model(ModelDesc):
def __init__(self, data_format='NCHW'):
......@@ -46,23 +51,16 @@ class Model(ModelDesc):
if self.data_format == 'NCHW':
image = tf.transpose(image, [0, 3, 1, 2])
cfg = {
18: ([2, 2, 2, 2], resnet_basicblock),
34: ([3, 4, 6, 3], resnet_basicblock),
50: ([3, 4, 6, 3], resnet_bottleneck),
101: ([3, 4, 23, 3], resnet_bottleneck)
}
defs, block_func = cfg[DEPTH]
defs, block_func = RESNET_CONFIG[DEPTH]
with argscope([Conv2D, MaxPooling, GlobalAvgPooling, BatchNorm], data_format=self.data_format):
logits = resnet_backbone(image, defs, block_func)
loss = compute_loss_and_error(logits, label)
wd_cost = regularize_cost('.*/W', l2_regularizer(1e-4), name='l2_regularize_loss')
add_moving_summary(loss, wd_cost)
self.cost = tf.add_n([loss, wd_cost], name='cost')
wd_loss = regularize_cost('.*/W', l2_regularizer(1e-4), name='l2_regularize_loss')
add_moving_summary(loss, wd_loss)
self.cost = tf.add_n([loss, wd_loss], name='cost')
def _get_optimizer(self):
lr = get_scalar_var('learning_rate', 0.1, summary=True)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment