Commit e42378ca authored by Yuxin Wu's avatar Yuxin Wu

testdataspeed

parent 49cd9aec
......@@ -13,7 +13,20 @@ from ..utils import *
__all__ = ['BatchData', 'FixedSizeData', 'MapData',
'RepeatedData', 'MapDataComponent', 'RandomChooseData',
'RandomMixData', 'JoinData', 'ConcatData', 'SelectComponent',
'LocallyShuffleData']
'LocallyShuffleData', 'TestDataSpeed']
class TestDataSpeed(ProxyDataFlow):
def __init__(self, ds, size=1000):
super(TestDataSpeed, self).__init__(ds)
self.test_size = size
def get_data(self):
from tqdm import tqdm
with tqdm(range(self.test_size), **get_tqdm_kwargs()) as pbar:
for dp in self.ds.get_data():
pbar.update()
for dp in self.ds.get_data():
yield dp
class BatchData(ProxyDataFlow):
def __init__(self, ds, batch_size, remainder=False):
......
......@@ -104,7 +104,7 @@ def build_multi_tower_prediction_graph(model, towers, prefix='towerp'):
model._build_graph(input_vars, False)
tf.get_variable_scope().reuse_variables()
def MultiTowerOfflinePredictor(OnlinePredictor):
class MultiTowerOfflinePredictor(OnlinePredictor):
PREFIX = 'towerp'
def __init__(self, config, towers):
self.graph = tf.Graph()
......
......@@ -112,6 +112,7 @@ class EnqueueThread(threading.Thread):
def run(self):
self.dataflow.reset_state()
with self.sess.as_default():
try:
while True:
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment