Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
dea224ac
Commit
dea224ac
authored
Dec 26, 2015
by
ppwwyyxx
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
sat
parent
53571a78
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
45 additions
and
6 deletions
+45
-6
example_mnist.py
example_mnist.py
+4
-2
utils/extension.py
utils/extension.py
+4
-4
utils/stat.py
utils/stat.py
+37
-0
No files found.
example_mnist.py
View file @
dea224ac
...
...
@@ -60,7 +60,7 @@ def get_model(input, label):
y
=
one_hot
(
label
,
NUM_CLASS
)
cost
=
tf
.
nn
.
softmax_cross_entropy_with_logits
(
fc1
,
y
)
cost
=
tf
.
reduce_
sum
(
cost
,
name
=
'cost'
)
cost
=
tf
.
reduce_
mean
(
cost
,
name
=
'cost'
)
tf
.
scalar_summary
(
cost
.
op
.
name
,
cost
)
return
prob
,
cost
...
...
@@ -97,14 +97,16 @@ def main():
keep_prob
=
G
.
get_tensor_by_name
(
'dropout_prob:0'
)
with
sess
.
as_default
():
for
epoch
in
count
(
1
):
running_cost
=
StatCounter
()
for
(
img
,
label
)
in
dataset_train
.
get_data
():
feed
=
{
input_var
:
img
,
label_var
:
label
,
keep_prob
:
0.5
}
_
,
cost_value
=
sess
.
run
([
train_op
,
cost
],
feed_dict
=
feed
)
running_cost
.
feed
(
cost_value
)
print
(
'Epoch
%
d:
last batch cost =
%.2
f'
%
(
epoch
,
cost_valu
e
))
print
(
'Epoch
%
d:
avg cost =
%.2
f'
%
(
epoch
,
running_cost
.
averag
e
))
summary_str
=
summary_op
.
eval
(
feed_dict
=
feed
)
summary_writer
.
add_summary
(
summary_str
,
epoch
)
...
...
utils/extension.py
View file @
dea224ac
...
...
@@ -65,7 +65,7 @@ class OnehotClassificationValidation(PeriodicExtension):
def
_trigger
(
self
):
cnt
=
0
c
nt_correct
=
0
c
orrect_stat
=
Accuracy
()
sess
=
tf
.
get_default_session
()
cost_sum
=
0
for
(
img
,
label
)
in
self
.
ds
.
get_data
():
...
...
@@ -75,12 +75,12 @@ class OnehotClassificationValidation(PeriodicExtension):
cnt
+=
img
.
shape
[
0
]
correct
,
cost
=
sess
.
run
([
self
.
nr_correct_var
,
self
.
cost_var
],
feed_dict
=
feed
)
c
nt_correct
+=
correct
cost_sum
+=
cost
c
orrect_stat
.
feed
(
correct
,
cnt
)
cost_sum
+=
cost
*
cnt
cost_sum
/=
cnt
# TODO write to summary?
print
"After epoch {}: acc={}, cost={}"
.
format
(
self
.
epoch_num
,
c
nt_correct
*
1.0
/
cnt
,
cost_sum
)
self
.
epoch_num
,
c
orrect_stat
.
accuracy
,
cost_sum
)
class
PeriodicSaver
(
PeriodicExtension
):
...
...
utils/stat.py
0 → 100644
View file @
dea224ac
#!/usr/bin/env python2
# -*- coding: UTF-8 -*-
# File: stat.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
import
numpy
as
np
class
StatCounter
(
object
):
def
__init__
(
self
):
self
.
values
=
[]
def
feed
(
self
,
v
):
self
.
values
.
append
(
v
)
@
property
def
average
(
self
):
return
np
.
mean
(
self
.
values
)
@
property
def
sum
(
self
):
return
np
.
sum
(
self
.
values
)
class
Accuracy
(
object
):
def
__init__
(
self
):
self
.
tot
=
0
self
.
corr
=
0
def
feed
(
self
,
corr
,
tot
=
1
):
self
.
tot
+=
tot
self
.
corr
+=
corr
@
property
def
accuracy
(
self
):
if
self
.
tot
<
0.001
:
return
0
return
self
.
corr
*
1.0
/
self
.
tot
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment