Commit 744defbe authored by Yuxin Wu's avatar Yuxin Wu

bug fix in dataflow

parent 4f1568dc
...@@ -91,8 +91,9 @@ class Model(ModelDesc): ...@@ -91,8 +91,9 @@ class Model(ModelDesc):
# seqlen is 1 in inference. don't need loop_function # seqlen is 1 in inference. don't need loop_function
outputs, last_state = rnn.rnn(cell, input_list, initial, scope='rnnlm') outputs, last_state = rnn.rnn(cell, input_list, initial, scope='rnnlm')
self.last_state = tf.identity(last_state, 'last_state') self.last_state = tf.identity(last_state, 'last_state')
# seqlen x (Bxrnnsize) # seqlen x (Bxrnnsize)
output = tf.reshape(tf.concat(1, outputs), [-1, param.rnn_size]) # (seqlenxB) x rnnsize output = tf.reshape(tf.concat(1, outputs), [-1, param.rnn_size]) # (Bxseqlen) x rnnsize
logits = FullyConnected('fc', output, param.vocab_size, nl=tf.identity) logits = FullyConnected('fc', output, param.vocab_size, nl=tf.identity)
self.prob = tf.nn.softmax(logits / param.softmax_temprature) self.prob = tf.nn.softmax(logits / param.softmax_temprature)
......
...@@ -7,7 +7,7 @@ ...@@ -7,7 +7,7 @@
from abc import abstractmethod, ABCMeta from abc import abstractmethod, ABCMeta
from ..utils import get_rng from ..utils import get_rng
__all__ = ['DataFlow', 'ProxyDataFlow'] __all__ = ['DataFlow', 'ProxyDataFlow', 'RNGDataFlow']
class DataFlow(object): class DataFlow(object):
""" Base class for all DataFlow """ """ Base class for all DataFlow """
......
...@@ -28,6 +28,11 @@ class TestDataSpeed(ProxyDataFlow): ...@@ -28,6 +28,11 @@ class TestDataSpeed(ProxyDataFlow):
for dp in self.ds.get_data(): for dp in self.ds.get_data():
yield dp yield dp
def start_test(self):
self.ds.reset_state()
for k in self.get_data():
pass
class BatchData(ProxyDataFlow): class BatchData(ProxyDataFlow):
def __init__(self, ds, batch_size, remainder=False): def __init__(self, ds, batch_size, remainder=False):
""" """
......
...@@ -6,6 +6,7 @@ ...@@ -6,6 +6,7 @@
import six import six
import tensorflow as tf import tensorflow as tf
import numpy as np import numpy as np
from ..utils import logger
__all__ = ['SessionUpdate', 'dump_session_params', 'dump_chkpt_vars'] __all__ = ['SessionUpdate', 'dump_session_params', 'dump_chkpt_vars']
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment