Commit b7ad58ec authored by Yuxin Wu's avatar Yuxin Wu

bugfix for upstream change of MultiRNNCell

parent d493c30b
......@@ -19,6 +19,7 @@ import tensorpack.tfutils.symbolic_functions as symbf
import tensorflow as tf
from timitdata import TIMITBatch
rnn = tf.contrib.rnn
BATCH = 64
......@@ -41,8 +42,7 @@ class Model(ModelDesc):
feat, labelidx, labelvalue, labelshape, seqlen = inputs
label = tf.SparseTensor(labelidx, labelvalue, labelshape)
cell = tf.contrib.rnn.BasicLSTMCell(num_units=HIDDEN)
cell = tf.contrib.rnn.MultiRNNCell([cell] * NLAYER)
cell = rnn.MultiRNNCell([rnn.LSTMBlockCell(num_units=HIDDEN) for _ in range(NLAYER)])
initial = cell.zero_state(tf.shape(feat)[0], tf.float32)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment