Commit 6087698d authored by Yuxin Wu's avatar Yuxin Wu

a3c supports cpu

parent 61384a65
......@@ -80,6 +80,7 @@ class Model(ModelDesc):
wrong = tf.cast(tf.not_equal(pred, edgemap), tf.float32)
wrong = tf.reduce_mean(wrong, name='train_error')
if get_current_tower_context().is_training:
wd_w = tf.train.exponential_decay(2e-4, get_global_step_var(),
80000, 0.7, True)
wd_cost = tf.mul(wd_w, regularize_cost('.*/W', tf.nn.l2_loss), name='wd_cost')
......
......@@ -69,7 +69,7 @@ def run_submission(cfg, output, nr):
if k != 0:
player.restart_episode()
score = play_one_episode(player, predfunc)
print("Total:", score)
print("Score:", score)
def do_submit(output):
gym.upload(output, api_key='xxx')
......
......@@ -242,16 +242,23 @@ if __name__ == '__main__':
elif args.task == 'eval':
eval_model_multithread(cfg, EVAL_EPISODE)
else:
if args.gpu:
nr_gpu = get_nr_gpu()
if nr_gpu > 1:
predict_tower = range(nr_gpu)[-nr_gpu/2:]
else:
predict_tower = [0]
PREDICTOR_THREAD = len(predict_tower) * PREDICTOR_THREAD_PER_GPU
train_tower = range(nr_gpu)[:-nr_gpu/2] or [0]
logger.info("[BA3C] Train on gpu {} and infer on gpu {}".format(
','.join(map(str, train_tower)), ','.join(map(str, predict_tower))))
else:
nr_gpu = 0
PREDICTOR_THREAD = 1
predict_tower = [0]
train_tower = [0]
config = get_config()
if args.load:
config.session_init = SaverRestore(args.load)
config.tower = range(nr_gpu)[:-nr_gpu/2] or [0]
logger.info("[BA3C] Train on gpu {} and infer on gpu {}".format(
','.join(map(str, config.tower)), ','.join(map(str, predict_tower))))
config.tower = train_tower
AsyncMultiGPUTrainer(config, predict_tower=predict_tower).train()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment