Commit a9a3b7d1 authored by Yuxin Wu's avatar Yuxin Wu

lmdb and discretizer

parent 5683e4d9
......@@ -7,17 +7,20 @@ Reproduce the following methods:
[Deep Reinforcement Learning with Double Q-learning](http://arxiv.org/abs/1509.06461)
+ A3C in [Asynchronous Methods for Deep Reinforcement Learning](http://arxiv.org/abs/1602.01783). (I
used a modified version where each batch contains transitions from different simulators, which I called "Batch A3C".)
used a modified version where each batch contains transitions from different simulators, which I called "Batch-A3C".)
Claimed performance in the paper can be reproduced, on several games I've tested with.
![DQN](curve-breakout.png)
A demo trained with Double-DQN on breakout game is available at [youtube](https://youtu.be/o21mddZtE5Y).
DQN would typically take 2~3 days of training to reach a score of 400 on breakout, but my A3C implementation only takes <2 hours on 1 GPU.
DQN was trained on 1 GPU and it typically took 2~3 days of training to reach a score of 400 on breakout game.
My Batch-A3C implementation only took <2 hours with 2 GPUs (one for training and one for simulation).
This is probably the fastest RL trainer you'd find.
The x-axis is the number of iterations not wall time. The iteration speed is 6.7it/s for B-A3C and 7.3it/s for D-DQN.
A demo trained with Double-DQN on breakout is available at [youtube](https://youtu.be/o21mddZtE5Y).
## How to use
Download [atari roms](https://github.com/openai/atari-py/tree/master/atari_py/atari_roms) to
......@@ -29,7 +32,6 @@ To train:
```
Training speed is about 7.3 iteration/s on 1 Tesla M40
(faster than this at the beginning, but will slow down due to exploration annealing).
It takes days to learn well (see figure above).
To visualize the agent:
```
......
......@@ -100,19 +100,27 @@ class LMDBData(RNGDataFlow):
v = self._txn.get(k)
yield [k, v]
class CaffeLMDB(LMDBData):
""" Read a Caffe LMDB file where each value contains a caffe.Datum protobuf """
def __init__(self, lmdb_dir, shuffle=True):
class LMDBDataDecoder(LMDBData):
def __init__(self, lmdb_dir, decoder, shuffle=True):
"""
:param shuffle: about 3 times slower
:param decoder: a function taking k, v and return a data point,
or return None to skip
"""
super(CaffeLMDB, self).__init__(lmdb_dir, shuffle)
self.cpb = get_caffe_pb()
super(LMDBDataDecoder, self).__init__(lmdb_dir, shuffle)
self.decoder = decoder
def get_data(self):
datum = self.cpb.Datum()
def parse(k, v):
for dp in super(LMDBDataDecoder, self).get_data():
v = self.decoder(dp[0], dp[1])
if v: yield v
class CaffeLMDB(LMDBDataDecoder):
""" Read a Caffe LMDB file where each value contains a caffe.Datum protobuf """
def __init__(self, lmdb_dir, shuffle=True):
cpb = get_caffe_pb()
def decoder(k, v):
try:
datum = cpb.Datum()
datum.ParseFromString(v)
img = np.fromstring(datum.data, dtype=np.uint8)
img = img.reshape(datum.channels, datum.height, datum.width)
......@@ -121,6 +129,5 @@ class CaffeLMDB(LMDBData):
return None
return [img.transpose(1, 2, 0), datum.label]
for dp in super(CaffeLMDB, self).get_data():
v = parse(dp[0], dp[1])
if v: yield v
super(CaffeLMDB, self).__init__(
lmdb_dir, decoder=decoder, shuffle=shuffle)
......@@ -118,7 +118,7 @@ class QueueInputTrainer(Trainer):
# use a smaller queue size for now, to avoid https://github.com/tensorflow/tensorflow/issues/2942
if input_queue is None:
self.input_queue = tf.FIFOQueue(
30, [x.dtype for x in self.input_vars], name='input_queue')
50, [x.dtype for x in self.input_vars], name='input_queue')
else:
self.input_queue = input_queue
if predict_tower is None:
......
......@@ -92,14 +92,9 @@ class UniformDiscretizerND(Discretizer):
def get_bin(self, v):
assert len(v) == self.n
bin_id = [self.discretizers[k].get_bin(v[k]) for k in range(self.n)]
return self.get_bin_from_nd_bin_ids(bin_id)
acc, res = 1, 0
for k in reversed(list(range(self.n))):
res += bin_id[k] * acc
acc *= self.nr_bins[k]
return res
def _get_bin_id_nd(self, bin_id):
def get_nd_bin_ids(self, bin_id):
ret = []
for k in reversed(list(range(self.n))):
nr = self.nr_bins[k]
......@@ -108,8 +103,18 @@ class UniformDiscretizerND(Discretizer):
ret.append(v)
return list(reversed(ret))
def get_bin_from_nd_bin_ids(self, bin_ids):
acc, res = 1, 0
for k in reversed(list(range(self.n))):
res += bin_ids[k] * acc
acc *= self.nr_bins[k]
return res
def get_nr_bin_nd(self):
return self.nr_bins
def get_bin_center(self, bin_id):
bin_id_nd = self._get_bin_id_nd(bin_id)
bin_id_nd = self.get_nd_bin_ids(bin_id)
return [self.discretizers[k].get_bin_center(bin_id_nd[k]) for k in range(self.n)]
if __name__ == '__main__':
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment