Commit 8a341291 authored by dongzhuoyao's avatar dongzhuoyao Committed by Yuxin Wu

load model function in shufflenet (#479)

parent 8c778cc2
...@@ -83,3 +83,4 @@ docs/_build/ ...@@ -83,3 +83,4 @@ docs/_build/
target/ target/
*.dat *.dat
.idea/
...@@ -177,6 +177,7 @@ if __name__ == '__main__': ...@@ -177,6 +177,7 @@ if __name__ == '__main__':
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('--gpu', help='comma separated list of GPU(s) to use.') parser.add_argument('--gpu', help='comma separated list of GPU(s) to use.')
parser.add_argument('--data', help='ILSVRC dataset dir') parser.add_argument('--data', help='ILSVRC dataset dir')
parser.add_argument('--load', help='load model')
parser.add_argument('--flops', action='store_true', help='print flops and exit') parser.add_argument('--flops', action='store_true', help='print flops and exit')
args = parser.parse_args() args = parser.parse_args()
...@@ -206,4 +207,6 @@ if __name__ == '__main__': ...@@ -206,4 +207,6 @@ if __name__ == '__main__':
nr_tower = max(get_nr_gpu(), 1) nr_tower = max(get_nr_gpu(), 1)
config = get_config(model, nr_tower) config = get_config(model, nr_tower)
if args.load:
config.session_init = get_model_loader(args.load)
launch_train_with_config(config, SyncMultiGPUTrainerParameterServer(nr_tower)) launch_train_with_config(config, SyncMultiGPUTrainerParameterServer(nr_tower))
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment