Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
12f2866e
Commit
12f2866e
authored
Dec 25, 2015
by
ppwwyyxx
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
extension
parent
db475954
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
74 additions
and
33 deletions
+74
-33
dataflow/batch.py
dataflow/batch.py
+9
-1
example_mnist.py
example_mnist.py
+13
-24
utils/extension.py
utils/extension.py
+52
-8
No files found.
dataflow/batch.py
View file @
12f2866e
...
...
@@ -8,9 +8,15 @@ import numpy as np
__all__
=
[
'BatchData'
]
class
BatchData
(
object
):
def
__init__
(
self
,
ds
,
batch_size
):
def
__init__
(
self
,
ds
,
batch_size
,
remainder
=
False
):
"""
Args:
ds: a dataflow
remainder: whether to return the remaining data smaller than a batch_size
"""
self
.
ds
=
ds
self
.
batch_size
=
batch_size
self
.
remainder
=
remainder
def
get_data
(
self
):
holder
=
[]
...
...
@@ -19,6 +25,8 @@ class BatchData(object):
if
len
(
holder
)
==
self
.
batch_size
:
yield
BatchData
.
aggregate_batch
(
holder
)
holder
=
[]
if
self
.
remainder
and
len
(
holder
)
>
0
:
yield
BatchData
.
aggregate_batch
(
holder
)
@
staticmethod
def
aggregate_batch
(
data_holder
):
...
...
example_mnist.py
View file @
12f2866e
...
...
@@ -44,22 +44,16 @@ def get_model(input, label):
tf
.
scalar_summary
(
cost
.
op
.
name
,
cost
)
return
prob
,
cost
#def get_eval(prob, labels):
#"""
#Args:
#prob: bx10
#labels: b
#Returns:
#scalar float: accuracy
#"""
#correct = tf.nn.in_top_k(prob, labels, 1)
#nr_correct = tf.reduce_sum(tf.cast(correct, tf.int32))
#return tf.cast(nr_correct, tf.float32) / tf.cast(tf.size(labels), tf.float32)
def
main
():
dataset_train
=
Mnist
(
'train'
)
dataset_test
=
Mnist
(
'test'
)
extensions
=
[
OnehotClassificationValidation
(
BatchData
(
dataset_test
,
batch_size
,
remainder
=
True
),
prefix
=
'test'
,
period
=
2
),
PeriodicSaver
(
LOG_DIR
,
period
=
2
)
]
with
tf
.
Graph
()
.
as_default
():
input_var
=
tf
.
placeholder
(
tf
.
float32
,
shape
=
(
None
,
PIXELS
),
name
=
'input'
)
label_var
=
tf
.
placeholder
(
tf
.
int32
,
shape
=
(
None
,),
name
=
'label'
)
...
...
@@ -69,17 +63,14 @@ def main():
optimizer
=
tf
.
train
.
AdagradOptimizer
(
0.01
)
train_op
=
optimizer
.
minimize
(
cost
)
validation_ext
=
OnehotClassificationValidation
(
BatchData
(
dataset_test
,
batch_size
),
'test'
)
validation_ext
.
init
()
for
ext
in
extensions
:
ext
.
init
()
summary_op
=
tf
.
merge_all_summaries
()
saver
=
tf
.
train
.
Saver
()
sess
=
tf
.
Session
()
sess
.
run
(
tf
.
initialize_all_variables
())
summary_writer
=
tf
.
train
.
SummaryWriter
(
LOG_DIR
,
graph_def
=
sess
.
graph_def
)
summary_writer
=
tf
.
train
.
SummaryWriter
(
LOG_DIR
,
graph_def
=
sess
.
graph_def
)
with
sess
.
as_default
():
for
epoch
in
count
(
1
):
...
...
@@ -90,13 +81,11 @@ def main():
_
,
cost_value
=
sess
.
run
([
train_op
,
cost
],
feed_dict
=
feed
)
print
(
'Epoch
%
d: last batch cost =
%.2
f'
%
(
epoch
,
cost_value
))
summary_str
=
sess
.
run
(
summary_op
,
feed_dict
=
feed
)
summary_str
=
summary_op
.
eval
(
feed_dict
=
feed
)
summary_writer
.
add_summary
(
summary_str
,
epoch
)
if
epoch
%
2
==
0
:
saver
.
save
(
sess
,
LOG_DIR
,
global_step
=
epoch
)
validation_ext
.
trigger
()
for
ext
in
extensions
:
ext
.
trigger
()
...
...
utils/extension.py
View file @
12f2866e
...
...
@@ -4,17 +4,47 @@
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
import
tensorflow
as
tf
import
sys
import
numpy
as
np
import
os
from
abc
import
abstractmethod
class
OnehotClassificationValidation
(
object
):
class
Extension
(
object
):
def
init
(
self
):
pass
@
abstractmethod
def
trigger
(
self
):
pass
class
PeriodicExtension
(
Extension
):
def
__init__
(
self
,
period
):
self
.
__period
=
period
self
.
epoch_num
=
0
def
init
(
self
):
pass
def
trigger
(
self
):
self
.
epoch_num
+=
1
if
self
.
epoch_num
%
self
.
__period
==
0
:
self
.
_trigger
()
@
abstractmethod
def
_trigger
(
self
):
pass
class
OnehotClassificationValidation
(
PeriodicExtension
):
"""
use with output: bxn probability
and label: (b,) vector
"""
def
__init__
(
self
,
ds
,
prefix
,
period
=
1
,
input_op_name
=
'input'
,
label_op_name
=
'label'
,
output_op_name
=
'output'
):
super
(
OnehotClassificationValidation
,
self
)
.
__init__
(
period
)
self
.
ds
=
ds
self
.
input_op_name
=
input_op_name
self
.
output_op_name
=
output_op_name
...
...
@@ -30,15 +60,29 @@ class OnehotClassificationValidation(object):
correct
=
tf
.
equal
(
tf
.
cast
(
tf
.
argmax
(
self
.
output_var
,
1
),
tf
.
int32
),
self
.
label_var
)
# TODO: add cost
self
.
accuracy_var
=
tf
.
reduce_mean
(
tf
.
cast
(
correct
,
tf
.
floa
t32
))
self
.
nr_correct_var
=
tf
.
reduce_sum
(
tf
.
cast
(
correct
,
tf
.
in
t32
))
def
trigger
(
self
):
scores
=
[]
def
_trigger
(
self
):
cnt
=
0
cnt_correct
=
0
for
(
img
,
label
)
in
self
.
ds
.
get_data
():
# TODO dropout?
feed
=
{
self
.
input_var
:
img
,
self
.
label_var
:
label
}
scores
.
append
(
self
.
accuracy_var
.
eval
(
feed_dict
=
feed
))
acc
=
np
.
array
(
scores
,
dtype
=
'float32'
)
.
mean
()
cnt
+=
img
.
shape
[
0
]
cnt_correct
+=
self
.
nr_correct_var
.
eval
(
feed_dict
=
feed
)
# TODO write to summary?
print
"Accuracy: "
,
acc
print
"Accuracy at epoch {}: {}"
.
format
(
self
.
epoch_num
,
cnt_correct
*
1.0
/
cnt
)
class
PeriodicSaver
(
PeriodicExtension
):
def
__init__
(
self
,
log_dir
,
period
=
1
):
super
(
PeriodicSaver
,
self
)
.
__init__
(
period
)
self
.
path
=
os
.
path
.
join
(
log_dir
,
'model'
)
def
init
(
self
):
self
.
saver
=
tf
.
train
.
Saver
(
max_to_keep
=
99999
)
def
_trigger
(
self
):
self
.
saver
.
save
(
tf
.
get_default_session
(),
self
.
path
,
global_step
=
self
.
epoch_num
,
latest_filename
=
'latest'
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment