Commit 59cd3c77 authored by Yuxin Wu's avatar Yuxin Wu

change to python2 shebang

parent 642d106e
...@@ -9,8 +9,8 @@ You can actually train them and reproduce the performance... not just to see how ...@@ -9,8 +9,8 @@ You can actually train them and reproduce the performance... not just to see how
+ [ResNet for Cifar10 classification](examples/ResNet) + [ResNet for Cifar10 classification](examples/ResNet)
+ [Fully-convolutional Network for Holistically-Nested Edge Detection](examples/HED) + [Fully-convolutional Network for Holistically-Nested Edge Detection](examples/HED)
+ [Spatial Transformer Networks on MNIST addition](examples/SpatialTransformer) + [Spatial Transformer Networks on MNIST addition](examples/SpatialTransformer)
+ [Double-DQN plays Atari games](examples/Atari2600) + [Double DQN plays Atari games](examples/Atari2600)
+ [Batch-A3C plays Atari games with demos on OpenAI Gym](examples/OpenAIGym) + [Asynchronous Advantage Actor-Critic(A3C) with demos on OpenAI Gym Atari games](examples/OpenAIGym)
+ [char-rnn language model](examples/char-rnn) + [char-rnn language model](examples/char-rnn)
## Features: ## Features:
......
#!/usr/bin/env python2 #!/usr/bin/env python
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# File: DQN.py # File: DQN.py
# Author: Yuxin Wu <ppwwyyxxc@gmail.com> # Author: Yuxin Wu <ppwwyyxxc@gmail.com>
...@@ -160,8 +160,7 @@ def get_config(): ...@@ -160,8 +160,7 @@ def get_config():
dataset=dataset_train, dataset=dataset_train,
optimizer=tf.train.AdamOptimizer(lr, epsilon=1e-3), optimizer=tf.train.AdamOptimizer(lr, epsilon=1e-3),
callbacks=Callbacks([ callbacks=Callbacks([
StatPrinter(), StatPrinter(), ModelSaver(),
ModelSaver(),
ScheduledHyperParamSetter('learning_rate', ScheduledHyperParamSetter('learning_rate',
[(150, 4e-4), (250, 1e-4), (350, 5e-5)]), [(150, 4e-4), (250, 1e-4), (350, 5e-5)]),
RunOp(lambda: M.update_target_param()), RunOp(lambda: M.update_target_param()),
......
![breakout](breakout.jpg) ![breakout](breakout.jpg)
[video demo](https://youtu.be/o21mddZtE5Y)
Reproduce the following reinforcement learning methods: Reproduce the following reinforcement learning methods:
+ Nature-DQN in: + Nature-DQN in:
...@@ -18,16 +20,11 @@ Claimed performance in the paper can be reproduced, on several games I've tested ...@@ -18,16 +20,11 @@ Claimed performance in the paper can be reproduced, on several games I've tested
DQN typically took 2 days of training to reach a score of 400 on breakout game. DQN typically took 2 days of training to reach a score of 400 on breakout game.
My Batch-A3C implementation only took <2 hours. My Batch-A3C implementation only took <2 hours.
Both were trained on one GPU with an extra GPU for simulation. Both were trained on one GPU with an extra GPU for simulation.
<!--
-This is probably the fastest RL trainer you'd find.
-->
The x-axis is the number of iterations, not wall time. The x-axis is the number of iterations, not wall time.
Iteration speed on Tesla M40 is about 9.7it/s for B-A3C. Iteration speed on Tesla M40 is about 9.7it/s for B-A3C.
D-DQN is faster at the beginning but will converge to 12it/s due of exploration annealing. D-DQN is faster at the beginning but will converge to 12it/s due of exploration annealing.
A demo trained with Double-DQN on breakout is available at [youtube](https://youtu.be/o21mddZtE5Y).
## How to use ## How to use
Download [atari roms](https://github.com/openai/atari-py/tree/master/atari_py/atari_roms) to Download [atari roms](https://github.com/openai/atari-py/tree/master/atari_py/atari_roms) to
...@@ -40,7 +37,7 @@ To train: ...@@ -40,7 +37,7 @@ To train:
To visualize the agent: To visualize the agent:
``` ```
./DQN.py --rom breakout.bin --task play --load pretrained.model ./DQN.py --rom breakout.bin --task play --load trained.model
``` ```
A3C code and models for Atari games in OpenAI Gym are released in [examples/OpenAIGym](../OpenAIGym) A3C code and models for Atari games in OpenAI Gym are released in [examples/OpenAIGym](../OpenAIGym)
# A3C Code and models for my Gym submissions on Atari games. ### A3C code and models for my Gym submissions on Atari games
### To train on an Atari game: ### To train on an Atari game:
...@@ -7,7 +7,7 @@ ...@@ -7,7 +7,7 @@
### To run a pretrained Atari model for 100 episodes: ### To run a pretrained Atari model for 100 episodes:
1. Download models from [model zoo](https://drive.google.com/open?id=0B9IPQTvr2BBkS0VhX0xmS1c5aFk) 1. Download models from [model zoo](https://drive.google.com/open?id=0B9IPQTvr2BBkS0VhX0xmS1c5aFk)
2. `ENV=Breakout-v0; ./run-atari.py --load "$ENV".tfmodel --env "$ENV"` 2. `ENV=Breakout-v0; ./run-atari.py --load "$ENV".tfmodel --env "$ENV" --episode 100 --output output_dir`
Models are available for the following gym atari environments (click links for videos): Models are available for the following gym atari environments (click links for videos):
...@@ -64,3 +64,5 @@ Note that atari game settings in gym are quite different from DeepMind papers, s ...@@ -64,3 +64,5 @@ Note that atari game settings in gym are quite different from DeepMind papers, s
+ In gym, inputs are RGB instead of greyscale. + In gym, inputs are RGB instead of greyscale.
+ In gym, an episode is limited to 10000 steps. + In gym, an episode is limited to 10000 steps.
+ The action space also seems to be different. + The action space also seems to be different.
Also see the DQN implementation [here](../Atari2600)
#!/usr/bin/env python2 #!/usr/bin/env python
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# File: run-atari.py # File: run-atari.py
# Author: Yuxin Wu <ppwwyyxxc@gmail.com> # Author: Yuxin Wu <ppwwyyxxc@gmail.com>
...@@ -61,25 +61,26 @@ class Model(ModelDesc): ...@@ -61,25 +61,26 @@ class Model(ModelDesc):
policy = self._get_NN_prediction(state) policy = self._get_NN_prediction(state)
self.logits = tf.nn.softmax(policy, name='logits') self.logits = tf.nn.softmax(policy, name='logits')
def run_submission(cfg): def run_submission(cfg, output, nr):
dirname = 'gym-submit' player = get_player(dumpdir=output)
player = get_player(dumpdir=dirname)
predfunc = get_predict_func(cfg) predfunc = get_predict_func(cfg)
for k in range(100): for k in range(nr):
if k != 0: if k != 0:
player.restart_episode() player.restart_episode()
score = play_one_episode(player, predfunc) score = play_one_episode(player, predfunc)
print("Total:", score) print("Total:", score)
def do_submit(): def do_submit(output):
dirname = 'gym-submit' gym.upload(output, api_key='xxx')
gym.upload(dirname, api_key='xxx')
if __name__ == '__main__': if __name__ == '__main__':
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('--gpu', help='comma separated list of GPU(s) to use.') # nargs='*' in multi mode parser.add_argument('--gpu', help='comma separated list of GPU(s) to use.') # nargs='*' in multi mode
parser.add_argument('--load', help='load model', required=True) parser.add_argument('--load', help='load model', required=True)
parser.add_argument('--env', help='env', required=True) parser.add_argument('--env', help='environment name', required=True)
parser.add_argument('--episode', help='number of episodes to run',
type=int, default=100)
parser.add_argument('--output', help='output directory', default='gym-submit')
args = parser.parse_args() args = parser.parse_args()
ENV_NAME = args.env ENV_NAME = args.env
...@@ -95,4 +96,4 @@ if __name__ == '__main__': ...@@ -95,4 +96,4 @@ if __name__ == '__main__':
session_init=SaverRestore(args.load), session_init=SaverRestore(args.load),
input_var_names=['state'], input_var_names=['state'],
output_var_names=['logits']) output_var_names=['logits'])
run_submission(cfg) run_submission(cfg, args.output, args.episode)
#!/usr/bin/env python2 #!/usr/bin/env python
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# File: train-atari.py # File: train-atari.py
# Author: Yuxin Wu <ppwwyyxxc@gmail.com> # Author: Yuxin Wu <ppwwyyxxc@gmail.com>
......
...@@ -4,9 +4,9 @@ ...@@ -4,9 +4,9 @@
Training examples with __reproducible__ and meaningful performance. Training examples with __reproducible__ and meaningful performance.
+ [An illustrative mnist example with explanation of the framework](mnist-convnet.py) + [An illustrative mnist example with explanation of the framework](mnist-convnet.py)
+ [A tiny SVHN ConvNet with 97.5% accuracy](svhn-digit-convnet.py) + [A tiny SVHN ConvNet with 97.8% accuracy](svhn-digit-convnet.py)
+ [Inception-BN with 71% accuracy](Inception/inception-bn.py) + [Inception-BN with 71% accuracy](Inception/inception-bn.py)
+ [InceptionV3 with 74.5% accuracy (similar to the official code)](Inception/inceptionv3.py) + [InceptionV3 with 74% accuracy (similar to the official code)](Inception/inceptionv3.py)
+ [DoReFa-Net: binary / low-bitwidth CNN on ImageNet](DoReFa-Net) + [DoReFa-Net: binary / low-bitwidth CNN on ImageNet](DoReFa-Net)
+ [ResNet for Cifar10 and SVHN](ResNet) + [ResNet for Cifar10 and SVHN](ResNet)
+ [Holistically-Nested Edge Detection](HED) + [Holistically-Nested Edge Detection](HED)
......
#!/usr/bin/env python2 #!/usr/bin/env python
# -*- coding: UTF-8 -*- # -*- coding: UTF-8 -*-
# File: load-alexnet.py # File: load-alexnet.py
# Author: Yuxin Wu <ppwwyyxxc@gmail.com> # Author: Yuxin Wu <ppwwyyxxc@gmail.com>
...@@ -15,7 +15,7 @@ from tensorpack.dataflow.dataset import ILSVRCMeta ...@@ -15,7 +15,7 @@ from tensorpack.dataflow.dataset import ILSVRCMeta
""" """
Usage: Usage:
python2 -m tensorpack.utils.loadcaffe PATH/TO/CAFFE/{deploy.prototxt,bvlc_alexnet.caffemodel} alexnet.npy python -m tensorpack.utils.loadcaffe PATH/TO/CAFFE/{deploy.prototxt,bvlc_alexnet.caffemodel} alexnet.npy
./load-alexnet.py --load alexnet.npy --input cat.png ./load-alexnet.py --load alexnet.npy --input cat.png
""" """
......
#!/usr/bin/env python2 #!/usr/bin/env python
# -*- coding: UTF-8 -*- # -*- coding: UTF-8 -*-
# File: load-vgg16.py # File: load-vgg16.py
# Author: Yuxin Wu <ppwwyyxxc@gmail.com> # Author: Yuxin Wu <ppwwyyxxc@gmail.com>
...@@ -23,7 +23,7 @@ from tensorpack.dataflow.dataset import ILSVRCMeta ...@@ -23,7 +23,7 @@ from tensorpack.dataflow.dataset import ILSVRCMeta
""" """
Usage: Usage:
python2 -m tensorpack.utils.loadcaffe PATH/TO/VGG/{VGG_ILSVRC_16_layers_deploy.prototxt,VGG_ILSVRC_16_layers.caffemodel} vgg16.npy python -m tensorpack.utils.loadcaffe PATH/TO/VGG/{VGG_ILSVRC_16_layers_deploy.prototxt,VGG_ILSVRC_16_layers.caffemodel} vgg16.npy
./load-vgg16.py --load vgg16.npy --input cat.png ./load-vgg16.py --load vgg16.npy --input cat.png
""" """
......
#!/usr/bin/env python2 #!/usr/bin/env python
# -*- coding: UTF-8 -*- # -*- coding: UTF-8 -*-
# File: dump-model-params.py # File: dump-model-params.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com> # Author: Yuxin Wu <ppwwyyxx@gmail.com>
......
#!/usr/bin/env python2 #!/usr/bin/env python
# -*- coding: UTF-8 -*- # -*- coding: UTF-8 -*-
# File: dump_train_config.py # File: dump_train_config.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com> # Author: Yuxin Wu <ppwwyyxx@gmail.com>
......
#!/usr/bin/env python2 #!/usr/bin/env python
# -*- coding: UTF-8 -*- # -*- coding: UTF-8 -*-
# File: imgclassify.py # File: imgclassify.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com> # Author: Yuxin Wu <ppwwyyxx@gmail.com>
......
#!/usr/bin/env python2 #!/usr/bin/env python
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
""" """
......
#!/usr/bin/env python2 #!/usr/bin/env python
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# File: expreplay.py # File: expreplay.py
# Author: Yuxin Wu <ppwwyyxxc@gmail.com> # Author: Yuxin Wu <ppwwyyxxc@gmail.com>
......
#!/usr/bin/env python2 #!/usr/bin/env python
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# File: simulator.py # File: simulator.py
# Author: Yuxin Wu <ppwwyyxxc@gmail.com> # Author: Yuxin Wu <ppwwyyxxc@gmail.com>
......
#!/usr/bin/env python2 #!/usr/bin/env python
# -*- coding: UTF-8 -*- # -*- coding: UTF-8 -*-
# File: param.py # File: param.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com> # Author: Yuxin Wu <ppwwyyxx@gmail.com>
...@@ -144,6 +144,9 @@ class HumanHyperParamSetter(HyperParamSetter): ...@@ -144,6 +144,9 @@ class HumanHyperParamSetter(HyperParamSetter):
self.file_name, self.param.readable_name)) self.file_name, self.param.readable_name))
def _get_value_to_set(self): def _get_value_to_set(self):
# ignore if no such file exists
if not os.path.isfile(self.file_name):
return None
try: try:
with open(self.file_name) as f: with open(self.file_name) as f:
lines = f.readlines() lines = f.readlines()
...@@ -152,9 +155,9 @@ class HumanHyperParamSetter(HyperParamSetter): ...@@ -152,9 +155,9 @@ class HumanHyperParamSetter(HyperParamSetter):
ret = dic[self.param.readable_name] ret = dic[self.param.readable_name]
return ret return ret
except: except:
#logger.warn( logger.warn(
#"Cannot find {} in {}".format( "Cannot find {} in {}".format(
#self.param.readable_name, self.file_name)) self.param.readable_name, self.file_name))
return None return None
class ScheduledHyperParamSetter(HyperParamSetter): class ScheduledHyperParamSetter(HyperParamSetter):
......
#!/usr/bin/env python2 #!/usr/bin/env python
# -*- coding: UTF-8 -*- # -*- coding: UTF-8 -*-
# File: base.py # File: base.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com> # Author: Yuxin Wu <ppwwyyxx@gmail.com>
......
#!/usr/bin/env python2 #!/usr/bin/env python
# -*- coding: UTF-8 -*- # -*- coding: UTF-8 -*-
# File: cifar.py # File: cifar.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com> # Author: Yuxin Wu <ppwwyyxx@gmail.com>
......
#!/usr/bin/env python2 #!/usr/bin/env python
# -*- coding: UTF-8 -*- # -*- coding: UTF-8 -*-
# File: mnist.py # File: mnist.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com> # Author: Yuxin Wu <ppwwyyxx@gmail.com>
......
#!/usr/bin/env python2 #!/usr/bin/env python
# -*- coding: UTF-8 -*- # -*- coding: UTF-8 -*-
# File: svhn.py # File: svhn.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com> # Author: Yuxin Wu <ppwwyyxx@gmail.com>
......
#!/usr/bin/env python2 #!/usr/bin/env python
# -*- coding: UTF-8 -*- # -*- coding: UTF-8 -*-
# File: _test.py # File: _test.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com> # Author: Yuxin Wu <ppwwyyxx@gmail.com>
......
#!/usr/bin/env python2 #!/usr/bin/env python
# -*- coding: UTF-8 -*- # -*- coding: UTF-8 -*-
# File: _test.py # File: _test.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com> # Author: Yuxin Wu <ppwwyyxx@gmail.com>
......
#!/usr/bin/env python2 #!/usr/bin/env python
# -*- coding: UTF-8 -*- # -*- coding: UTF-8 -*-
# File: batch_norm.py # File: batch_norm.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com> # Author: Yuxin Wu <ppwwyyxx@gmail.com>
......
#!/usr/bin/env python2 #!/usr/bin/env python
# -*- coding: UTF-8 -*- # -*- coding: UTF-8 -*-
# File: conv2d.py # File: conv2d.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com> # Author: Yuxin Wu <ppwwyyxx@gmail.com>
......
#!/usr/bin/env python2 #!/usr/bin/env python
# -*- coding: UTF-8 -*- # -*- coding: UTF-8 -*-
# File: fc.py # File: fc.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com> # Author: Yuxin Wu <ppwwyyxx@gmail.com>
......
#!/usr/bin/env python2 #!/usr/bin/env python
# -*- coding: UTF-8 -*- # -*- coding: UTF-8 -*-
# File: image_sample.py # File: image_sample.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com> # Author: Yuxin Wu <ppwwyyxx@gmail.com>
......
#!/usr/bin/env python2 #!/usr/bin/env python
# -*- coding: UTF-8 -*- # -*- coding: UTF-8 -*-
# File: model_desc.py # File: model_desc.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com> # Author: Yuxin Wu <ppwwyyxx@gmail.com>
......
#!/usr/bin/env python2 #!/usr/bin/env python
# -*- coding: UTF-8 -*- # -*- coding: UTF-8 -*-
# File: nonlin.py # File: nonlin.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com> # Author: Yuxin Wu <ppwwyyxx@gmail.com>
......
#!/usr/bin/env python2 #!/usr/bin/env python
# -*- coding: UTF-8 -*- # -*- coding: UTF-8 -*-
# File: pool.py # File: pool.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com> # Author: Yuxin Wu <ppwwyyxx@gmail.com>
......
#!/usr/bin/env python2 #!/usr/bin/env python
# -*- coding: UTF-8 -*- # -*- coding: UTF-8 -*-
# File: common.py # File: common.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com> # Author: Yuxin Wu <ppwwyyxx@gmail.com>
......
#!/usr/bin/env python2 #!/usr/bin/env python
# -*- coding: UTF-8 -*- # -*- coding: UTF-8 -*-
# File: gradproc.py # File: gradproc.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com> # Author: Yuxin Wu <ppwwyyxx@gmail.com>
......
...@@ -126,7 +126,7 @@ class Trainer(object): ...@@ -126,7 +126,7 @@ class Trainer(object):
for self.epoch_num in range( for self.epoch_num in range(
self.config.starting_epoch, self.config.max_epoch+1): self.config.starting_epoch, self.config.max_epoch+1):
with timed_operation( with timed_operation(
'Epoch {}, global_step={}'.format( 'Epoch {} (global_step {})'.format(
self.epoch_num, self.global_step + self.config.step_per_epoch)): self.epoch_num, self.global_step + self.config.step_per_epoch)):
for step in tqdm.trange( for step in tqdm.trange(
self.config.step_per_epoch, self.config.step_per_epoch,
......
...@@ -42,7 +42,7 @@ def timed_operation(msg, log_start=False): ...@@ -42,7 +42,7 @@ def timed_operation(msg, log_start=False):
logger.info('Start {} ...'.format(msg)) logger.info('Start {} ...'.format(msg))
start = time.time() start = time.time()
yield yield
logger.info('{} finished, time={:.2f}sec.'.format( logger.info('{} finished, time:{:.2f}sec.'.format(
msg, time.time() - start)) msg, time.time() - start))
_TOTAL_TIMER_DATA = defaultdict(StatCounter) _TOTAL_TIMER_DATA = defaultdict(StatCounter)
......
#!/usr/bin/env python #!/usr/bin/env python
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# File: viz.py # File: viz.py
# Author: Yuxin Wu <ppwwyyxxc@gmail.com> # Credit: zxytim
import numpy as np import numpy as np
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment