Commit 83686e04 authored by ppwwyyxx's avatar ppwwyyxx

add convnet

parent 12f2866e
...@@ -23,7 +23,7 @@ class Mnist(object): ...@@ -23,7 +23,7 @@ class Mnist(object):
def get_data(self): def get_data(self):
ds = self.dataset.train if self.train_or_test == 'train' else self.dataset.test ds = self.dataset.train if self.train_or_test == 'train' else self.dataset.test
for k in xrange(ds.num_examples): for k in xrange(ds.num_examples):
img = ds.images[k] img = ds.images[k].reshape((28, 28))
label = ds.labels[k] label = ds.labels[k]
yield (img, label) yield (img, label)
......
...@@ -14,7 +14,6 @@ from dataflow.dataset import Mnist ...@@ -14,7 +14,6 @@ from dataflow.dataset import Mnist
from dataflow import * from dataflow import *
IMAGE_SIZE = 28 IMAGE_SIZE = 28
PIXELS = IMAGE_SIZE * IMAGE_SIZE
NUM_CLASS = 10 NUM_CLASS = 10
batch_size = 128 batch_size = 128
LOG_DIR = 'train_log' LOG_DIR = 'train_log'
...@@ -22,13 +21,18 @@ LOG_DIR = 'train_log' ...@@ -22,13 +21,18 @@ LOG_DIR = 'train_log'
def get_model(input, label): def get_model(input, label):
""" """
Args: Args:
input: bxPIXELS input: bx28x28
label: bx1 integer label: bx1 integer
Returns: Returns:
(output, cost) (output, cost)
output: variable output: variable
cost: scalar variable cost: scalar variable
""" """
input = tf.reshape(input, [-1, 28, 28, 1])
conv = Conv2D('conv0', input, out_channel=20, kernel_shape=3,
padding='same')
input = tf.reshape(input, [-1, 28 * 28])
fc0 = FullyConnected('fc0', input, 200) fc0 = FullyConnected('fc0', input, 200)
fc0 = tf.nn.relu(fc0) fc0 = tf.nn.relu(fc0)
fc1 = FullyConnected('fc1', fc0, out_dim=200) fc1 = FullyConnected('fc1', fc0, out_dim=200)
...@@ -55,7 +59,7 @@ def main(): ...@@ -55,7 +59,7 @@ def main():
] ]
with tf.Graph().as_default(): with tf.Graph().as_default():
input_var = tf.placeholder(tf.float32, shape=(None, PIXELS), name='input') input_var = tf.placeholder(tf.float32, shape=(None, IMAGE_SIZE, IMAGE_SIZE), name='input')
label_var = tf.placeholder(tf.int32, shape=(None,), name='label') label_var = tf.placeholder(tf.int32, shape=(None,), name='label')
prob, cost = get_model(input_var, label_var) prob, cost = get_model(input_var, label_var)
......
...@@ -3,7 +3,6 @@ ...@@ -3,7 +3,6 @@
# File: _common.py # File: _common.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com> # Author: Yuxin Wu <ppwwyyxx@gmail.com>
__all__ = ['layer_register']
import tensorflow as tf import tensorflow as tf
def layer_register(): def layer_register():
...@@ -13,7 +12,7 @@ def layer_register(): ...@@ -13,7 +12,7 @@ def layer_register():
assert isinstance(name, basestring) assert isinstance(name, basestring)
args = args[1:] args = args[1:]
with tf.name_scope(name): with tf.variable_scope(name):
return func(*args, **kwargs) return func(*args, **kwargs)
return inner return inner
return wrapper return wrapper
#!/usr/bin/env python2
# -*- coding: UTF-8 -*-
# File: conv2d.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
import tensorflow as tf
import math
from ._common import layer_register
__all__ = ['Conv2D']
@layer_register()
def Conv2D(x, out_channel, kernel_shape,
padding='VALID', stride=None,
W_init=None, b_init=None):
"""
kernel_shape: (h, w) or a int
stride: (h, w) or a int
padding: 'valid' or 'same'
"""
in_shape = x.get_shape().as_list()
in_channel = in_shape[-1]
if type(kernel_shape) == int:
kernel_shape = [kernel_shape, kernel_shape]
padding = padding.upper()
filter_shape = kernel_shape + [in_channel, out_channel]
if stride is None:
stride = [1, 1, 1, 1]
elif type(stride) == int:
stride = [1, stride, stride, 1]
elif type(stride) in [list, tuple]:
assert len(stride) == 2
stride = [1] + list(stride) + [1]
if W_init is None:
W_init = tf.truncated_normal_initializer(stddev=0.04)
if b_init is None:
b_init = tf.constant_initializer()
W = tf.get_variable('W', filter_shape, initializer=W_init) # TODO collections
b = tf.get_variable('b', [out_channel], initializer=b_init)
conv = tf.nn.conv2d(x, W, stride, padding)
return tf.nn.bias_add(conv, b)
...@@ -18,11 +18,10 @@ def FullyConnected(x, out_dim, W_init=None, b_init=None): ...@@ -18,11 +18,10 @@ def FullyConnected(x, out_dim, W_init=None, b_init=None):
in_dim = x.get_shape().as_list()[1] in_dim = x.get_shape().as_list()[1]
if W_init is None: if W_init is None:
W_init = lambda shape: tf.truncated_normal( W_init = tf.truncated_normal_initializer(stddev=1.0 / math.sqrt(float(in_dim)))
shape, stddev=1.0 / math.sqrt(float(in_dim)))
if b_init is None: if b_init is None:
b_init = tf.zeros b_init = tf.constant_initializer()
W = tf.Variable(W_init([in_dim, out_dim]), name='W') W = tf.get_variable('W', [in_dim, out_dim], initializer=W_init) # TODO collections
b = tf.Variable(b_init([out_dim]), name='b') b = tf.get_variable('b', [out_dim], initializer=b_init)
return tf.matmul(x, W) + b return tf.matmul(x, W) + b
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment