Commit 4f1568dc authored by Yuxin Wu's avatar Yuxin Wu

typed fake data

parent 509c2c90
......@@ -11,7 +11,7 @@ __all__ = ['FakeData', 'DataFromQueue', 'DataFromList']
class FakeData(RNGDataFlow):
""" Generate fake fixed data of given shapes"""
def __init__(self, shapes, size, random=True):
def __init__(self, shapes, size, random=True, dtype='float32'):
"""
:param shapes: a list of lists/tuples
:param size: size of this DataFlow
......@@ -20,6 +20,7 @@ class FakeData(RNGDataFlow):
self.shapes = shapes
self._size = int(size)
self.random = random
self.dtype = dtype
def size(self):
return self._size
......@@ -27,9 +28,9 @@ class FakeData(RNGDataFlow):
def get_data(self):
if self.random:
for _ in range(self._size):
yield [self.rng.rand(*k).astype('float32') for k in self.shapes]
yield [self.rng.rand(*k).astype(self.dtype) for k in self.shapes]
else:
v = [self.rng.rand(*k).astype('float32') for k in self.shapes]
v = [self.rng.rand(*k).astype(self.dtype) for k in self.shapes]
for _ in range(self._size):
yield v
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment