Commit 5dee7231 authored by Yuxin Wu's avatar Yuxin Wu

--output argument for HED

parent 6087698d
......@@ -180,7 +180,7 @@ def get_config():
max_epoch=100,
)
def run(model_path, image_path):
def run(model_path, image_path, output):
pred_config = PredictConfig(
model=Model(),
session_init=get_model_loader(model_path),
......@@ -191,10 +191,14 @@ def run(model_path, image_path):
assert im is not None
im = cv2.resize(im, (im.shape[1] // 16 * 16, im.shape[0] // 16 * 16))
outputs = predict_func([[im.astype('float32')]])
if output is None:
for k in range(6):
pred = outputs[k][0]
cv2.imwrite("out{}.png".format(
'-fused' if k == 5 else str(k+1)), pred * 255)
else:
pred = outputs[5][0]
cv2.imwrite(output, pred * 255)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
......@@ -202,6 +206,7 @@ if __name__ == '__main__':
parser.add_argument('--load', help='load model')
parser.add_argument('--view', help='view dataset', action='store_true')
parser.add_argument('--run', help='run model on images')
parser.add_argument('--output', help='fused output filename. default to out-fused.png')
args = parser.parse_args()
if args.gpu:
os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
......@@ -209,7 +214,7 @@ if __name__ == '__main__':
if args.view:
view_data()
elif args.run:
run(args.load, args.run)
run(args.load, args.run, args.output)
else:
config = get_config()
if args.load:
......
......@@ -16,12 +16,12 @@ def change_gpu(val):
def get_nr_gpu():
env = os.environ.get('CUDA_VISIBLE_DEVICES', None)
assert env is not None # TODO
assert env is not None, 'gpu not set!' # TODO
return len(env.split(','))
def get_gpus():
""" return a list of GPU physical id"""
env = os.environ.get('CUDA_VISIBLE_DEVICES', None)
assert env is not None # TODO
assert env is not None, 'gpu not set!' # TODO
return map(int, env.strip().split(','))
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment