Commit 8532e89d authored by Yuxin Wu's avatar Yuxin Wu

remove duplicate code in svhn-resnet

parent 44b723a1
...@@ -14,7 +14,8 @@ from tensorpack.tfutils.summary import * ...@@ -14,7 +14,8 @@ from tensorpack.tfutils.summary import *
from disturb import DisturbLabel from disturb import DisturbLabel
import imp import imp
svhn_example = imp.load_source('svhn_example', '../svhn-digit-convnet.py') svhn_example = imp.load_source('svhn_example',
os.path.join(os.path.dirname(__file__), '..', 'svhn-digit-convnet.py')))
Model = svhn_example.Model Model = svhn_example.Model
get_config = svhn_example.get_config get_config = svhn_example.get_config
......
...@@ -18,98 +18,12 @@ Reach 1.8% validation error after 70 epochs, with 2 TitanX. 2it/s. ...@@ -18,98 +18,12 @@ Reach 1.8% validation error after 70 epochs, with 2 TitanX. 2it/s.
You might need to adjust the learning rate schedule when running with 1 GPU. You might need to adjust the learning rate schedule when running with 1 GPU.
""" """
BATCH_SIZE = 128 import imp
cifar_example = imp.load_source('cifar_example',
class Model(ModelDesc): os.path.join(os.path.dirname(__file__), 'cifar10-resnet.py'))
def __init__(self, n): Model = cifar_example.Model
super(Model, self).__init__()
self.n = n
def _get_input_vars(self):
return [InputVar(tf.float32, [None, 32, 32, 3], 'input'),
InputVar(tf.int32, [None], 'label')
]
def _build_graph(self, input_vars, is_training):
image, label = input_vars
image = image / 128.0 - 1
def conv(name, l, channel, stride):
return Conv2D(name, l, channel, 3, stride=stride,
nl=tf.identity, use_bias=False,
W_init=tf.random_normal_initializer(stddev=np.sqrt(2.0/9/channel)))
def residual(name, l, increase_dim=False, first=False):
shape = l.get_shape().as_list()
in_channel = shape[3]
if increase_dim:
out_channel = in_channel * 2
stride1 = 2
else:
out_channel = in_channel
stride1 = 1
with tf.variable_scope(name) as scope:
if not first:
b1 = BatchNorm('bn1', l, is_training)
b1 = tf.nn.relu(b1)
else:
b1 = l
c1 = conv('conv1', b1, out_channel, stride1)
b2 = BatchNorm('bn2', c1, is_training)
b2 = tf.nn.relu(b2)
c2 = conv('conv2', b2, out_channel, 1)
if increase_dim:
l = AvgPooling('pool', l, 2)
l = tf.pad(l, [[0,0], [0,0], [0,0], [in_channel//2, in_channel//2]])
l = c2 + l BATCH_SIZE = 128
return l
l = conv('conv0', image, 16, 1)
l = BatchNorm('bn0', l, is_training)
l = tf.nn.relu(l)
l = residual('res1.0', l, first=True)
for k in range(1, self.n):
l = residual('res1.{}'.format(k), l)
# 32,c=16
l = residual('res2.0', l, increase_dim=True)
for k in range(1, self.n):
l = residual('res2.{}'.format(k), l)
# 16,c=32
l = residual('res3.0', l, increase_dim=True)
for k in range(1, self.n):
l = residual('res3.' + str(k), l)
l = BatchNorm('bnlast', l, is_training)
l = tf.nn.relu(l)
# 8,c=64
l = GlobalAvgPooling('gap', l)
logits = FullyConnected('linear', l, out_dim=10, nl=tf.identity)
prob = tf.nn.softmax(logits, name='output')
cost = tf.nn.sparse_softmax_cross_entropy_with_logits(logits, label)
cost = tf.reduce_mean(cost, name='cross_entropy_loss')
tf.add_to_collection(MOVING_SUMMARY_VARS_KEY, cost)
wrong = prediction_incorrect(logits, label)
nr_wrong = tf.reduce_sum(wrong, name='wrong')
# monitor training error
tf.add_to_collection(
MOVING_SUMMARY_VARS_KEY, tf.reduce_mean(wrong, name='train_error'))
# weight decay on all W of fc layers
#wd_cost = regularize_cost('.*/W', l2_regularizer(0.0002), name='regularize_loss')
wd_w = tf.train.exponential_decay(0.0001, get_global_step_var(),
960000, 0.5, True)
wd_cost = tf.mul(wd_w, regularize_cost('.*/W', tf.nn.l2_loss), name='wd_cost')
tf.add_to_collection(MOVING_SUMMARY_VARS_KEY, wd_cost)
add_param_summary([('.*/W', ['histogram'])]) # monitor W
self.cost = tf.add_n([cost, wd_cost], name='cost')
def get_data(train_or_test): def get_data(train_or_test):
isTrain = train_or_test == 'train' isTrain = train_or_test == 'train'
......
...@@ -14,7 +14,7 @@ __all__ = ['UniformDiscretizer1D', 'UniformDiscretizerND'] ...@@ -14,7 +14,7 @@ __all__ = ['UniformDiscretizer1D', 'UniformDiscretizerND']
def log_once(s): def log_once(s):
logger.warn(s) logger.warn(s)
# just placeholder # just a placeholder
class Discretizer(object): class Discretizer(object):
__metaclass__ = ABCMeta __metaclass__ = ABCMeta
...@@ -55,6 +55,9 @@ class UniformDiscretizer1D(Discretizer1D): ...@@ -55,6 +55,9 @@ class UniformDiscretizer1D(Discretizer1D):
(v - self.minv) / self.spacing, (v - self.minv) / self.spacing,
0, self.nr_bin - 1)) 0, self.nr_bin - 1))
def get_bin_center(self, bin_id):
return self.minv + self.spacing * (bin_id + 0.5)
def get_distribution(self, v, smooth_factor=0.05, smooth_radius=2): def get_distribution(self, v, smooth_factor=0.05, smooth_radius=2):
""" return a smoothed one-hot distribution of the sample v. """ return a smoothed one-hot distribution of the sample v.
""" """
...@@ -96,6 +99,19 @@ class UniformDiscretizerND(Discretizer): ...@@ -96,6 +99,19 @@ class UniformDiscretizerND(Discretizer):
acc *= self.nr_bins[k] acc *= self.nr_bins[k]
return res return res
def _get_bin_id_nd(self, bin_id):
ret = []
for k in reversed(list(range(self.n))):
nr = self.nr_bins[k]
v = bin_id % nr
bin_id = bin_id / nr
ret.append(v)
return list(reversed(ret))
def get_bin_center(self, bin_id):
bin_id_nd = self._get_bin_id_nd(bin_id)
return [self.discretizers[k].get_bin_center(bin_id_nd[k]) for k in range(self.n)]
if __name__ == '__main__': if __name__ == '__main__':
#u = UniformDiscretizer1D(-10, 10, 0.12) #u = UniformDiscretizer1D(-10, 10, 0.12)
u = UniformDiscretizerND((0, 100, 1), (0, 100, 1), (0, 100, 1)) u = UniformDiscretizerND((0, 100, 1), (0, 100, 1), (0, 100, 1))
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment