Commit b7ad58ec authored by Yuxin Wu's avatar Yuxin Wu

bugfix for upstream change of MultiRNNCell

parent d493c30b
...@@ -19,6 +19,7 @@ import tensorpack.tfutils.symbolic_functions as symbf ...@@ -19,6 +19,7 @@ import tensorpack.tfutils.symbolic_functions as symbf
import tensorflow as tf import tensorflow as tf
from timitdata import TIMITBatch from timitdata import TIMITBatch
rnn = tf.contrib.rnn
BATCH = 64 BATCH = 64
...@@ -41,8 +42,7 @@ class Model(ModelDesc): ...@@ -41,8 +42,7 @@ class Model(ModelDesc):
feat, labelidx, labelvalue, labelshape, seqlen = inputs feat, labelidx, labelvalue, labelshape, seqlen = inputs
label = tf.SparseTensor(labelidx, labelvalue, labelshape) label = tf.SparseTensor(labelidx, labelvalue, labelshape)
cell = tf.contrib.rnn.BasicLSTMCell(num_units=HIDDEN) cell = rnn.MultiRNNCell([rnn.LSTMBlockCell(num_units=HIDDEN) for _ in range(NLAYER)])
cell = tf.contrib.rnn.MultiRNNCell([cell] * NLAYER)
initial = cell.zero_state(tf.shape(feat)[0], tf.float32) initial = cell.zero_state(tf.shape(feat)[0], tf.float32)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment