Commit 8a341291 authored by dongzhuoyao's avatar dongzhuoyao Committed by Yuxin Wu

load model function in shufflenet (#479)

parent 8c778cc2
......@@ -83,3 +83,4 @@ docs/_build/
target/
*.dat
.idea/
......@@ -177,6 +177,7 @@ if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--gpu', help='comma separated list of GPU(s) to use.')
parser.add_argument('--data', help='ILSVRC dataset dir')
parser.add_argument('--load', help='load model')
parser.add_argument('--flops', action='store_true', help='print flops and exit')
args = parser.parse_args()
......@@ -206,4 +207,6 @@ if __name__ == '__main__':
nr_tower = max(get_nr_gpu(), 1)
config = get_config(model, nr_tower)
if args.load:
config.session_init = get_model_loader(args.load)
launch_train_with_config(config, SyncMultiGPUTrainerParameterServer(nr_tower))
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment