Commit 2a0e96e0 authored by Yuxin Wu's avatar Yuxin Wu

fix train-atari py3 compatbility

parent d4421aee
...@@ -9,7 +9,7 @@ You can actually train them and reproduce the performance... not just to see how ...@@ -9,7 +9,7 @@ You can actually train them and reproduce the performance... not just to see how
+ [InceptionV3 on ImageNet](examples/Inception/inceptionv3.py) + [InceptionV3 on ImageNet](examples/Inception/inceptionv3.py)
+ [Fully-convolutional Network for Holistically-Nested Edge Detection](examples/HED) + [Fully-convolutional Network for Holistically-Nested Edge Detection](examples/HED)
+ [Spatial Transformer Networks on MNIST addition](examples/SpatialTransformer) + [Spatial Transformer Networks on MNIST addition](examples/SpatialTransformer)
+ [Deep Convolutional Generative Adversarial Networks](examples/GAN) + [Generative Adversarial Networks & Image to Image Translation](examples/GAN)
+ [DQN variants on Atari games](examples/Atari2600) + [DQN variants on Atari games](examples/Atari2600)
+ [Asynchronous Advantage Actor-Critic(A3C) with demos on OpenAI Gym](examples/OpenAIGym) + [Asynchronous Advantage Actor-Critic(A3C) with demos on OpenAI Gym](examples/OpenAIGym)
+ [char-rnn language model](examples/char-rnn) + [char-rnn language model](examples/char-rnn)
......
...@@ -240,11 +240,11 @@ if __name__ == '__main__': ...@@ -240,11 +240,11 @@ if __name__ == '__main__':
if args.gpu: if args.gpu:
nr_gpu = get_nr_gpu() nr_gpu = get_nr_gpu()
if nr_gpu > 1: if nr_gpu > 1:
predict_tower = range(nr_gpu)[-nr_gpu/2:] predict_tower = range(nr_gpu)[-nr_gpu//2:]
else: else:
predict_tower = [0] predict_tower = [0]
PREDICTOR_THREAD = len(predict_tower) * PREDICTOR_THREAD_PER_GPU PREDICTOR_THREAD = len(predict_tower) * PREDICTOR_THREAD_PER_GPU
train_tower = range(nr_gpu)[:-nr_gpu/2] or [0] train_tower = range(nr_gpu)[:-nr_gpu//2] or [0]
logger.info("[BA3C] Train on gpu {} and infer on gpu {}".format( logger.info("[BA3C] Train on gpu {} and infer on gpu {}".format(
','.join(map(str, train_tower)), ','.join(map(str, predict_tower)))) ','.join(map(str, train_tower)), ','.join(map(str, predict_tower))))
else: else:
......
...@@ -2,7 +2,6 @@ pillow ...@@ -2,7 +2,6 @@ pillow
scipy scipy
nltk nltk
h5py h5py
pyzmq
tornado; python_version < '3.0' tornado; python_version < '3.0'
lmdb lmdb
matplotlib matplotlib
......
...@@ -4,4 +4,5 @@ termcolor ...@@ -4,4 +4,5 @@ termcolor
tqdm>4.6.1 tqdm>4.6.1
msgpack-python msgpack-python
msgpack-numpy msgpack-numpy
pyzmq
subprocess32; python_version < '3.0' subprocess32; python_version < '3.0'
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment