Commit 4b99af0a authored by Patrick Wieschollek's avatar Patrick Wieschollek Committed by Yuxin Wu

Change rnn-cell to fix #103 (#104)

* Change rnn-cell to fix #103
parent c3dc184d
......@@ -40,8 +40,8 @@ class Model(ModelDesc):
feat, labelidx, labelvalue, labelshape, seqlen = input_vars
label = tf.SparseTensor(labelidx, labelvalue, labelshape)
cell = tf.nn.rnn_cell.BasicLSTMCell(num_units=HIDDEN)
cell = tf.nn.rnn_cell.MultiRNNCell([cell] * NLAYER)
cell = tf.contrib.rnn.BasicLSTMCell(num_units=HIDDEN)
cell = tf.contrib.rnn.MultiRNNCell([cell] * NLAYER)
initial = cell.zero_state(tf.shape(feat)[0], tf.float32)
......
......@@ -69,8 +69,8 @@ class Model(ModelDesc):
def _build_graph(self, input_vars):
input, nextinput = input_vars
cell = tf.nn.rnn_cell.BasicLSTMCell(num_units=param.rnn_size)
cell = tf.nn.rnn_cell.MultiRNNCell([cell] * param.num_rnn_layer)
cell = tf.contrib.rnn.BasicLSTMCell(num_units=param.rnn_size)
cell = tf.contrib.rnn.MultiRNNCell([cell] * param.num_rnn_layer)
self.initial = initial = cell.zero_state(tf.shape(input)[0], tf.float32)
......@@ -80,7 +80,7 @@ class Model(ModelDesc):
input_list = tf.unstack(input_feature, axis=1) # seqlen x (Bxrnnsize)
# seqlen is 1 in inference. don't need loop_function
outputs, last_state = tf.nn.rnn(cell, input_list, initial, scope='rnnlm')
outputs, last_state = tf.contrib.rnn.static_rnn(cell, input_list, initial, scope='rnnlm')
self.last_state = tf.identity(last_state, 'last_state')
# seqlen x (Bxrnnsize)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment