Commit 6a1e822d authored by Yuxin Wu's avatar Yuxin Wu

bugfix

parent 7bb9ee1c
...@@ -154,5 +154,5 @@ if __name__ == '__main__': ...@@ -154,5 +154,5 @@ if __name__ == '__main__':
nr_gpu = len(args.gpu.split(',')) nr_gpu = len(args.gpu.split(','))
trainer = QueueInputTrainer() if nr_gpu <= 1 \ trainer = QueueInputTrainer() if nr_gpu <= 1 \
else SyncMultiGPUTrainerParameterServer(list(range(nr_gpu))) else SyncMultiGPUTrainerParameterServer(nr_gpu)
launch_train_with_config(config, trainer) launch_train_with_config(config, trainer)
...@@ -36,6 +36,7 @@ def _int_to_range(x): ...@@ -36,6 +36,7 @@ def _int_to_range(x):
if isinstance(x, int): if isinstance(x, int):
assert x > 0, x assert x > 0, x
return list(range(x)) return list(range(x))
return x
class SimpleTrainer(SingleCostTrainer): class SimpleTrainer(SingleCostTrainer):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment