Commit dabebf69 authored by Yuxin Wu's avatar Yuxin Wu

some a3c performance notes

parent c280473d
......@@ -10,8 +10,10 @@ Most of them are the best reproducible results on gym.
`./train-atari.py --env Breakout-v0 --gpu 0`
In each iteration it trains on a batch of 128 new states.
The speed is about 6~10 iterations/s on 1 GPU plus 12+ CPU cores.
In each iteration it trains on a batch of 128 new states. The network architecture is larger than what's used in the original paper.
With 2 TitanX + 20+ CPU cores, by setting `SIMULATOR_PROC=240, PREDICT_BATCH_SIZE=30, PREDICTOR_THREAD_PER_GPU=6`, it can improve to 16 it/s (2K images/s).
Note that the network architecture is larger than what's used in the original paper.
The pre-trained models are all trained with 4 GPUs for about 2 days.
But on simple games like Breakout, you can get good performance within several hours.
......
......@@ -17,12 +17,6 @@ import cv2
import tensorflow as tf
import six
from six.moves import queue
if six.PY3:
from concurrent import futures # py3
CancelledError = futures.CancelledError
else:
CancelledError = Exception
from tensorpack import *
from tensorpack.utils.concurrency import *
......@@ -36,6 +30,12 @@ from simulator import *
import common
from common import (play_model, Evaluator, eval_model_multithread, play_one_episode)
if six.PY3:
from concurrent import futures
CancelledError = futures.CancelledError
else:
CancelledError = Exception
IMAGE_SIZE = (84, 84)
FRAME_HISTORY = 4
GAMMA = 0.99
......@@ -46,6 +46,7 @@ LOCAL_TIME_MAX = 5
STEPS_PER_EPOCH = 6000
EVAL_EPISODE = 50
BATCH_SIZE = 128
PREDICT_BATCH_SIZE = 15 # batch for efficient forward
SIMULATOR_PROC = 50
PREDICTOR_THREAD_PER_GPU = 3
PREDICTOR_THREAD = None
......@@ -154,7 +155,7 @@ class MySimulatorMaster(SimulatorMaster, Callback):
def _setup_graph(self):
self.async_predictor = MultiThreadAsyncPredictor(
self.trainer.get_predictors(['state'], ['policy_explore', 'pred_value'],
PREDICTOR_THREAD), batch_size=15)
PREDICTOR_THREAD), batch_size=PREDICT_BATCH_SIZE)
def _before_train(self):
self.async_predictor.start()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment