Commit dea224ac authored by ppwwyyxx's avatar ppwwyyxx

sat

parent 53571a78
......@@ -60,7 +60,7 @@ def get_model(input, label):
y = one_hot(label, NUM_CLASS)
cost = tf.nn.softmax_cross_entropy_with_logits(fc1, y)
cost = tf.reduce_sum(cost, name='cost')
cost = tf.reduce_mean(cost, name='cost')
tf.scalar_summary(cost.op.name, cost)
return prob, cost
......@@ -97,14 +97,16 @@ def main():
keep_prob = G.get_tensor_by_name('dropout_prob:0')
with sess.as_default():
for epoch in count(1):
running_cost = StatCounter()
for (img, label) in dataset_train.get_data():
feed = {input_var: img,
label_var: label,
keep_prob: 0.5}
_, cost_value = sess.run([train_op, cost], feed_dict=feed)
running_cost.feed(cost_value)
print('Epoch %d: last batch cost = %.2f' % (epoch, cost_value))
print('Epoch %d: avg cost = %.2f' % (epoch, running_cost.average))
summary_str = summary_op.eval(feed_dict=feed)
summary_writer.add_summary(summary_str, epoch)
......
......@@ -65,7 +65,7 @@ class OnehotClassificationValidation(PeriodicExtension):
def _trigger(self):
cnt = 0
cnt_correct = 0
correct_stat = Accuracy()
sess = tf.get_default_session()
cost_sum = 0
for (img, label) in self.ds.get_data():
......@@ -75,12 +75,12 @@ class OnehotClassificationValidation(PeriodicExtension):
cnt += img.shape[0]
correct, cost = sess.run([self.nr_correct_var, self.cost_var],
feed_dict=feed)
cnt_correct += correct
cost_sum += cost
correct_stat.feed(correct, cnt)
cost_sum += cost * cnt
cost_sum /= cnt
# TODO write to summary?
print "After epoch {}: acc={}, cost={}".format(
self.epoch_num, cnt_correct * 1.0 / cnt, cost_sum)
self.epoch_num, correct_stat.accuracy, cost_sum)
class PeriodicSaver(PeriodicExtension):
......
#!/usr/bin/env python2
# -*- coding: UTF-8 -*-
# File: stat.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
import numpy as np
class StatCounter(object):
def __init__(self):
self.values = []
def feed(self, v):
self.values.append(v)
@property
def average(self):
return np.mean(self.values)
@property
def sum(self):
return np.sum(self.values)
class Accuracy(object):
def __init__(self):
self.tot = 0
self.corr = 0
def feed(self, corr, tot=1):
self.tot += tot
self.corr += corr
@property
def accuracy(self):
if self.tot < 0.001:
return 0
return self.corr * 1.0 / self.tot
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment