Commit e1b19c5d authored by ppwwyyxx's avatar ppwwyyxx

add fakedata

parent 20becf84
#!/usr/bin/env python2 #!/usr/bin/env python2
# -*- coding: UTF-8 -*- # -*- coding: UTF-8 -*-
# File: batch.py # File: common.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com> # Author: Yuxin Wu <ppwwyyxx@gmail.com>
import numpy as np import numpy as np
from .base import DataFlow from .base import DataFlow
__all__ = ['BatchData', 'FixedSizeData'] __all__ = ['BatchData', 'FixedSizeData', 'FakeData']
class BatchData(DataFlow): class BatchData(DataFlow):
def __init__(self, ds, batch_size, remainder=False): def __init__(self, ds, batch_size, remainder=False):
...@@ -65,3 +65,16 @@ class FixedSizeData(DataFlow): ...@@ -65,3 +65,16 @@ class FixedSizeData(DataFlow):
if cnt == self._size: if cnt == self._size:
return return
class FakeData(DataFlow):
""" Build fake data of given shapes"""
def __init__(self, shapes, size):
self.shapes = shapes
self._size = size
def size(self):
return self._size
def get_data(self):
for _ in xrange(self._size):
yield tuple((np.random.random(k) for k in self.shapes))
...@@ -44,6 +44,7 @@ class EnqueueThread(threading.Thread): ...@@ -44,6 +44,7 @@ class EnqueueThread(threading.Thread):
except tf.errors.CancelledError as e: except tf.errors.CancelledError as e:
pass pass
except Exception: except Exception:
# TODO close queue.
logger.exception("Exception in EnqueueThread:") logger.exception("Exception in EnqueueThread:")
self.coord.request_stop() self.coord.request_stop()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment