Commit 3f19c1b8 authored by Yuxin Wu's avatar Yuxin Wu

fix global_step scope problem

parent 7e9d09a8
......@@ -44,14 +44,11 @@ Usage:
"""
class Model(ModelDesc):
def __init__(self, is_training=True):
self.isTrain = is_training
def _get_input_vars(self):
return [InputVar(tf.float32, [None, None, None] + [3], 'image'),
InputVar(tf.int32, [None, None, None], 'edgemap') ]
def _build_graph(self, input_vars, is_training):
def _build_graph(self, input_vars):
image, edgemap = input_vars
image = image - tf.constant([104, 116, 122], dtype='float32')
......@@ -91,8 +88,12 @@ class Model(ModelDesc):
l = Conv2D('conv5_3', l, 512)
b5 = branch('branch5', l, 16)
final_map = tf.squeeze(tf.mul(0.2, b1 + b2 + b3 + b4 + b5),
[3], name='predmap')
final_map = Conv2D('convfcweight',
tf.concat(3, [b1, b2, b3, b4, b5]), 1, 1,
W_init=tf.constant_initializer(0.2), use_bias=False)
final_map = tf.squeeze(final_map, [3], name='predmap')
#final_map = tf.squeeze(tf.mul(0.2, b1 + b2 + b3 + b4 + b5),
#[3], name='predmap')
costs = []
for idx, b in enumerate([b1, b2, b3, b4, b5, final_map]):
output = tf.nn.sigmoid(b, name='output{}'.format(idx+1))
......@@ -114,6 +115,10 @@ class Model(ModelDesc):
add_param_summary([('.*/W', ['histogram'])]) # monitor W
self.cost = tf.add_n(costs, name='cost')
def get_gradient_processor(self):
return [ScaleGradient([('convfc.*', 0.1), ('conv5_.*', 100)]),
SummaryGradient()]
def get_data(name):
isTrain = name == 'train'
ds = dataset.BSDS500(name, shuffle=True)
......@@ -149,8 +154,8 @@ def get_data(name):
ds = AugmentImageComponents(ds, shape_aug, (0, 1))
def f(m):
m[m>=0.49] = 1
m[m<0.49] = 0
m[m>=0.51] = 1
m[m<0.51] = 0
return m
ds = MapDataComponent(ds, f, 1)
......@@ -161,9 +166,10 @@ def get_data(name):
imgaug.GaussianNoise(),
]
ds = AugmentImageComponent(ds, augmentors)
ds = BatchDataByShape(ds, 8, idx=0)
if isTrain:
ds = BatchDataByShape(ds, 8, idx=0)
ds = PrefetchDataZMQ(ds, 1)
else:
ds = BatchData(ds, 1)
return ds
def view_data():
......@@ -180,27 +186,28 @@ def view_data():
def get_config():
logger.auto_set_dir()
dataset_train = get_data('train')
step_per_epoch = dataset_train.size() * 100
step_per_epoch = dataset_train.size() * 40
dataset_val = get_data('val')
#dataset_test = get_data('test')
lr = tf.Variable(1e-5, trainable=False, name='learning_rate')
lr = tf.Variable(5e-6, trainable=False, name='learning_rate')
tf.scalar_summary('learning_rate', lr)
return TrainConfig(
dataset=dataset_train,
optimizer=tf.train.AdamOptimizer(lr, epsilon=1e-3),
#optimizer=tf.train.AdamOptimizer(lr, epsilon=1e-3),
optimizer=tf.train.MomentumOptimizer(lr, 0.9),
callbacks=Callbacks([
StatPrinter(),
ModelSaver(),
ScheduledHyperParamSetter('learning_rate', [(25, 3e-6)]),
ScheduledHyperParamSetter('learning_rate', [(100, 3e-6), (200, 8e-7)]),
HumanHyperParamSetter('learning_rate'),
InferenceRunner(dataset_val,
BinaryClassificationStats('prediction', 'edgemap'))
]),
model=Model(),
step_per_epoch=step_per_epoch,
max_epoch=100,
max_epoch=300,
)
def run(model_path, image_path):
......
......@@ -60,7 +60,8 @@ class Resize(ImageAugmentor):
class RandomResize(ImageAugmentor):
""" randomly rescale w and h of the image"""
def __init__(self, xrange, yrange, minimum=(0,0), aspect_ratio_thres=0.15):
def __init__(self, xrange, yrange, minimum=(0,0), aspect_ratio_thres=0.15,
interp=cv2.INTER_CUBIC):
"""
:param xrange: (min, max) scaling ratio
:param yrange: (min, max) scaling ratio
......@@ -88,5 +89,5 @@ class RandomResize(ImageAugmentor):
return img.shape[1], img.shape[0]
def _augment(self, img, dsize):
return cv2.resize(img, dsize, interpolation=cv2.INTER_CUBIC)
return cv2.resize(img, dsize, interpolation=self.interp)
......@@ -45,8 +45,8 @@ def get_global_step_var():
scope = tf.get_variable_scope()
assert scope.name == '', \
"Creating global_step_var under a variable scope would cause problems!"
var = tf.Variable(
0, trainable=False, name=GLOBAL_STEP_OP_NAME)
var = tf.get_variable(GLOBAL_STEP_OP_NAME, shape=[],
initializer=tf.constant_initializer(), trainable=False)
return var
def get_global_step():
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment