Commit 61384a65 authored by Yuxin Wu's avatar Yuxin Wu

small bugfix

parent b95ea88f
......@@ -8,7 +8,7 @@ You can actually train them and reproduce the performance... not just to see how
+ [DoReFa-Net: training binary / low bitwidth CNN](examples/DoReFa-Net)
+ [InceptionV3 on ImageNet](examples/Inception/inceptionv3.py)
+ [ResNet for ImageNet/Cifar10 classification](examples/ResNet)
+ [ResNet for ImageNet/Cifar10/SVHN classification](examples/ResNet)
+ [Fully-convolutional Network for Holistically-Nested Edge Detection](examples/HED)
+ [Spatial Transformer Networks on MNIST addition](examples/SpatialTransformer)
+ [Double DQN plays Atari games](examples/Atari2600)
......
......@@ -23,9 +23,6 @@ import common
from common import play_model, Evaluator, eval_model_multithread
from atari import AtariPlayer
METHOD = ['DQN', 'Double', 'Dueling'][1]
BATCH_SIZE = 64
IMAGE_SIZE = (84, 84)
FRAME_HISTORY = 4
......@@ -48,6 +45,7 @@ EVAL_EPISODE = 50
NUM_ACTIONS = None
ROM_FILE = None
METHOD = None
def get_player(viz=False, train=False):
pl = AtariPlayer(ROM_FILE, frame_skip=ACTION_REPEAT,
......@@ -123,7 +121,8 @@ class Model(ModelDesc):
target = reward + (1.0 - tf.cast(isOver, tf.float32)) * GAMMA * tf.stop_gradient(best_v)
self.cost = tf.truediv(symbf.huber_loss(target - pred_action_value), BATCH_SIZE, name='cost')
self.cost = tf.truediv(symbf.huber_loss(target - pred_action_value),
tf.cast(BATCH_SIZE, tf.float32), name='cost')
summary.add_param_summary([('conv.*/W', ['histogram', 'rms']),
('fc.*/W', ['histogram', 'rms']) ]) # monitor all W
......@@ -188,6 +187,8 @@ if __name__ == '__main__':
parser.add_argument('--task', help='task to perform',
choices=['play', 'eval', 'train'], default='train')
parser.add_argument('--rom', help='atari rom', required=True)
parser.add_argument('--algo', help='algorithm',
choices=['DQN', 'Double', 'Dueling'], default='Double')
args = parser.parse_args()
if args.gpu:
......@@ -195,6 +196,7 @@ if __name__ == '__main__':
if args.task != 'train':
assert args.load is not None
ROM_FILE = args.rom
METHOD = args.algo
if args.task != 'train':
cfg = PredictConfig(
......
......@@ -2,7 +2,7 @@
[video demo](https://youtu.be/o21mddZtE5Y)
Reproduce the following reinforcement learning papers:
Reproduce the following reinforcement learning methods:
+ Nature-DQN in:
[Human-level Control Through Deep Reinforcement Learning](http://www.nature.com/nature/journal/v518/n7540/full/nature14236.html)
......@@ -29,12 +29,13 @@ D-DQN is faster at the beginning but will converge to 12it/s due of exploration
## How to use
Download [atari roms](https://github.com/openai/atari-py/tree/master/atari_py/atari_roms) to
`$TENSORPACK_DATASET/atari_rom` (defaults to tensorpack/dataflow/dataset/atari_rom).
Download an [atari rom](https://github.com/openai/atari-py/tree/master/atari_py/atari_roms) to
`$TENSORPACK_DATASET/atari_rom/` (defaults to tensorpack/dataflow/dataset/atari_rom/).
To train:
```
./DQN.py --rom breakout.bin --gpu 0
./DQN.py --rom breakout.bin
# use `--algo` to select other DQN algorithms
```
To visualize the agent:
......
......@@ -160,7 +160,6 @@ def get_config():
dataset_train = get_data('train')
step_per_epoch = dataset_train.size() * 40
dataset_val = get_data('val')
#dataset_test = get_data('test')
lr = tf.Variable(3e-5, trainable=False, name='learning_rate')
tf.scalar_summary('learning_rate', lr)
......@@ -169,8 +168,7 @@ def get_config():
dataset=dataset_train,
optimizer=tf.train.AdamOptimizer(lr, epsilon=1e-3),
callbacks=Callbacks([
StatPrinter(),
ModelSaver(),
StatPrinter(), ModelSaver(),
ScheduledHyperParamSetter('learning_rate', [(30, 6e-6), (45, 1e-6), (60, 8e-7)]),
HumanHyperParamSetter('learning_rate'),
InferenceRunner(dataset_val,
......
......@@ -8,7 +8,7 @@ Training examples with __reproducible__ and meaningful performance.
+ [Inception-BN with 71% accuracy](Inception/inception-bn.py)
+ [InceptionV3 with 74% accuracy (similar to the official code)](Inception/inceptionv3.py)
+ [DoReFa-Net: binary / low-bitwidth CNN on ImageNet](DoReFa-Net)
+ [ResNet for Cifar10 and SVHN](ResNet)
+ [ResNet for ImageNet/Cifar10/SVHN](ResNet)
+ [Holistically-Nested Edge Detection](HED)
+ [Spatial Transformer Networks on MNIST addition](SpatialTransformer)
+ [DisturbLabel, because I don't believe the paper](DisturbLabel)
......
......@@ -59,9 +59,10 @@ def BatchNorm(x, use_local_stat=None, decay=0.9, epsilon=1e-5):
ctx = get_current_tower_context()
if use_local_stat is None:
use_local_stat = ctx.is_training
assert use_local_stat == ctx.is_training
if use_local_stat != ctx.is_training:
logger.warn("[BatchNorm] use_local_stat != is_training")
if ctx.is_training:
if use_local_stat:
# training tower
with tf.name_scope(None): # https://github.com/tensorflow/tensorflow/issues/2740
ema = tf.train.ExponentialMovingAverage(decay=decay, name=emaname)
......@@ -72,7 +73,6 @@ def BatchNorm(x, use_local_stat=None, decay=0.9, epsilon=1e-5):
tf.add_to_collection(EXTRA_SAVE_VARS_KEY, ema_mean)
tf.add_to_collection(EXTRA_SAVE_VARS_KEY, ema_var)
else:
assert not use_local_stat
if ctx.is_main_tower:
# not training, but main tower. need to create the vars
with tf.name_scope(None):
......
......@@ -159,7 +159,7 @@ class ParamRestore(SessionInit):
logger.info("Params to restore: {}".format(
', '.join(map(str, intersect))))
for k in variable_names - param_names:
if not is_training_specific_name(k):
if not is_training_name(k):
logger.warn("Variable {} in the graph not found in the dict!".format(k))
for k in param_names - variable_names:
logger.warn("Variable {} in the dict not found in the graph!".format(k))
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment