Commit 5dee7231 authored by Yuxin Wu's avatar Yuxin Wu

--output argument for HED

parent 6087698d
...@@ -180,7 +180,7 @@ def get_config(): ...@@ -180,7 +180,7 @@ def get_config():
max_epoch=100, max_epoch=100,
) )
def run(model_path, image_path): def run(model_path, image_path, output):
pred_config = PredictConfig( pred_config = PredictConfig(
model=Model(), model=Model(),
session_init=get_model_loader(model_path), session_init=get_model_loader(model_path),
...@@ -191,10 +191,14 @@ def run(model_path, image_path): ...@@ -191,10 +191,14 @@ def run(model_path, image_path):
assert im is not None assert im is not None
im = cv2.resize(im, (im.shape[1] // 16 * 16, im.shape[0] // 16 * 16)) im = cv2.resize(im, (im.shape[1] // 16 * 16, im.shape[0] // 16 * 16))
outputs = predict_func([[im.astype('float32')]]) outputs = predict_func([[im.astype('float32')]])
for k in range(6): if output is None:
pred = outputs[k][0] for k in range(6):
cv2.imwrite("out{}.png".format( pred = outputs[k][0]
'-fused' if k == 5 else str(k+1)), pred * 255) cv2.imwrite("out{}.png".format(
'-fused' if k == 5 else str(k+1)), pred * 255)
else:
pred = outputs[5][0]
cv2.imwrite(output, pred * 255)
if __name__ == '__main__': if __name__ == '__main__':
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
...@@ -202,6 +206,7 @@ if __name__ == '__main__': ...@@ -202,6 +206,7 @@ if __name__ == '__main__':
parser.add_argument('--load', help='load model') parser.add_argument('--load', help='load model')
parser.add_argument('--view', help='view dataset', action='store_true') parser.add_argument('--view', help='view dataset', action='store_true')
parser.add_argument('--run', help='run model on images') parser.add_argument('--run', help='run model on images')
parser.add_argument('--output', help='fused output filename. default to out-fused.png')
args = parser.parse_args() args = parser.parse_args()
if args.gpu: if args.gpu:
os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
...@@ -209,7 +214,7 @@ if __name__ == '__main__': ...@@ -209,7 +214,7 @@ if __name__ == '__main__':
if args.view: if args.view:
view_data() view_data()
elif args.run: elif args.run:
run(args.load, args.run) run(args.load, args.run, args.output)
else: else:
config = get_config() config = get_config()
if args.load: if args.load:
......
...@@ -16,12 +16,12 @@ def change_gpu(val): ...@@ -16,12 +16,12 @@ def change_gpu(val):
def get_nr_gpu(): def get_nr_gpu():
env = os.environ.get('CUDA_VISIBLE_DEVICES', None) env = os.environ.get('CUDA_VISIBLE_DEVICES', None)
assert env is not None # TODO assert env is not None, 'gpu not set!' # TODO
return len(env.split(',')) return len(env.split(','))
def get_gpus(): def get_gpus():
""" return a list of GPU physical id""" """ return a list of GPU physical id"""
env = os.environ.get('CUDA_VISIBLE_DEVICES', None) env = os.environ.get('CUDA_VISIBLE_DEVICES', None)
assert env is not None # TODO assert env is not None, 'gpu not set!' # TODO
return map(int, env.strip().split(',')) return map(int, env.strip().split(','))
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment