Commit dabebf69 authored by Yuxin Wu's avatar Yuxin Wu

some a3c performance notes

parent c280473d
...@@ -10,8 +10,10 @@ Most of them are the best reproducible results on gym. ...@@ -10,8 +10,10 @@ Most of them are the best reproducible results on gym.
`./train-atari.py --env Breakout-v0 --gpu 0` `./train-atari.py --env Breakout-v0 --gpu 0`
In each iteration it trains on a batch of 128 new states.
The speed is about 6~10 iterations/s on 1 GPU plus 12+ CPU cores. The speed is about 6~10 iterations/s on 1 GPU plus 12+ CPU cores.
In each iteration it trains on a batch of 128 new states. The network architecture is larger than what's used in the original paper. With 2 TitanX + 20+ CPU cores, by setting `SIMULATOR_PROC=240, PREDICT_BATCH_SIZE=30, PREDICTOR_THREAD_PER_GPU=6`, it can improve to 16 it/s (2K images/s).
Note that the network architecture is larger than what's used in the original paper.
The pre-trained models are all trained with 4 GPUs for about 2 days. The pre-trained models are all trained with 4 GPUs for about 2 days.
But on simple games like Breakout, you can get good performance within several hours. But on simple games like Breakout, you can get good performance within several hours.
......
...@@ -17,12 +17,6 @@ import cv2 ...@@ -17,12 +17,6 @@ import cv2
import tensorflow as tf import tensorflow as tf
import six import six
from six.moves import queue from six.moves import queue
if six.PY3:
from concurrent import futures # py3
CancelledError = futures.CancelledError
else:
CancelledError = Exception
from tensorpack import * from tensorpack import *
from tensorpack.utils.concurrency import * from tensorpack.utils.concurrency import *
...@@ -36,6 +30,12 @@ from simulator import * ...@@ -36,6 +30,12 @@ from simulator import *
import common import common
from common import (play_model, Evaluator, eval_model_multithread, play_one_episode) from common import (play_model, Evaluator, eval_model_multithread, play_one_episode)
if six.PY3:
from concurrent import futures
CancelledError = futures.CancelledError
else:
CancelledError = Exception
IMAGE_SIZE = (84, 84) IMAGE_SIZE = (84, 84)
FRAME_HISTORY = 4 FRAME_HISTORY = 4
GAMMA = 0.99 GAMMA = 0.99
...@@ -46,6 +46,7 @@ LOCAL_TIME_MAX = 5 ...@@ -46,6 +46,7 @@ LOCAL_TIME_MAX = 5
STEPS_PER_EPOCH = 6000 STEPS_PER_EPOCH = 6000
EVAL_EPISODE = 50 EVAL_EPISODE = 50
BATCH_SIZE = 128 BATCH_SIZE = 128
PREDICT_BATCH_SIZE = 15 # batch for efficient forward
SIMULATOR_PROC = 50 SIMULATOR_PROC = 50
PREDICTOR_THREAD_PER_GPU = 3 PREDICTOR_THREAD_PER_GPU = 3
PREDICTOR_THREAD = None PREDICTOR_THREAD = None
...@@ -154,7 +155,7 @@ class MySimulatorMaster(SimulatorMaster, Callback): ...@@ -154,7 +155,7 @@ class MySimulatorMaster(SimulatorMaster, Callback):
def _setup_graph(self): def _setup_graph(self):
self.async_predictor = MultiThreadAsyncPredictor( self.async_predictor = MultiThreadAsyncPredictor(
self.trainer.get_predictors(['state'], ['policy_explore', 'pred_value'], self.trainer.get_predictors(['state'], ['policy_explore', 'pred_value'],
PREDICTOR_THREAD), batch_size=15) PREDICTOR_THREAD), batch_size=PREDICT_BATCH_SIZE)
def _before_train(self): def _before_train(self):
self.async_predictor.start() self.async_predictor.start()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment