Commit 4e3849e0 authored by Yuxin Wu's avatar Yuxin Wu

[PTB] really allow any number of layers (fix #567)

parent 4de54e62
......@@ -67,9 +67,10 @@ class Model(ModelDesc):
return tf.get_variable(n, [BATCH, HIDDEN_SIZE],
trainable=False,
initializer=tf.constant_initializer())
self.state = state_var = \
(rnn.LSTMStateTuple(get_v('c0'), get_v('h0')),
rnn.LSTMStateTuple(get_v('c1'), get_v('h1')))
state_var = [rnn.LSTMStateTuple(
get_v('c{}'.format(k)), get_v('h{}'.format(k))) for k in range(NUM_LAYER)]
self.state = state_var = tuple(state_var)
embeddingW = tf.get_variable('embedding', [VOCAB_SIZE, HIDDEN_SIZE], initializer=initializer)
input_feature = tf.nn.embedding_lookup(embeddingW, input) # B x seqlen x hiddensize
......@@ -102,10 +103,11 @@ class Model(ModelDesc):
def reset_lstm_state(self):
s = self.state
z = tf.zeros_like(s[0].c)
return tf.group(s[0].c.assign(z),
s[0].h.assign(z),
s[1].c.assign(z),
s[1].h.assign(z), name='reset_lstm_state')
ops = []
for k in range(NUM_LAYER):
ops.append(s[k].c.assign(z))
ops.append(s[k].h.assign(z))
return tf.group(*ops, name='reset_lstm_state')
def _get_optimizer(self):
lr = tf.get_variable('learning_rate', initializer=1.0, trainable=False)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment