Commit 53571a78 authored by ppwwyyxx's avatar ppwwyyxx

update

parent 8c57fc1f
......@@ -35,7 +35,7 @@ def get_model(input, label):
cost: scalar variable
"""
# use this dropout variable! it will be set to 1 at test time
keep_prob = tf.placeholder(tf.float32, name='dropout_prob')
keep_prob = tf.placeholder(tf.float32, shape=tuple(), name='dropout_prob')
input = tf.reshape(input, [-1, IMAGE_SIZE, IMAGE_SIZE, 1])
conv0 = Conv2D('conv0', input, out_channel=32, kernel_shape=5,
......@@ -60,19 +60,17 @@ def get_model(input, label):
y = one_hot(label, NUM_CLASS)
cost = tf.nn.softmax_cross_entropy_with_logits(fc1, y)
#logprob = logSoftmax(fc1)
#cost = tf.reduce_sum(-y * logprob, 1)
cost = tf.reduce_sum(cost, name='cost')
tf.scalar_summary(cost.op.name, cost)
return prob, cost
def main():
dataset_train = Mnist('train')
dataset_test = Mnist('test')
dataset_train = BatchData(Mnist('train'), batch_size)
dataset_test = BatchData(Mnist('test'), batch_size, remainder=True)
extensions = [
OnehotClassificationValidation(
BatchData(dataset_test, batch_size, remainder=True),
dataset_test,
prefix='test', period=2),
PeriodicSaver(LOG_DIR, period=2)
]
......@@ -99,7 +97,7 @@ def main():
keep_prob = G.get_tensor_by_name('dropout_prob:0')
with sess.as_default():
for epoch in count(1):
for (img, label) in BatchData(dataset_train, batch_size).get_data():
for (img, label) in dataset_train.get_data():
feed = {input_var: img,
label_var: label,
keep_prob: 0.5}
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment