Commit 2a0e96e0 authored by Yuxin Wu's avatar Yuxin Wu

fix train-atari py3 compatbility

parent d4421aee
......@@ -9,7 +9,7 @@ You can actually train them and reproduce the performance... not just to see how
+ [InceptionV3 on ImageNet](examples/Inception/inceptionv3.py)
+ [Fully-convolutional Network for Holistically-Nested Edge Detection](examples/HED)
+ [Spatial Transformer Networks on MNIST addition](examples/SpatialTransformer)
+ [Deep Convolutional Generative Adversarial Networks](examples/GAN)
+ [Generative Adversarial Networks & Image to Image Translation](examples/GAN)
+ [DQN variants on Atari games](examples/Atari2600)
+ [Asynchronous Advantage Actor-Critic(A3C) with demos on OpenAI Gym](examples/OpenAIGym)
+ [char-rnn language model](examples/char-rnn)
......
......@@ -240,11 +240,11 @@ if __name__ == '__main__':
if args.gpu:
nr_gpu = get_nr_gpu()
if nr_gpu > 1:
predict_tower = range(nr_gpu)[-nr_gpu/2:]
predict_tower = range(nr_gpu)[-nr_gpu//2:]
else:
predict_tower = [0]
PREDICTOR_THREAD = len(predict_tower) * PREDICTOR_THREAD_PER_GPU
train_tower = range(nr_gpu)[:-nr_gpu/2] or [0]
train_tower = range(nr_gpu)[:-nr_gpu//2] or [0]
logger.info("[BA3C] Train on gpu {} and infer on gpu {}".format(
','.join(map(str, train_tower)), ','.join(map(str, predict_tower))))
else:
......
......@@ -2,7 +2,6 @@ pillow
scipy
nltk
h5py
pyzmq
tornado; python_version < '3.0'
lmdb
matplotlib
......
......@@ -4,4 +4,5 @@ termcolor
tqdm>4.6.1
msgpack-python
msgpack-numpy
pyzmq
subprocess32; python_version < '3.0'
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment