Commit 01cab873 authored by Yuxin Wu's avatar Yuxin Wu

bugfix in char-rnn example (fix #323)

parent 8faea40d
......@@ -49,8 +49,8 @@ class CharRNNData(RNGDataFlow):
print(sorted(self.chars))
self.vocab_size = len(self.chars)
param.vocab_size = self.vocab_size
char2idx = {c: i for i, c in enumerate(self.chars)}
self.whole_seq = np.array([char2idx[c] for c in data], dtype='int32')
self.char2idx = {c: i for i, c in enumerate(self.chars)}
self.whole_seq = np.array([self.char2idx[c] for c in data], dtype='int32')
logger.info("Corpus loaded. Vocab size: {}".format(self.vocab_size))
def size(self):
......@@ -146,7 +146,7 @@ def sample(path, start, length):
# feed the starting sentence
initial = np.zeros((1, param.rnn_size))
for c in start[:-1]:
x = np.array([[ds.lut.get_idx(c)]], dtype='int32')
x = np.array([[ds.char2idx[c]]], dtype='int32')
_, state = pred(x, initial, initial, initial, initial)
def pick(prob):
......@@ -158,9 +158,9 @@ def sample(path, start, length):
ret = start
c = start[-1]
for k in range(length):
x = np.array([[ds.lut.get_idx(c)]], dtype='int32')
x = np.array([[ds.char2idx[c]]], dtype='int32')
prob, state = pred(x, state[0, 0], state[0, 1], state[1, 0], state[1, 1])
c = ds.lut.get_obj(pick(prob[0]))
c = ds.chars[pick(prob[0])]
ret += c
print(ret)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment