Commit f87b431b authored by Patrick Wieschollek's avatar Patrick Wieschollek Committed by Yuxin Wu

fix bug in boilerplate (#260)

* fix bug in boilerplate, self.cost and enhance FakeData

* fix len_range -> max_range
parent 47992266
......@@ -13,19 +13,21 @@ All code is in this file is the most minimalistic way to solve a deep-learning p
"""
BATCH_SIZE = 16
SHAPE = 28
CHANNELS = 3
class Model(ModelDesc):
def _get_inputs(self):
return [InputDesc(tf.float32, (None, 28, 28, 1), 'input'),
return [InputDesc(tf.float32, (None, SHAPE, SHAPE, CHANNELS), 'input'),
InputDesc(tf.int32, (None,), 'label')]
def _build_graph(self, inputs):
image, label = inputs
image = image * 2 - 1
self.cost = tf.identity(0, name='total_costs')
summary.add_moving_summary(cost)
self.cost = tf.identity(0., name='total_costs')
summary.add_moving_summary(self.cost)
def _get_optimizer(self):
lr = symbolic_functions.get_scalar_var('learning_rate', 5e-3, summary=True)
......@@ -33,9 +35,9 @@ class Model(ModelDesc):
def get_data(subset):
ds = None # something.get_data() that yields [[28, 28, 1], [1]]
# ...
# something that yields [[SHAPE, SHAPE, CHANNELS], [1]]
ds = FakeData([[SHAPE, SHAPE, CHANNELS], [1]], 1000, random=False,
dtype=['float32', 'uint8'], domain=[(0, 255), (0, 10)])
ds = PrefetchDataZMQ(ds, 2)
ds = BatchData(ds, BATCH_SIZE)
return ds
......
......@@ -5,6 +5,7 @@
import numpy as np
import copy
import six
from six.moves import range
from .base import DataFlow, RNGDataFlow
......@@ -14,20 +15,22 @@ __all__ = ['FakeData', 'DataFromQueue', 'DataFromList']
class FakeData(RNGDataFlow):
""" Generate fake data of given shapes"""
def __init__(self, shapes, size=1000, random=True, dtype='float32'):
def __init__(self, shapes, size=1000, random=True, dtype='float32', domain=(0, 1)):
"""
Args:
shapes (list): a list of lists/tuples. Shapes of each component.
size (int): size of this DataFlow.
random (bool): whether to randomly generate data every iteration.
Note that merely generating the data could sometimes be time-consuming!
dtype (str): data type.
dtype (str): data type as string or a list of data types.
domain (str): domain of values as tuple/list.
"""
super(FakeData, self).__init__()
self.shapes = shapes
self._size = int(size)
self.random = random
self.dtype = dtype
self.dtype = [dtype] * len(shapes) if isinstance(dtype, six.string_types) else dtype
self.domain = [domain] * len(shapes) if isinstance(domain, tuple) else domain
def size(self):
return self._size
......@@ -35,11 +38,18 @@ class FakeData(RNGDataFlow):
def get_data(self):
if self.random:
for _ in range(self._size):
yield [self.rng.rand(*k).astype(self.dtype) for k in self.shapes]
val = []
for k in range(len(self.shapes)):
v = self.rng.rand(*self.shapes[k]) * (self.domain[k][1] - self.domain[k][0]) + self.domain[k][0]
val.append(v.astype(self.dtype[k]))
yield val
else:
v = [self.rng.rand(*k).astype(self.dtype) for k in self.shapes]
val = []
for k in range(len(self.shapes)):
v = self.rng.rand(*self.shapes[k]) * (self.domain[k][1] - self.domain[k][0]) + self.domain[k][0]
val.append(v.astype(self.dtype[k]))
for _ in range(self._size):
yield copy.deepcopy(v)
yield copy.deepcopy(val)
class DataFromQueue(DataFlow):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment