Commit 6087698d authored by Yuxin Wu's avatar Yuxin Wu

a3c supports cpu

parent 61384a65
...@@ -80,6 +80,7 @@ class Model(ModelDesc): ...@@ -80,6 +80,7 @@ class Model(ModelDesc):
wrong = tf.cast(tf.not_equal(pred, edgemap), tf.float32) wrong = tf.cast(tf.not_equal(pred, edgemap), tf.float32)
wrong = tf.reduce_mean(wrong, name='train_error') wrong = tf.reduce_mean(wrong, name='train_error')
if get_current_tower_context().is_training:
wd_w = tf.train.exponential_decay(2e-4, get_global_step_var(), wd_w = tf.train.exponential_decay(2e-4, get_global_step_var(),
80000, 0.7, True) 80000, 0.7, True)
wd_cost = tf.mul(wd_w, regularize_cost('.*/W', tf.nn.l2_loss), name='wd_cost') wd_cost = tf.mul(wd_w, regularize_cost('.*/W', tf.nn.l2_loss), name='wd_cost')
......
...@@ -69,7 +69,7 @@ def run_submission(cfg, output, nr): ...@@ -69,7 +69,7 @@ def run_submission(cfg, output, nr):
if k != 0: if k != 0:
player.restart_episode() player.restart_episode()
score = play_one_episode(player, predfunc) score = play_one_episode(player, predfunc)
print("Total:", score) print("Score:", score)
def do_submit(output): def do_submit(output):
gym.upload(output, api_key='xxx') gym.upload(output, api_key='xxx')
......
...@@ -242,16 +242,23 @@ if __name__ == '__main__': ...@@ -242,16 +242,23 @@ if __name__ == '__main__':
elif args.task == 'eval': elif args.task == 'eval':
eval_model_multithread(cfg, EVAL_EPISODE) eval_model_multithread(cfg, EVAL_EPISODE)
else: else:
if args.gpu:
nr_gpu = get_nr_gpu() nr_gpu = get_nr_gpu()
if nr_gpu > 1: if nr_gpu > 1:
predict_tower = range(nr_gpu)[-nr_gpu/2:] predict_tower = range(nr_gpu)[-nr_gpu/2:]
else: else:
predict_tower = [0] predict_tower = [0]
PREDICTOR_THREAD = len(predict_tower) * PREDICTOR_THREAD_PER_GPU PREDICTOR_THREAD = len(predict_tower) * PREDICTOR_THREAD_PER_GPU
train_tower = range(nr_gpu)[:-nr_gpu/2] or [0]
logger.info("[BA3C] Train on gpu {} and infer on gpu {}".format(
','.join(map(str, train_tower)), ','.join(map(str, predict_tower))))
else:
nr_gpu = 0
PREDICTOR_THREAD = 1
predict_tower = [0]
train_tower = [0]
config = get_config() config = get_config()
if args.load: if args.load:
config.session_init = SaverRestore(args.load) config.session_init = SaverRestore(args.load)
config.tower = range(nr_gpu)[:-nr_gpu/2] or [0] config.tower = train_tower
logger.info("[BA3C] Train on gpu {} and infer on gpu {}".format(
','.join(map(str, config.tower)), ','.join(map(str, predict_tower))))
AsyncMultiGPUTrainer(config, predict_tower=predict_tower).train() AsyncMultiGPUTrainer(config, predict_tower=predict_tower).train()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment