Commit 7e91eb48 authored by Yuxin Wu's avatar Yuxin Wu

fix py2/3 compatibilty in char-rnn

parent 475d0d28
...@@ -25,7 +25,7 @@ multiprocess Python program to get a cgroup dedicated for the task. ...@@ -25,7 +25,7 @@ multiprocess Python program to get a cgroup dedicated for the task.
Models are available for the following atari environments (click to watch videos of my agent): Models are available for the following atari environments (click to watch videos of my agent):
+ [AirRaid](https://gym.openai.com/evaluations/eval_zIeNk5MxSGOmvGEUxrZDUw) (this one is flickering, don't know why) + [AirRaid](https://gym.openai.com/evaluations/eval_zIeNk5MxSGOmvGEUxrZDUw) (this one is flickering due to [gym settings](https://github.com/openai/gym/issues/378))
+ [Alien](https://gym.openai.com/evaluations/eval_8NR1IvjTQkSIT6En4xSMA) + [Alien](https://gym.openai.com/evaluations/eval_8NR1IvjTQkSIT6En4xSMA)
+ [Amidar](https://gym.openai.com/evaluations/eval_HwEazbHtTYGpCialv9uPhA) + [Amidar](https://gym.openai.com/evaluations/eval_HwEazbHtTYGpCialv9uPhA)
+ [Assault](https://gym.openai.com/evaluations/eval_tCiHwy5QrSdFVucSbBV6Q) + [Assault](https://gym.openai.com/evaluations/eval_tCiHwy5QrSdFVucSbBV6Q)
......
...@@ -38,11 +38,15 @@ class CharRNNData(RNGDataFlow): ...@@ -38,11 +38,15 @@ class CharRNNData(RNGDataFlow):
logger.info("Loading corpus...") logger.info("Loading corpus...")
# preprocess data # preprocess data
with open(input_file) as f: with open(input_file, 'rb') as f:
data = f.read() data = f.read()
if six.PY2:
data = bytearray(data)
data = [chr(c) for c in data if c < 128] # TODO this is Py2 only
counter = Counter(data) counter = Counter(data)
char_cnt = sorted(counter.items(), key=operator.itemgetter(1), reverse=True) char_cnt = sorted(counter.items(), key=operator.itemgetter(1), reverse=True)
self.chars = [x[0] for x in char_cnt] self.chars = [x[0] for x in char_cnt]
print(sorted(self.chars))
self.vocab_size = len(self.chars) self.vocab_size = len(self.chars)
param.vocab_size = self.vocab_size param.vocab_size = self.vocab_size
self.lut = LookUpTable(self.chars) self.lut = LookUpTable(self.chars)
......
...@@ -33,7 +33,7 @@ class StartProcOrThread(Callback): ...@@ -33,7 +33,7 @@ class StartProcOrThread(Callback):
def _before_train(self): def _before_train(self):
logger.info("Starting " + logger.info("Starting " +
', '.join([k.name for k in self._procs_threads])) ', '.join([k.name for k in self._procs_threads]) + ' ...')
# avoid sigint get handled by other processes # avoid sigint get handled by other processes
start_proc_mask_signal(self._procs_threads) start_proc_mask_signal(self._procs_threads)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment