Commit 3f19c1b8 authored by Yuxin Wu's avatar Yuxin Wu

fix global_step scope problem

parent 7e9d09a8
...@@ -44,14 +44,11 @@ Usage: ...@@ -44,14 +44,11 @@ Usage:
""" """
class Model(ModelDesc): class Model(ModelDesc):
def __init__(self, is_training=True):
self.isTrain = is_training
def _get_input_vars(self): def _get_input_vars(self):
return [InputVar(tf.float32, [None, None, None] + [3], 'image'), return [InputVar(tf.float32, [None, None, None] + [3], 'image'),
InputVar(tf.int32, [None, None, None], 'edgemap') ] InputVar(tf.int32, [None, None, None], 'edgemap') ]
def _build_graph(self, input_vars, is_training): def _build_graph(self, input_vars):
image, edgemap = input_vars image, edgemap = input_vars
image = image - tf.constant([104, 116, 122], dtype='float32') image = image - tf.constant([104, 116, 122], dtype='float32')
...@@ -91,8 +88,12 @@ class Model(ModelDesc): ...@@ -91,8 +88,12 @@ class Model(ModelDesc):
l = Conv2D('conv5_3', l, 512) l = Conv2D('conv5_3', l, 512)
b5 = branch('branch5', l, 16) b5 = branch('branch5', l, 16)
final_map = tf.squeeze(tf.mul(0.2, b1 + b2 + b3 + b4 + b5), final_map = Conv2D('convfcweight',
[3], name='predmap') tf.concat(3, [b1, b2, b3, b4, b5]), 1, 1,
W_init=tf.constant_initializer(0.2), use_bias=False)
final_map = tf.squeeze(final_map, [3], name='predmap')
#final_map = tf.squeeze(tf.mul(0.2, b1 + b2 + b3 + b4 + b5),
#[3], name='predmap')
costs = [] costs = []
for idx, b in enumerate([b1, b2, b3, b4, b5, final_map]): for idx, b in enumerate([b1, b2, b3, b4, b5, final_map]):
output = tf.nn.sigmoid(b, name='output{}'.format(idx+1)) output = tf.nn.sigmoid(b, name='output{}'.format(idx+1))
...@@ -114,6 +115,10 @@ class Model(ModelDesc): ...@@ -114,6 +115,10 @@ class Model(ModelDesc):
add_param_summary([('.*/W', ['histogram'])]) # monitor W add_param_summary([('.*/W', ['histogram'])]) # monitor W
self.cost = tf.add_n(costs, name='cost') self.cost = tf.add_n(costs, name='cost')
def get_gradient_processor(self):
return [ScaleGradient([('convfc.*', 0.1), ('conv5_.*', 100)]),
SummaryGradient()]
def get_data(name): def get_data(name):
isTrain = name == 'train' isTrain = name == 'train'
ds = dataset.BSDS500(name, shuffle=True) ds = dataset.BSDS500(name, shuffle=True)
...@@ -149,8 +154,8 @@ def get_data(name): ...@@ -149,8 +154,8 @@ def get_data(name):
ds = AugmentImageComponents(ds, shape_aug, (0, 1)) ds = AugmentImageComponents(ds, shape_aug, (0, 1))
def f(m): def f(m):
m[m>=0.49] = 1 m[m>=0.51] = 1
m[m<0.49] = 0 m[m<0.51] = 0
return m return m
ds = MapDataComponent(ds, f, 1) ds = MapDataComponent(ds, f, 1)
...@@ -161,9 +166,10 @@ def get_data(name): ...@@ -161,9 +166,10 @@ def get_data(name):
imgaug.GaussianNoise(), imgaug.GaussianNoise(),
] ]
ds = AugmentImageComponent(ds, augmentors) ds = AugmentImageComponent(ds, augmentors)
ds = BatchDataByShape(ds, 8, idx=0) ds = BatchDataByShape(ds, 8, idx=0)
if isTrain:
ds = PrefetchDataZMQ(ds, 1) ds = PrefetchDataZMQ(ds, 1)
else:
ds = BatchData(ds, 1)
return ds return ds
def view_data(): def view_data():
...@@ -180,27 +186,28 @@ def view_data(): ...@@ -180,27 +186,28 @@ def view_data():
def get_config(): def get_config():
logger.auto_set_dir() logger.auto_set_dir()
dataset_train = get_data('train') dataset_train = get_data('train')
step_per_epoch = dataset_train.size() * 100 step_per_epoch = dataset_train.size() * 40
dataset_val = get_data('val') dataset_val = get_data('val')
#dataset_test = get_data('test') #dataset_test = get_data('test')
lr = tf.Variable(1e-5, trainable=False, name='learning_rate') lr = tf.Variable(5e-6, trainable=False, name='learning_rate')
tf.scalar_summary('learning_rate', lr) tf.scalar_summary('learning_rate', lr)
return TrainConfig( return TrainConfig(
dataset=dataset_train, dataset=dataset_train,
optimizer=tf.train.AdamOptimizer(lr, epsilon=1e-3), #optimizer=tf.train.AdamOptimizer(lr, epsilon=1e-3),
optimizer=tf.train.MomentumOptimizer(lr, 0.9),
callbacks=Callbacks([ callbacks=Callbacks([
StatPrinter(), StatPrinter(),
ModelSaver(), ModelSaver(),
ScheduledHyperParamSetter('learning_rate', [(25, 3e-6)]), ScheduledHyperParamSetter('learning_rate', [(100, 3e-6), (200, 8e-7)]),
HumanHyperParamSetter('learning_rate'), HumanHyperParamSetter('learning_rate'),
InferenceRunner(dataset_val, InferenceRunner(dataset_val,
BinaryClassificationStats('prediction', 'edgemap')) BinaryClassificationStats('prediction', 'edgemap'))
]), ]),
model=Model(), model=Model(),
step_per_epoch=step_per_epoch, step_per_epoch=step_per_epoch,
max_epoch=100, max_epoch=300,
) )
def run(model_path, image_path): def run(model_path, image_path):
......
...@@ -60,7 +60,8 @@ class Resize(ImageAugmentor): ...@@ -60,7 +60,8 @@ class Resize(ImageAugmentor):
class RandomResize(ImageAugmentor): class RandomResize(ImageAugmentor):
""" randomly rescale w and h of the image""" """ randomly rescale w and h of the image"""
def __init__(self, xrange, yrange, minimum=(0,0), aspect_ratio_thres=0.15): def __init__(self, xrange, yrange, minimum=(0,0), aspect_ratio_thres=0.15,
interp=cv2.INTER_CUBIC):
""" """
:param xrange: (min, max) scaling ratio :param xrange: (min, max) scaling ratio
:param yrange: (min, max) scaling ratio :param yrange: (min, max) scaling ratio
...@@ -88,5 +89,5 @@ class RandomResize(ImageAugmentor): ...@@ -88,5 +89,5 @@ class RandomResize(ImageAugmentor):
return img.shape[1], img.shape[0] return img.shape[1], img.shape[0]
def _augment(self, img, dsize): def _augment(self, img, dsize):
return cv2.resize(img, dsize, interpolation=cv2.INTER_CUBIC) return cv2.resize(img, dsize, interpolation=self.interp)
...@@ -45,8 +45,8 @@ def get_global_step_var(): ...@@ -45,8 +45,8 @@ def get_global_step_var():
scope = tf.get_variable_scope() scope = tf.get_variable_scope()
assert scope.name == '', \ assert scope.name == '', \
"Creating global_step_var under a variable scope would cause problems!" "Creating global_step_var under a variable scope would cause problems!"
var = tf.Variable( var = tf.get_variable(GLOBAL_STEP_OP_NAME, shape=[],
0, trainable=False, name=GLOBAL_STEP_OP_NAME) initializer=tf.constant_initializer(), trainable=False)
return var return var
def get_global_step(): def get_global_step():
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment