Commit f87b431b authored by Patrick Wieschollek's avatar Patrick Wieschollek Committed by Yuxin Wu

fix bug in boilerplate (#260)

* fix bug in boilerplate, self.cost and enhance FakeData

* fix len_range -> max_range
parent 47992266
...@@ -13,19 +13,21 @@ All code is in this file is the most minimalistic way to solve a deep-learning p ...@@ -13,19 +13,21 @@ All code is in this file is the most minimalistic way to solve a deep-learning p
""" """
BATCH_SIZE = 16 BATCH_SIZE = 16
SHAPE = 28
CHANNELS = 3
class Model(ModelDesc): class Model(ModelDesc):
def _get_inputs(self): def _get_inputs(self):
return [InputDesc(tf.float32, (None, 28, 28, 1), 'input'), return [InputDesc(tf.float32, (None, SHAPE, SHAPE, CHANNELS), 'input'),
InputDesc(tf.int32, (None,), 'label')] InputDesc(tf.int32, (None,), 'label')]
def _build_graph(self, inputs): def _build_graph(self, inputs):
image, label = inputs image, label = inputs
image = image * 2 - 1 image = image * 2 - 1
self.cost = tf.identity(0, name='total_costs') self.cost = tf.identity(0., name='total_costs')
summary.add_moving_summary(cost) summary.add_moving_summary(self.cost)
def _get_optimizer(self): def _get_optimizer(self):
lr = symbolic_functions.get_scalar_var('learning_rate', 5e-3, summary=True) lr = symbolic_functions.get_scalar_var('learning_rate', 5e-3, summary=True)
...@@ -33,9 +35,9 @@ class Model(ModelDesc): ...@@ -33,9 +35,9 @@ class Model(ModelDesc):
def get_data(subset): def get_data(subset):
ds = None # something.get_data() that yields [[28, 28, 1], [1]] # something that yields [[SHAPE, SHAPE, CHANNELS], [1]]
# ... ds = FakeData([[SHAPE, SHAPE, CHANNELS], [1]], 1000, random=False,
dtype=['float32', 'uint8'], domain=[(0, 255), (0, 10)])
ds = PrefetchDataZMQ(ds, 2) ds = PrefetchDataZMQ(ds, 2)
ds = BatchData(ds, BATCH_SIZE) ds = BatchData(ds, BATCH_SIZE)
return ds return ds
......
...@@ -5,6 +5,7 @@ ...@@ -5,6 +5,7 @@
import numpy as np import numpy as np
import copy import copy
import six
from six.moves import range from six.moves import range
from .base import DataFlow, RNGDataFlow from .base import DataFlow, RNGDataFlow
...@@ -14,20 +15,22 @@ __all__ = ['FakeData', 'DataFromQueue', 'DataFromList'] ...@@ -14,20 +15,22 @@ __all__ = ['FakeData', 'DataFromQueue', 'DataFromList']
class FakeData(RNGDataFlow): class FakeData(RNGDataFlow):
""" Generate fake data of given shapes""" """ Generate fake data of given shapes"""
def __init__(self, shapes, size=1000, random=True, dtype='float32'): def __init__(self, shapes, size=1000, random=True, dtype='float32', domain=(0, 1)):
""" """
Args: Args:
shapes (list): a list of lists/tuples. Shapes of each component. shapes (list): a list of lists/tuples. Shapes of each component.
size (int): size of this DataFlow. size (int): size of this DataFlow.
random (bool): whether to randomly generate data every iteration. random (bool): whether to randomly generate data every iteration.
Note that merely generating the data could sometimes be time-consuming! Note that merely generating the data could sometimes be time-consuming!
dtype (str): data type. dtype (str): data type as string or a list of data types.
domain (str): domain of values as tuple/list.
""" """
super(FakeData, self).__init__() super(FakeData, self).__init__()
self.shapes = shapes self.shapes = shapes
self._size = int(size) self._size = int(size)
self.random = random self.random = random
self.dtype = dtype self.dtype = [dtype] * len(shapes) if isinstance(dtype, six.string_types) else dtype
self.domain = [domain] * len(shapes) if isinstance(domain, tuple) else domain
def size(self): def size(self):
return self._size return self._size
...@@ -35,11 +38,18 @@ class FakeData(RNGDataFlow): ...@@ -35,11 +38,18 @@ class FakeData(RNGDataFlow):
def get_data(self): def get_data(self):
if self.random: if self.random:
for _ in range(self._size): for _ in range(self._size):
yield [self.rng.rand(*k).astype(self.dtype) for k in self.shapes] val = []
for k in range(len(self.shapes)):
v = self.rng.rand(*self.shapes[k]) * (self.domain[k][1] - self.domain[k][0]) + self.domain[k][0]
val.append(v.astype(self.dtype[k]))
yield val
else: else:
v = [self.rng.rand(*k).astype(self.dtype) for k in self.shapes] val = []
for k in range(len(self.shapes)):
v = self.rng.rand(*self.shapes[k]) * (self.domain[k][1] - self.domain[k][0]) + self.domain[k][0]
val.append(v.astype(self.dtype[k]))
for _ in range(self._size): for _ in range(self._size):
yield copy.deepcopy(v) yield copy.deepcopy(val)
class DataFromQueue(DataFlow): class DataFromQueue(DataFlow):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment