Commit 53571a78 authored by ppwwyyxx's avatar ppwwyyxx

update

parent 8c57fc1f
...@@ -35,7 +35,7 @@ def get_model(input, label): ...@@ -35,7 +35,7 @@ def get_model(input, label):
cost: scalar variable cost: scalar variable
""" """
# use this dropout variable! it will be set to 1 at test time # use this dropout variable! it will be set to 1 at test time
keep_prob = tf.placeholder(tf.float32, name='dropout_prob') keep_prob = tf.placeholder(tf.float32, shape=tuple(), name='dropout_prob')
input = tf.reshape(input, [-1, IMAGE_SIZE, IMAGE_SIZE, 1]) input = tf.reshape(input, [-1, IMAGE_SIZE, IMAGE_SIZE, 1])
conv0 = Conv2D('conv0', input, out_channel=32, kernel_shape=5, conv0 = Conv2D('conv0', input, out_channel=32, kernel_shape=5,
...@@ -60,19 +60,17 @@ def get_model(input, label): ...@@ -60,19 +60,17 @@ def get_model(input, label):
y = one_hot(label, NUM_CLASS) y = one_hot(label, NUM_CLASS)
cost = tf.nn.softmax_cross_entropy_with_logits(fc1, y) cost = tf.nn.softmax_cross_entropy_with_logits(fc1, y)
#logprob = logSoftmax(fc1)
#cost = tf.reduce_sum(-y * logprob, 1)
cost = tf.reduce_sum(cost, name='cost') cost = tf.reduce_sum(cost, name='cost')
tf.scalar_summary(cost.op.name, cost) tf.scalar_summary(cost.op.name, cost)
return prob, cost return prob, cost
def main(): def main():
dataset_train = Mnist('train') dataset_train = BatchData(Mnist('train'), batch_size)
dataset_test = Mnist('test') dataset_test = BatchData(Mnist('test'), batch_size, remainder=True)
extensions = [ extensions = [
OnehotClassificationValidation( OnehotClassificationValidation(
BatchData(dataset_test, batch_size, remainder=True), dataset_test,
prefix='test', period=2), prefix='test', period=2),
PeriodicSaver(LOG_DIR, period=2) PeriodicSaver(LOG_DIR, period=2)
] ]
...@@ -99,7 +97,7 @@ def main(): ...@@ -99,7 +97,7 @@ def main():
keep_prob = G.get_tensor_by_name('dropout_prob:0') keep_prob = G.get_tensor_by_name('dropout_prob:0')
with sess.as_default(): with sess.as_default():
for epoch in count(1): for epoch in count(1):
for (img, label) in BatchData(dataset_train, batch_size).get_data(): for (img, label) in dataset_train.get_data():
feed = {input_var: img, feed = {input_var: img,
label_var: label, label_var: label,
keep_prob: 0.5} keep_prob: 0.5}
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment