Commit 5fab58f3 authored by ppwwyyxx's avatar ppwwyyxx

fix

parent 88d607d8
......@@ -16,7 +16,7 @@ class Mnist(object):
train_or_test: string either 'train' or 'test'
"""
if dir is None:
dir = os.path.join(os.path.dirname(__file__), 'mnist')
dir = os.path.join(os.path.dirname(__file__), 'mnist_data')
self.dataset = input_data.read_data_sets(dir)
self.train_or_test = train_or_test
......@@ -28,5 +28,7 @@ class Mnist(object):
yield (img, label)
if __name__ == '__main__':
ds = Mnist()
ds.get_data()
ds = Mnist('train')
for (img, label) in ds.get_data():
from IPython import embed; embed()
......@@ -4,10 +4,10 @@
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
# HACK protobuf
#import sys
#import os
#sys.path.insert(0, os.path.expanduser('~/.local/lib/python2.7/site-packages'))
# prefer protobuf in user-namespace
import sys
import os
sys.path.insert(0, os.path.expanduser('~/.local/lib/python2.7/site-packages'))
import tensorflow as tf
import numpy as np
......@@ -35,21 +35,23 @@ def get_model(input, label):
cost: scalar variable
"""
input = tf.reshape(input, [-1, IMAGE_SIZE, IMAGE_SIZE, 1])
conv0 = Conv2D('conv0', input, out_channel=5, kernel_shape=3,
conv0 = Conv2D('conv0', input, out_channel=20, kernel_shape=5,
padding='valid')
pool0 = tf.nn.max_pool(conv0, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1],
padding='SAME')
conv1 = Conv2D('conv1', pool0, out_channel=10, kernel_shape=4,
conv1 = Conv2D('conv1', pool0, out_channel=40, kernel_shape=3,
padding='valid')
pool1 = tf.nn.max_pool(conv1, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1],
padding='SAME')
conv2 = Conv2D('conv2', pool0, out_channel=40, kernel_shape=3,
padding='valid')
feature = batch_flatten(pool1)
feature = batch_flatten(conv2)
fc0 = FullyConnected('fc0', feature, 512)
fc0 = tf.nn.relu(fc0)
fc2 = FullyConnected('lr', fc1, out_dim=10)
prob = tf.nn.softmax(fc2, name='output')
fc1 = FullyConnected('lr', fc0, out_dim=10)
prob = tf.nn.softmax(fc1, name='output')
logprob = tf.log(prob)
y = one_hot(label, NUM_CLASS)
......@@ -82,8 +84,9 @@ def main():
ext.init()
summary_op = tf.merge_all_summaries()
sess = tf.Session()
config = tf.ConfigProto()
config.device_count['GPU'] = 1
sess = tf.Session(config=config)
sess.run(tf.initialize_all_variables())
summary_writer = tf.train.SummaryWriter(LOG_DIR, graph_def=sess.graph_def)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment