Commit a9a3b7d1 authored by Yuxin Wu's avatar Yuxin Wu

lmdb and discretizer

parent 5683e4d9
...@@ -7,17 +7,20 @@ Reproduce the following methods: ...@@ -7,17 +7,20 @@ Reproduce the following methods:
[Deep Reinforcement Learning with Double Q-learning](http://arxiv.org/abs/1509.06461) [Deep Reinforcement Learning with Double Q-learning](http://arxiv.org/abs/1509.06461)
+ A3C in [Asynchronous Methods for Deep Reinforcement Learning](http://arxiv.org/abs/1602.01783). (I + A3C in [Asynchronous Methods for Deep Reinforcement Learning](http://arxiv.org/abs/1602.01783). (I
used a modified version where each batch contains transitions from different simulators, which I called "Batch A3C".) used a modified version where each batch contains transitions from different simulators, which I called "Batch-A3C".)
Claimed performance in the paper can be reproduced, on several games I've tested with. Claimed performance in the paper can be reproduced, on several games I've tested with.
![DQN](curve-breakout.png) ![DQN](curve-breakout.png)
A demo trained with Double-DQN on breakout game is available at [youtube](https://youtu.be/o21mddZtE5Y). DQN was trained on 1 GPU and it typically took 2~3 days of training to reach a score of 400 on breakout game.
My Batch-A3C implementation only took <2 hours with 2 GPUs (one for training and one for simulation).
DQN would typically take 2~3 days of training to reach a score of 400 on breakout, but my A3C implementation only takes <2 hours on 1 GPU.
This is probably the fastest RL trainer you'd find. This is probably the fastest RL trainer you'd find.
The x-axis is the number of iterations not wall time. The iteration speed is 6.7it/s for B-A3C and 7.3it/s for D-DQN.
A demo trained with Double-DQN on breakout is available at [youtube](https://youtu.be/o21mddZtE5Y).
## How to use ## How to use
Download [atari roms](https://github.com/openai/atari-py/tree/master/atari_py/atari_roms) to Download [atari roms](https://github.com/openai/atari-py/tree/master/atari_py/atari_roms) to
...@@ -29,7 +32,6 @@ To train: ...@@ -29,7 +32,6 @@ To train:
``` ```
Training speed is about 7.3 iteration/s on 1 Tesla M40 Training speed is about 7.3 iteration/s on 1 Tesla M40
(faster than this at the beginning, but will slow down due to exploration annealing). (faster than this at the beginning, but will slow down due to exploration annealing).
It takes days to learn well (see figure above).
To visualize the agent: To visualize the agent:
``` ```
......
...@@ -100,19 +100,27 @@ class LMDBData(RNGDataFlow): ...@@ -100,19 +100,27 @@ class LMDBData(RNGDataFlow):
v = self._txn.get(k) v = self._txn.get(k)
yield [k, v] yield [k, v]
class CaffeLMDB(LMDBData): class LMDBDataDecoder(LMDBData):
""" Read a Caffe LMDB file where each value contains a caffe.Datum protobuf """ def __init__(self, lmdb_dir, decoder, shuffle=True):
def __init__(self, lmdb_dir, shuffle=True):
""" """
:param shuffle: about 3 times slower :param decoder: a function taking k, v and return a data point,
or return None to skip
""" """
super(CaffeLMDB, self).__init__(lmdb_dir, shuffle) super(LMDBDataDecoder, self).__init__(lmdb_dir, shuffle)
self.cpb = get_caffe_pb() self.decoder = decoder
def get_data(self): def get_data(self):
datum = self.cpb.Datum() for dp in super(LMDBDataDecoder, self).get_data():
def parse(k, v): v = self.decoder(dp[0], dp[1])
if v: yield v
class CaffeLMDB(LMDBDataDecoder):
""" Read a Caffe LMDB file where each value contains a caffe.Datum protobuf """
def __init__(self, lmdb_dir, shuffle=True):
cpb = get_caffe_pb()
def decoder(k, v):
try: try:
datum = cpb.Datum()
datum.ParseFromString(v) datum.ParseFromString(v)
img = np.fromstring(datum.data, dtype=np.uint8) img = np.fromstring(datum.data, dtype=np.uint8)
img = img.reshape(datum.channels, datum.height, datum.width) img = img.reshape(datum.channels, datum.height, datum.width)
...@@ -121,6 +129,5 @@ class CaffeLMDB(LMDBData): ...@@ -121,6 +129,5 @@ class CaffeLMDB(LMDBData):
return None return None
return [img.transpose(1, 2, 0), datum.label] return [img.transpose(1, 2, 0), datum.label]
for dp in super(CaffeLMDB, self).get_data(): super(CaffeLMDB, self).__init__(
v = parse(dp[0], dp[1]) lmdb_dir, decoder=decoder, shuffle=shuffle)
if v: yield v
...@@ -118,7 +118,7 @@ class QueueInputTrainer(Trainer): ...@@ -118,7 +118,7 @@ class QueueInputTrainer(Trainer):
# use a smaller queue size for now, to avoid https://github.com/tensorflow/tensorflow/issues/2942 # use a smaller queue size for now, to avoid https://github.com/tensorflow/tensorflow/issues/2942
if input_queue is None: if input_queue is None:
self.input_queue = tf.FIFOQueue( self.input_queue = tf.FIFOQueue(
30, [x.dtype for x in self.input_vars], name='input_queue') 50, [x.dtype for x in self.input_vars], name='input_queue')
else: else:
self.input_queue = input_queue self.input_queue = input_queue
if predict_tower is None: if predict_tower is None:
......
...@@ -92,14 +92,9 @@ class UniformDiscretizerND(Discretizer): ...@@ -92,14 +92,9 @@ class UniformDiscretizerND(Discretizer):
def get_bin(self, v): def get_bin(self, v):
assert len(v) == self.n assert len(v) == self.n
bin_id = [self.discretizers[k].get_bin(v[k]) for k in range(self.n)] bin_id = [self.discretizers[k].get_bin(v[k]) for k in range(self.n)]
return self.get_bin_from_nd_bin_ids(bin_id)
acc, res = 1, 0 def get_nd_bin_ids(self, bin_id):
for k in reversed(list(range(self.n))):
res += bin_id[k] * acc
acc *= self.nr_bins[k]
return res
def _get_bin_id_nd(self, bin_id):
ret = [] ret = []
for k in reversed(list(range(self.n))): for k in reversed(list(range(self.n))):
nr = self.nr_bins[k] nr = self.nr_bins[k]
...@@ -108,8 +103,18 @@ class UniformDiscretizerND(Discretizer): ...@@ -108,8 +103,18 @@ class UniformDiscretizerND(Discretizer):
ret.append(v) ret.append(v)
return list(reversed(ret)) return list(reversed(ret))
def get_bin_from_nd_bin_ids(self, bin_ids):
acc, res = 1, 0
for k in reversed(list(range(self.n))):
res += bin_ids[k] * acc
acc *= self.nr_bins[k]
return res
def get_nr_bin_nd(self):
return self.nr_bins
def get_bin_center(self, bin_id): def get_bin_center(self, bin_id):
bin_id_nd = self._get_bin_id_nd(bin_id) bin_id_nd = self.get_nd_bin_ids(bin_id)
return [self.discretizers[k].get_bin_center(bin_id_nd[k]) for k in range(self.n)] return [self.discretizers[k].get_bin_center(bin_id_nd[k]) for k in range(self.n)]
if __name__ == '__main__': if __name__ == '__main__':
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment