Commit 61384a65 authored by Yuxin Wu's avatar Yuxin Wu

small bugfix

parent b95ea88f
...@@ -8,7 +8,7 @@ You can actually train them and reproduce the performance... not just to see how ...@@ -8,7 +8,7 @@ You can actually train them and reproduce the performance... not just to see how
+ [DoReFa-Net: training binary / low bitwidth CNN](examples/DoReFa-Net) + [DoReFa-Net: training binary / low bitwidth CNN](examples/DoReFa-Net)
+ [InceptionV3 on ImageNet](examples/Inception/inceptionv3.py) + [InceptionV3 on ImageNet](examples/Inception/inceptionv3.py)
+ [ResNet for ImageNet/Cifar10 classification](examples/ResNet) + [ResNet for ImageNet/Cifar10/SVHN classification](examples/ResNet)
+ [Fully-convolutional Network for Holistically-Nested Edge Detection](examples/HED) + [Fully-convolutional Network for Holistically-Nested Edge Detection](examples/HED)
+ [Spatial Transformer Networks on MNIST addition](examples/SpatialTransformer) + [Spatial Transformer Networks on MNIST addition](examples/SpatialTransformer)
+ [Double DQN plays Atari games](examples/Atari2600) + [Double DQN plays Atari games](examples/Atari2600)
......
...@@ -23,9 +23,6 @@ import common ...@@ -23,9 +23,6 @@ import common
from common import play_model, Evaluator, eval_model_multithread from common import play_model, Evaluator, eval_model_multithread
from atari import AtariPlayer from atari import AtariPlayer
METHOD = ['DQN', 'Double', 'Dueling'][1]
BATCH_SIZE = 64 BATCH_SIZE = 64
IMAGE_SIZE = (84, 84) IMAGE_SIZE = (84, 84)
FRAME_HISTORY = 4 FRAME_HISTORY = 4
...@@ -48,6 +45,7 @@ EVAL_EPISODE = 50 ...@@ -48,6 +45,7 @@ EVAL_EPISODE = 50
NUM_ACTIONS = None NUM_ACTIONS = None
ROM_FILE = None ROM_FILE = None
METHOD = None
def get_player(viz=False, train=False): def get_player(viz=False, train=False):
pl = AtariPlayer(ROM_FILE, frame_skip=ACTION_REPEAT, pl = AtariPlayer(ROM_FILE, frame_skip=ACTION_REPEAT,
...@@ -123,7 +121,8 @@ class Model(ModelDesc): ...@@ -123,7 +121,8 @@ class Model(ModelDesc):
target = reward + (1.0 - tf.cast(isOver, tf.float32)) * GAMMA * tf.stop_gradient(best_v) target = reward + (1.0 - tf.cast(isOver, tf.float32)) * GAMMA * tf.stop_gradient(best_v)
self.cost = tf.truediv(symbf.huber_loss(target - pred_action_value), BATCH_SIZE, name='cost') self.cost = tf.truediv(symbf.huber_loss(target - pred_action_value),
tf.cast(BATCH_SIZE, tf.float32), name='cost')
summary.add_param_summary([('conv.*/W', ['histogram', 'rms']), summary.add_param_summary([('conv.*/W', ['histogram', 'rms']),
('fc.*/W', ['histogram', 'rms']) ]) # monitor all W ('fc.*/W', ['histogram', 'rms']) ]) # monitor all W
...@@ -188,6 +187,8 @@ if __name__ == '__main__': ...@@ -188,6 +187,8 @@ if __name__ == '__main__':
parser.add_argument('--task', help='task to perform', parser.add_argument('--task', help='task to perform',
choices=['play', 'eval', 'train'], default='train') choices=['play', 'eval', 'train'], default='train')
parser.add_argument('--rom', help='atari rom', required=True) parser.add_argument('--rom', help='atari rom', required=True)
parser.add_argument('--algo', help='algorithm',
choices=['DQN', 'Double', 'Dueling'], default='Double')
args = parser.parse_args() args = parser.parse_args()
if args.gpu: if args.gpu:
...@@ -195,6 +196,7 @@ if __name__ == '__main__': ...@@ -195,6 +196,7 @@ if __name__ == '__main__':
if args.task != 'train': if args.task != 'train':
assert args.load is not None assert args.load is not None
ROM_FILE = args.rom ROM_FILE = args.rom
METHOD = args.algo
if args.task != 'train': if args.task != 'train':
cfg = PredictConfig( cfg = PredictConfig(
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
[video demo](https://youtu.be/o21mddZtE5Y) [video demo](https://youtu.be/o21mddZtE5Y)
Reproduce the following reinforcement learning papers: Reproduce the following reinforcement learning methods:
+ Nature-DQN in: + Nature-DQN in:
[Human-level Control Through Deep Reinforcement Learning](http://www.nature.com/nature/journal/v518/n7540/full/nature14236.html) [Human-level Control Through Deep Reinforcement Learning](http://www.nature.com/nature/journal/v518/n7540/full/nature14236.html)
...@@ -29,12 +29,13 @@ D-DQN is faster at the beginning but will converge to 12it/s due of exploration ...@@ -29,12 +29,13 @@ D-DQN is faster at the beginning but will converge to 12it/s due of exploration
## How to use ## How to use
Download [atari roms](https://github.com/openai/atari-py/tree/master/atari_py/atari_roms) to Download an [atari rom](https://github.com/openai/atari-py/tree/master/atari_py/atari_roms) to
`$TENSORPACK_DATASET/atari_rom` (defaults to tensorpack/dataflow/dataset/atari_rom). `$TENSORPACK_DATASET/atari_rom/` (defaults to tensorpack/dataflow/dataset/atari_rom/).
To train: To train:
``` ```
./DQN.py --rom breakout.bin --gpu 0 ./DQN.py --rom breakout.bin
# use `--algo` to select other DQN algorithms
``` ```
To visualize the agent: To visualize the agent:
......
...@@ -160,7 +160,6 @@ def get_config(): ...@@ -160,7 +160,6 @@ def get_config():
dataset_train = get_data('train') dataset_train = get_data('train')
step_per_epoch = dataset_train.size() * 40 step_per_epoch = dataset_train.size() * 40
dataset_val = get_data('val') dataset_val = get_data('val')
#dataset_test = get_data('test')
lr = tf.Variable(3e-5, trainable=False, name='learning_rate') lr = tf.Variable(3e-5, trainable=False, name='learning_rate')
tf.scalar_summary('learning_rate', lr) tf.scalar_summary('learning_rate', lr)
...@@ -169,8 +168,7 @@ def get_config(): ...@@ -169,8 +168,7 @@ def get_config():
dataset=dataset_train, dataset=dataset_train,
optimizer=tf.train.AdamOptimizer(lr, epsilon=1e-3), optimizer=tf.train.AdamOptimizer(lr, epsilon=1e-3),
callbacks=Callbacks([ callbacks=Callbacks([
StatPrinter(), StatPrinter(), ModelSaver(),
ModelSaver(),
ScheduledHyperParamSetter('learning_rate', [(30, 6e-6), (45, 1e-6), (60, 8e-7)]), ScheduledHyperParamSetter('learning_rate', [(30, 6e-6), (45, 1e-6), (60, 8e-7)]),
HumanHyperParamSetter('learning_rate'), HumanHyperParamSetter('learning_rate'),
InferenceRunner(dataset_val, InferenceRunner(dataset_val,
......
...@@ -8,7 +8,7 @@ Training examples with __reproducible__ and meaningful performance. ...@@ -8,7 +8,7 @@ Training examples with __reproducible__ and meaningful performance.
+ [Inception-BN with 71% accuracy](Inception/inception-bn.py) + [Inception-BN with 71% accuracy](Inception/inception-bn.py)
+ [InceptionV3 with 74% accuracy (similar to the official code)](Inception/inceptionv3.py) + [InceptionV3 with 74% accuracy (similar to the official code)](Inception/inceptionv3.py)
+ [DoReFa-Net: binary / low-bitwidth CNN on ImageNet](DoReFa-Net) + [DoReFa-Net: binary / low-bitwidth CNN on ImageNet](DoReFa-Net)
+ [ResNet for Cifar10 and SVHN](ResNet) + [ResNet for ImageNet/Cifar10/SVHN](ResNet)
+ [Holistically-Nested Edge Detection](HED) + [Holistically-Nested Edge Detection](HED)
+ [Spatial Transformer Networks on MNIST addition](SpatialTransformer) + [Spatial Transformer Networks on MNIST addition](SpatialTransformer)
+ [DisturbLabel, because I don't believe the paper](DisturbLabel) + [DisturbLabel, because I don't believe the paper](DisturbLabel)
......
...@@ -59,9 +59,10 @@ def BatchNorm(x, use_local_stat=None, decay=0.9, epsilon=1e-5): ...@@ -59,9 +59,10 @@ def BatchNorm(x, use_local_stat=None, decay=0.9, epsilon=1e-5):
ctx = get_current_tower_context() ctx = get_current_tower_context()
if use_local_stat is None: if use_local_stat is None:
use_local_stat = ctx.is_training use_local_stat = ctx.is_training
assert use_local_stat == ctx.is_training if use_local_stat != ctx.is_training:
logger.warn("[BatchNorm] use_local_stat != is_training")
if ctx.is_training: if use_local_stat:
# training tower # training tower
with tf.name_scope(None): # https://github.com/tensorflow/tensorflow/issues/2740 with tf.name_scope(None): # https://github.com/tensorflow/tensorflow/issues/2740
ema = tf.train.ExponentialMovingAverage(decay=decay, name=emaname) ema = tf.train.ExponentialMovingAverage(decay=decay, name=emaname)
...@@ -72,7 +73,6 @@ def BatchNorm(x, use_local_stat=None, decay=0.9, epsilon=1e-5): ...@@ -72,7 +73,6 @@ def BatchNorm(x, use_local_stat=None, decay=0.9, epsilon=1e-5):
tf.add_to_collection(EXTRA_SAVE_VARS_KEY, ema_mean) tf.add_to_collection(EXTRA_SAVE_VARS_KEY, ema_mean)
tf.add_to_collection(EXTRA_SAVE_VARS_KEY, ema_var) tf.add_to_collection(EXTRA_SAVE_VARS_KEY, ema_var)
else: else:
assert not use_local_stat
if ctx.is_main_tower: if ctx.is_main_tower:
# not training, but main tower. need to create the vars # not training, but main tower. need to create the vars
with tf.name_scope(None): with tf.name_scope(None):
......
...@@ -159,7 +159,7 @@ class ParamRestore(SessionInit): ...@@ -159,7 +159,7 @@ class ParamRestore(SessionInit):
logger.info("Params to restore: {}".format( logger.info("Params to restore: {}".format(
', '.join(map(str, intersect)))) ', '.join(map(str, intersect))))
for k in variable_names - param_names: for k in variable_names - param_names:
if not is_training_specific_name(k): if not is_training_name(k):
logger.warn("Variable {} in the graph not found in the dict!".format(k)) logger.warn("Variable {} in the graph not found in the dict!".format(k))
for k in param_names - variable_names: for k in param_names - variable_names:
logger.warn("Variable {} in the dict not found in the graph!".format(k)) logger.warn("Variable {} in the dict not found in the graph!".format(k))
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment