Commit 744defbe authored by Yuxin Wu's avatar Yuxin Wu

bug fix in dataflow

parent 4f1568dc
......@@ -91,8 +91,9 @@ class Model(ModelDesc):
# seqlen is 1 in inference. don't need loop_function
outputs, last_state = rnn.rnn(cell, input_list, initial, scope='rnnlm')
self.last_state = tf.identity(last_state, 'last_state')
# seqlen x (Bxrnnsize)
output = tf.reshape(tf.concat(1, outputs), [-1, param.rnn_size]) # (seqlenxB) x rnnsize
output = tf.reshape(tf.concat(1, outputs), [-1, param.rnn_size]) # (Bxseqlen) x rnnsize
logits = FullyConnected('fc', output, param.vocab_size, nl=tf.identity)
self.prob = tf.nn.softmax(logits / param.softmax_temprature)
......
......@@ -7,7 +7,7 @@
from abc import abstractmethod, ABCMeta
from ..utils import get_rng
__all__ = ['DataFlow', 'ProxyDataFlow']
__all__ = ['DataFlow', 'ProxyDataFlow', 'RNGDataFlow']
class DataFlow(object):
""" Base class for all DataFlow """
......
......@@ -28,6 +28,11 @@ class TestDataSpeed(ProxyDataFlow):
for dp in self.ds.get_data():
yield dp
def start_test(self):
self.ds.reset_state()
for k in self.get_data():
pass
class BatchData(ProxyDataFlow):
def __init__(self, ds, batch_size, remainder=False):
"""
......
......@@ -6,6 +6,7 @@
import six
import tensorflow as tf
import numpy as np
from ..utils import logger
__all__ = ['SessionUpdate', 'dump_session_params', 'dump_chkpt_vars']
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment