Commit bcee048d authored by Yuxin Wu's avatar Yuxin Wu

update hed

parent eb11e29c
......@@ -22,7 +22,7 @@ class Model(ModelDesc):
def _build_graph(self, input_vars):
image, edgemap = input_vars
image = image - tf.constant([104, 116, 122], dtype='float32')
edgemap = tf.expand_dims(edgemap, 3)
edgemap = tf.expand_dims(edgemap, 3, name='edgemap4d')
def branch(name, l, up):
with tf.variable_scope(name) as scope:
......@@ -171,7 +171,7 @@ def get_config():
ScheduledHyperParamSetter('learning_rate', [(30, 6e-6), (45, 1e-6), (60, 8e-7)]),
HumanHyperParamSetter('learning_rate'),
InferenceRunner(dataset_val,
BinaryClassificationStats('prediction', 'edgemap'))
BinaryClassificationStats('prediction', 'edgemap4d'))
]),
model=Model(),
step_per_epoch=step_per_epoch,
......
......@@ -86,7 +86,7 @@ class BinaryStatistics(object):
:param pred: 0/1 np array
:param label: 0/1 np array of the same size
"""
assert pred.shape == label.shape
assert pred.shape == label.shape, "{} != {}".format(pred.shape, label.shape)
self.nr_pos += (label == 1).sum()
self.nr_neg += (label == 0).sum()
self.nr_pred_pos += (pred == 1).sum()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment