Commit c920872c authored by Yuxin Wu's avatar Yuxin Wu

fix char-rnn for changes in TF API

parent ff9e5555
...@@ -73,8 +73,8 @@ class Model(ModelDesc): ...@@ -73,8 +73,8 @@ class Model(ModelDesc):
def _build_graph(self, inputs): def _build_graph(self, inputs):
input, nextinput = inputs input, nextinput = inputs
cell = rnn.BasicLSTMCell(num_units=param.rnn_size) cell = rnn.MultiRNNCell([rnn.BasicLSTMCell(num_units=param.rnn_size)
cell = rnn.MultiRNNCell([cell] * param.num_rnn_layer) for _ in range(param.num_rnn_layer)])
def get_v(n): def get_v(n):
ret = tf.get_variable(n + '_unused', [param.batch_size, param.rnn_size], ret = tf.get_variable(n + '_unused', [param.batch_size, param.rnn_size],
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment