Commit 5fab58f3 authored by ppwwyyxx's avatar ppwwyyxx

fix

parent 88d607d8
...@@ -16,7 +16,7 @@ class Mnist(object): ...@@ -16,7 +16,7 @@ class Mnist(object):
train_or_test: string either 'train' or 'test' train_or_test: string either 'train' or 'test'
""" """
if dir is None: if dir is None:
dir = os.path.join(os.path.dirname(__file__), 'mnist') dir = os.path.join(os.path.dirname(__file__), 'mnist_data')
self.dataset = input_data.read_data_sets(dir) self.dataset = input_data.read_data_sets(dir)
self.train_or_test = train_or_test self.train_or_test = train_or_test
...@@ -28,5 +28,7 @@ class Mnist(object): ...@@ -28,5 +28,7 @@ class Mnist(object):
yield (img, label) yield (img, label)
if __name__ == '__main__': if __name__ == '__main__':
ds = Mnist() ds = Mnist('train')
ds.get_data() for (img, label) in ds.get_data():
from IPython import embed; embed()
...@@ -4,10 +4,10 @@ ...@@ -4,10 +4,10 @@
# Author: Yuxin Wu <ppwwyyxx@gmail.com> # Author: Yuxin Wu <ppwwyyxx@gmail.com>
# HACK protobuf # prefer protobuf in user-namespace
#import sys import sys
#import os import os
#sys.path.insert(0, os.path.expanduser('~/.local/lib/python2.7/site-packages')) sys.path.insert(0, os.path.expanduser('~/.local/lib/python2.7/site-packages'))
import tensorflow as tf import tensorflow as tf
import numpy as np import numpy as np
...@@ -35,21 +35,23 @@ def get_model(input, label): ...@@ -35,21 +35,23 @@ def get_model(input, label):
cost: scalar variable cost: scalar variable
""" """
input = tf.reshape(input, [-1, IMAGE_SIZE, IMAGE_SIZE, 1]) input = tf.reshape(input, [-1, IMAGE_SIZE, IMAGE_SIZE, 1])
conv0 = Conv2D('conv0', input, out_channel=5, kernel_shape=3, conv0 = Conv2D('conv0', input, out_channel=20, kernel_shape=5,
padding='valid') padding='valid')
pool0 = tf.nn.max_pool(conv0, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], pool0 = tf.nn.max_pool(conv0, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1],
padding='SAME') padding='SAME')
conv1 = Conv2D('conv1', pool0, out_channel=10, kernel_shape=4, conv1 = Conv2D('conv1', pool0, out_channel=40, kernel_shape=3,
padding='valid') padding='valid')
pool1 = tf.nn.max_pool(conv1, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], pool1 = tf.nn.max_pool(conv1, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1],
padding='SAME') padding='SAME')
conv2 = Conv2D('conv2', pool0, out_channel=40, kernel_shape=3,
padding='valid')
feature = batch_flatten(pool1) feature = batch_flatten(conv2)
fc0 = FullyConnected('fc0', feature, 512) fc0 = FullyConnected('fc0', feature, 512)
fc0 = tf.nn.relu(fc0) fc0 = tf.nn.relu(fc0)
fc2 = FullyConnected('lr', fc1, out_dim=10) fc1 = FullyConnected('lr', fc0, out_dim=10)
prob = tf.nn.softmax(fc2, name='output') prob = tf.nn.softmax(fc1, name='output')
logprob = tf.log(prob) logprob = tf.log(prob)
y = one_hot(label, NUM_CLASS) y = one_hot(label, NUM_CLASS)
...@@ -82,8 +84,9 @@ def main(): ...@@ -82,8 +84,9 @@ def main():
ext.init() ext.init()
summary_op = tf.merge_all_summaries() summary_op = tf.merge_all_summaries()
config = tf.ConfigProto()
sess = tf.Session() config.device_count['GPU'] = 1
sess = tf.Session(config=config)
sess.run(tf.initialize_all_variables()) sess.run(tf.initialize_all_variables())
summary_writer = tf.train.SummaryWriter(LOG_DIR, graph_def=sess.graph_def) summary_writer = tf.train.SummaryWriter(LOG_DIR, graph_def=sess.graph_def)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment