Commit 4f1568dc authored by Yuxin Wu's avatar Yuxin Wu

typed fake data

parent 509c2c90
...@@ -11,7 +11,7 @@ __all__ = ['FakeData', 'DataFromQueue', 'DataFromList'] ...@@ -11,7 +11,7 @@ __all__ = ['FakeData', 'DataFromQueue', 'DataFromList']
class FakeData(RNGDataFlow): class FakeData(RNGDataFlow):
""" Generate fake fixed data of given shapes""" """ Generate fake fixed data of given shapes"""
def __init__(self, shapes, size, random=True): def __init__(self, shapes, size, random=True, dtype='float32'):
""" """
:param shapes: a list of lists/tuples :param shapes: a list of lists/tuples
:param size: size of this DataFlow :param size: size of this DataFlow
...@@ -20,6 +20,7 @@ class FakeData(RNGDataFlow): ...@@ -20,6 +20,7 @@ class FakeData(RNGDataFlow):
self.shapes = shapes self.shapes = shapes
self._size = int(size) self._size = int(size)
self.random = random self.random = random
self.dtype = dtype
def size(self): def size(self):
return self._size return self._size
...@@ -27,9 +28,9 @@ class FakeData(RNGDataFlow): ...@@ -27,9 +28,9 @@ class FakeData(RNGDataFlow):
def get_data(self): def get_data(self):
if self.random: if self.random:
for _ in range(self._size): for _ in range(self._size):
yield [self.rng.rand(*k).astype('float32') for k in self.shapes] yield [self.rng.rand(*k).astype(self.dtype) for k in self.shapes]
else: else:
v = [self.rng.rand(*k).astype('float32') for k in self.shapes] v = [self.rng.rand(*k).astype(self.dtype) for k in self.shapes]
for _ in range(self._size): for _ in range(self._size):
yield v yield v
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment