Commit 59cd3c77 authored by Yuxin Wu's avatar Yuxin Wu

change to python2 shebang

parent 642d106e
......@@ -9,8 +9,8 @@ You can actually train them and reproduce the performance... not just to see how
+ [ResNet for Cifar10 classification](examples/ResNet)
+ [Fully-convolutional Network for Holistically-Nested Edge Detection](examples/HED)
+ [Spatial Transformer Networks on MNIST addition](examples/SpatialTransformer)
+ [Double-DQN plays Atari games](examples/Atari2600)
+ [Batch-A3C plays Atari games with demos on OpenAI Gym](examples/OpenAIGym)
+ [Double DQN plays Atari games](examples/Atari2600)
+ [Asynchronous Advantage Actor-Critic(A3C) with demos on OpenAI Gym Atari games](examples/OpenAIGym)
+ [char-rnn language model](examples/char-rnn)
## Features:
......
#!/usr/bin/env python2
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# File: DQN.py
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
......@@ -160,8 +160,7 @@ def get_config():
dataset=dataset_train,
optimizer=tf.train.AdamOptimizer(lr, epsilon=1e-3),
callbacks=Callbacks([
StatPrinter(),
ModelSaver(),
StatPrinter(), ModelSaver(),
ScheduledHyperParamSetter('learning_rate',
[(150, 4e-4), (250, 1e-4), (350, 5e-5)]),
RunOp(lambda: M.update_target_param()),
......
![breakout](breakout.jpg)
[video demo](https://youtu.be/o21mddZtE5Y)
Reproduce the following reinforcement learning methods:
+ Nature-DQN in:
......@@ -18,16 +20,11 @@ Claimed performance in the paper can be reproduced, on several games I've tested
DQN typically took 2 days of training to reach a score of 400 on breakout game.
My Batch-A3C implementation only took <2 hours.
Both were trained on one GPU with an extra GPU for simulation.
<!--
-This is probably the fastest RL trainer you'd find.
-->
The x-axis is the number of iterations, not wall time.
Iteration speed on Tesla M40 is about 9.7it/s for B-A3C.
D-DQN is faster at the beginning but will converge to 12it/s due of exploration annealing.
A demo trained with Double-DQN on breakout is available at [youtube](https://youtu.be/o21mddZtE5Y).
## How to use
Download [atari roms](https://github.com/openai/atari-py/tree/master/atari_py/atari_roms) to
......@@ -40,7 +37,7 @@ To train:
To visualize the agent:
```
./DQN.py --rom breakout.bin --task play --load pretrained.model
./DQN.py --rom breakout.bin --task play --load trained.model
```
A3C code and models for Atari games in OpenAI Gym are released in [examples/OpenAIGym](../OpenAIGym)
# A3C Code and models for my Gym submissions on Atari games.
### A3C code and models for my Gym submissions on Atari games
### To train on an Atari game:
......@@ -7,7 +7,7 @@
### To run a pretrained Atari model for 100 episodes:
1. Download models from [model zoo](https://drive.google.com/open?id=0B9IPQTvr2BBkS0VhX0xmS1c5aFk)
2. `ENV=Breakout-v0; ./run-atari.py --load "$ENV".tfmodel --env "$ENV"`
2. `ENV=Breakout-v0; ./run-atari.py --load "$ENV".tfmodel --env "$ENV" --episode 100 --output output_dir`
Models are available for the following gym atari environments (click links for videos):
......@@ -64,3 +64,5 @@ Note that atari game settings in gym are quite different from DeepMind papers, s
+ In gym, inputs are RGB instead of greyscale.
+ In gym, an episode is limited to 10000 steps.
+ The action space also seems to be different.
Also see the DQN implementation [here](../Atari2600)
#!/usr/bin/env python2
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# File: run-atari.py
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
......@@ -61,25 +61,26 @@ class Model(ModelDesc):
policy = self._get_NN_prediction(state)
self.logits = tf.nn.softmax(policy, name='logits')
def run_submission(cfg):
dirname = 'gym-submit'
player = get_player(dumpdir=dirname)
def run_submission(cfg, output, nr):
player = get_player(dumpdir=output)
predfunc = get_predict_func(cfg)
for k in range(100):
for k in range(nr):
if k != 0:
player.restart_episode()
score = play_one_episode(player, predfunc)
print("Total:", score)
def do_submit():
dirname = 'gym-submit'
gym.upload(dirname, api_key='xxx')
def do_submit(output):
gym.upload(output, api_key='xxx')
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--gpu', help='comma separated list of GPU(s) to use.') # nargs='*' in multi mode
parser.add_argument('--load', help='load model', required=True)
parser.add_argument('--env', help='env', required=True)
parser.add_argument('--env', help='environment name', required=True)
parser.add_argument('--episode', help='number of episodes to run',
type=int, default=100)
parser.add_argument('--output', help='output directory', default='gym-submit')
args = parser.parse_args()
ENV_NAME = args.env
......@@ -95,4 +96,4 @@ if __name__ == '__main__':
session_init=SaverRestore(args.load),
input_var_names=['state'],
output_var_names=['logits'])
run_submission(cfg)
run_submission(cfg, args.output, args.episode)
#!/usr/bin/env python2
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# File: train-atari.py
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
......
......@@ -4,9 +4,9 @@
Training examples with __reproducible__ and meaningful performance.
+ [An illustrative mnist example with explanation of the framework](mnist-convnet.py)
+ [A tiny SVHN ConvNet with 97.5% accuracy](svhn-digit-convnet.py)
+ [A tiny SVHN ConvNet with 97.8% accuracy](svhn-digit-convnet.py)
+ [Inception-BN with 71% accuracy](Inception/inception-bn.py)
+ [InceptionV3 with 74.5% accuracy (similar to the official code)](Inception/inceptionv3.py)
+ [InceptionV3 with 74% accuracy (similar to the official code)](Inception/inceptionv3.py)
+ [DoReFa-Net: binary / low-bitwidth CNN on ImageNet](DoReFa-Net)
+ [ResNet for Cifar10 and SVHN](ResNet)
+ [Holistically-Nested Edge Detection](HED)
......
#!/usr/bin/env python2
#!/usr/bin/env python
# -*- coding: UTF-8 -*-
# File: load-alexnet.py
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
......@@ -15,7 +15,7 @@ from tensorpack.dataflow.dataset import ILSVRCMeta
"""
Usage:
python2 -m tensorpack.utils.loadcaffe PATH/TO/CAFFE/{deploy.prototxt,bvlc_alexnet.caffemodel} alexnet.npy
python -m tensorpack.utils.loadcaffe PATH/TO/CAFFE/{deploy.prototxt,bvlc_alexnet.caffemodel} alexnet.npy
./load-alexnet.py --load alexnet.npy --input cat.png
"""
......
#!/usr/bin/env python2
#!/usr/bin/env python
# -*- coding: UTF-8 -*-
# File: load-vgg16.py
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
......@@ -23,7 +23,7 @@ from tensorpack.dataflow.dataset import ILSVRCMeta
"""
Usage:
python2 -m tensorpack.utils.loadcaffe PATH/TO/VGG/{VGG_ILSVRC_16_layers_deploy.prototxt,VGG_ILSVRC_16_layers.caffemodel} vgg16.npy
python -m tensorpack.utils.loadcaffe PATH/TO/VGG/{VGG_ILSVRC_16_layers_deploy.prototxt,VGG_ILSVRC_16_layers.caffemodel} vgg16.npy
./load-vgg16.py --load vgg16.npy --input cat.png
"""
......
#!/usr/bin/env python2
#!/usr/bin/env python
# -*- coding: UTF-8 -*-
# File: dump-model-params.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
......
#!/usr/bin/env python2
#!/usr/bin/env python
# -*- coding: UTF-8 -*-
# File: dump_train_config.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
......
#!/usr/bin/env python2
#!/usr/bin/env python
# -*- coding: UTF-8 -*-
# File: imgclassify.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
......
#!/usr/bin/env python2
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
......
#!/usr/bin/env python2
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# File: expreplay.py
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
......
#!/usr/bin/env python2
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# File: simulator.py
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
......
#!/usr/bin/env python2
#!/usr/bin/env python
# -*- coding: UTF-8 -*-
# File: param.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
......@@ -144,6 +144,9 @@ class HumanHyperParamSetter(HyperParamSetter):
self.file_name, self.param.readable_name))
def _get_value_to_set(self):
# ignore if no such file exists
if not os.path.isfile(self.file_name):
return None
try:
with open(self.file_name) as f:
lines = f.readlines()
......@@ -152,9 +155,9 @@ class HumanHyperParamSetter(HyperParamSetter):
ret = dic[self.param.readable_name]
return ret
except:
#logger.warn(
#"Cannot find {} in {}".format(
#self.param.readable_name, self.file_name))
logger.warn(
"Cannot find {} in {}".format(
self.param.readable_name, self.file_name))
return None
class ScheduledHyperParamSetter(HyperParamSetter):
......
#!/usr/bin/env python2
#!/usr/bin/env python
# -*- coding: UTF-8 -*-
# File: base.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
......
#!/usr/bin/env python2
#!/usr/bin/env python
# -*- coding: UTF-8 -*-
# File: cifar.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
......
#!/usr/bin/env python2
#!/usr/bin/env python
# -*- coding: UTF-8 -*-
# File: mnist.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
......
#!/usr/bin/env python2
#!/usr/bin/env python
# -*- coding: UTF-8 -*-
# File: svhn.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
......
#!/usr/bin/env python2
#!/usr/bin/env python
# -*- coding: UTF-8 -*-
# File: _test.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
......
#!/usr/bin/env python2
#!/usr/bin/env python
# -*- coding: UTF-8 -*-
# File: _test.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
......
#!/usr/bin/env python2
#!/usr/bin/env python
# -*- coding: UTF-8 -*-
# File: batch_norm.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
......
#!/usr/bin/env python2
#!/usr/bin/env python
# -*- coding: UTF-8 -*-
# File: conv2d.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
......
#!/usr/bin/env python2
#!/usr/bin/env python
# -*- coding: UTF-8 -*-
# File: fc.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
......
#!/usr/bin/env python2
#!/usr/bin/env python
# -*- coding: UTF-8 -*-
# File: image_sample.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
......
#!/usr/bin/env python2
#!/usr/bin/env python
# -*- coding: UTF-8 -*-
# File: model_desc.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
......
#!/usr/bin/env python2
#!/usr/bin/env python
# -*- coding: UTF-8 -*-
# File: nonlin.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
......
#!/usr/bin/env python2
#!/usr/bin/env python
# -*- coding: UTF-8 -*-
# File: pool.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
......
#!/usr/bin/env python2
#!/usr/bin/env python
# -*- coding: UTF-8 -*-
# File: common.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
......
#!/usr/bin/env python2
#!/usr/bin/env python
# -*- coding: UTF-8 -*-
# File: gradproc.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
......
......@@ -126,7 +126,7 @@ class Trainer(object):
for self.epoch_num in range(
self.config.starting_epoch, self.config.max_epoch+1):
with timed_operation(
'Epoch {}, global_step={}'.format(
'Epoch {} (global_step {})'.format(
self.epoch_num, self.global_step + self.config.step_per_epoch)):
for step in tqdm.trange(
self.config.step_per_epoch,
......
......@@ -42,7 +42,7 @@ def timed_operation(msg, log_start=False):
logger.info('Start {} ...'.format(msg))
start = time.time()
yield
logger.info('{} finished, time={:.2f}sec.'.format(
logger.info('{} finished, time:{:.2f}sec.'.format(
msg, time.time() - start))
_TOTAL_TIMER_DATA = defaultdict(StatCounter)
......
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# File: viz.py
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
# Credit: zxytim
import numpy as np
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment