Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
28599036
Commit
28599036
authored
Dec 27, 2015
by
ppwwyyxx
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
summary activation, layer with nl
parent
5ec865d8
Changes
12
Hide whitespace changes
Inline
Side-by-side
Showing
12 changed files
with
144 additions
and
65 deletions
+144
-65
dataflow/batch.py
dataflow/batch.py
+8
-0
dataflow/dataset/mnist.py
dataflow/dataset/mnist.py
+4
-0
example_mnist.py
example_mnist.py
+25
-21
layers/_common.py
layers/_common.py
+31
-3
layers/conv2d.py
layers/conv2d.py
+7
-16
layers/fc.py
layers/fc.py
+6
-7
layers/pool.py
layers/pool.py
+26
-0
train.py
train.py
+5
-1
utils/callback.py
utils/callback.py
+3
-1
utils/naming.py
utils/naming.py
+3
-1
utils/summary.py
utils/summary.py
+12
-1
utils/validation_callback.py
utils/validation_callback.py
+14
-14
No files found.
dataflow/batch.py
View file @
28599036
...
@@ -18,6 +18,14 @@ class BatchData(object):
...
@@ -18,6 +18,14 @@ class BatchData(object):
self
.
batch_size
=
batch_size
self
.
batch_size
=
batch_size
self
.
remainder
=
remainder
self
.
remainder
=
remainder
def
size
(
self
):
ds_size
=
self
.
ds
.
size
()
div
=
ds_size
/
self
.
batch_size
rem
=
ds_size
%
self
.
batch_size
if
rem
==
0
:
return
div
return
div
+
int
(
self
.
remainder
)
def
get_data
(
self
):
def
get_data
(
self
):
holder
=
[]
holder
=
[]
for
data
in
self
.
ds
.
get_data
():
for
data
in
self
.
ds
.
get_data
():
...
...
dataflow/dataset/mnist.py
View file @
28599036
...
@@ -128,6 +128,10 @@ class Mnist(object):
...
@@ -128,6 +128,10 @@ class Mnist(object):
self
.
dataset
=
read_data_sets
(
dir
)
self
.
dataset
=
read_data_sets
(
dir
)
self
.
train_or_test
=
train_or_test
self
.
train_or_test
=
train_or_test
def
size
(
self
):
ds
=
self
.
dataset
.
train
if
self
.
train_or_test
==
'train'
else
self
.
dataset
.
test
return
ds
.
num_examples
def
get_data
(
self
):
def
get_data
(
self
):
ds
=
self
.
dataset
.
train
if
self
.
train_or_test
==
'train'
else
self
.
dataset
.
test
ds
=
self
.
dataset
.
train
if
self
.
train_or_test
==
'train'
else
self
.
dataset
.
test
for
k
in
xrange
(
ds
.
num_examples
):
for
k
in
xrange
(
ds
.
num_examples
):
...
...
example_mnist.py
View file @
28599036
...
@@ -40,21 +40,16 @@ def get_model(inputs):
...
@@ -40,21 +40,16 @@ def get_model(inputs):
image
=
tf
.
reshape
(
image
,
[
-
1
,
IMAGE_SIZE
,
IMAGE_SIZE
,
1
])
image
=
tf
.
reshape
(
image
,
[
-
1
,
IMAGE_SIZE
,
IMAGE_SIZE
,
1
])
conv0
=
Conv2D
(
'conv0'
,
image
,
out_channel
=
32
,
kernel_shape
=
5
,
conv0
=
Conv2D
(
'conv0'
,
image
,
out_channel
=
32
,
kernel_shape
=
5
,
padding
=
'valid'
)
padding
=
'valid'
)
conv0
=
tf
.
nn
.
relu
(
conv0
)
pool0
=
MaxPooling
(
'pool0'
,
conv0
,
2
)
pool0
=
tf
.
nn
.
max_pool
(
conv0
,
ksize
=
[
1
,
2
,
2
,
1
],
conv1
=
Conv2D
(
'conv1'
,
pool0
,
out_channel
=
40
,
kernel_shape
=
3
)
strides
=
[
1
,
2
,
2
,
1
],
padding
=
'SAME'
)
pool1
=
MaxPooling
(
'pool1'
,
conv1
,
2
)
conv1
=
Conv2D
(
'conv1'
,
pool0
,
out_channel
=
40
,
kernel_shape
=
3
,
padding
=
'valid'
)
conv1
=
tf
.
nn
.
relu
(
conv1
)
fc0
=
FullyConnected
(
'fc0'
,
pool1
,
1024
)
pool1
=
tf
.
nn
.
max_pool
(
conv1
,
ksize
=
[
1
,
2
,
2
,
1
],
strides
=
[
1
,
2
,
2
,
1
],
padding
=
'SAME'
)
feature
=
batch_flatten
(
pool1
)
fc0
=
FullyConnected
(
'fc0'
,
feature
,
1024
)
fc0
=
tf
.
nn
.
relu
(
fc0
)
fc0
=
tf
.
nn
.
dropout
(
fc0
,
keep_prob
)
fc0
=
tf
.
nn
.
dropout
(
fc0
,
keep_prob
)
fc1
=
FullyConnected
(
'lr'
,
fc0
,
out_dim
=
10
)
# fc will have activation summary by default. disable this for the output layer
fc1
=
FullyConnected
(
'fc1'
,
fc0
,
out_dim
=
10
,
summary_activation
=
False
,
nl
=
tf
.
identity
)
prob
=
tf
.
nn
.
softmax
(
fc1
,
name
=
'output'
)
prob
=
tf
.
nn
.
softmax
(
fc1
,
name
=
'output'
)
y
=
one_hot
(
label
,
10
)
y
=
one_hot
(
label
,
10
)
...
@@ -62,16 +57,16 @@ def get_model(inputs):
...
@@ -62,16 +57,16 @@ def get_model(inputs):
cost
=
tf
.
reduce_mean
(
cost
,
name
=
'cross_entropy_loss'
)
cost
=
tf
.
reduce_mean
(
cost
,
name
=
'cross_entropy_loss'
)
tf
.
add_to_collection
(
COST_VARS_KEY
,
cost
)
tf
.
add_to_collection
(
COST_VARS_KEY
,
cost
)
# compute the number of
correctly classified samples, for ValidationAccuracy
to use at test time
# compute the number of
failed samples, for ValidationErro
to use at test time
correct
=
tf
.
equal
(
wrong
=
tf
.
not_
equal
(
tf
.
cast
(
tf
.
argmax
(
prob
,
1
),
tf
.
int32
),
label
)
tf
.
cast
(
tf
.
argmax
(
prob
,
1
),
tf
.
int32
),
label
)
correct
=
tf
.
cast
(
correct
,
tf
.
float32
)
wrong
=
tf
.
cast
(
wrong
,
tf
.
float32
)
nr_
correct
=
tf
.
reduce_sum
(
correct
,
name
=
'correct
'
)
nr_
wrong
=
tf
.
reduce_sum
(
wrong
,
name
=
'wrong
'
)
# monitor training accuracy
# monitor training accuracy
tf
.
add_to_collection
(
tf
.
add_to_collection
(
SUMMARY_VARS_KEY
,
SUMMARY_VARS_KEY
,
1
-
tf
.
reduce_mean
(
correct
,
name
=
'train_error'
))
tf
.
sub
(
1.0
,
tf
.
reduce_mean
(
wrong
)
,
name
=
'train_error'
))
# weight decay on all W of fc layers
# weight decay on all W of fc layers
wd_cost
=
tf
.
mul
(
1e-4
,
wd_cost
=
tf
.
mul
(
1e-4
,
...
@@ -79,7 +74,7 @@ def get_model(inputs):
...
@@ -79,7 +74,7 @@ def get_model(inputs):
name
=
'regularize_loss'
)
name
=
'regularize_loss'
)
tf
.
add_to_collection
(
COST_VARS_KEY
,
wd_cost
)
tf
.
add_to_collection
(
COST_VARS_KEY
,
wd_cost
)
return
[
prob
,
nr_
correct
],
tf
.
add_n
(
tf
.
get_collection
(
COST_VARS_KEY
),
name
=
'cost'
)
return
[
prob
,
nr_
wrong
],
tf
.
add_n
(
tf
.
get_collection
(
COST_VARS_KEY
),
name
=
'cost'
)
def
main
(
argv
=
None
):
def
main
(
argv
=
None
):
BATCH_SIZE
=
128
BATCH_SIZE
=
128
...
@@ -97,11 +92,20 @@ def main(argv=None):
...
@@ -97,11 +92,20 @@ def main(argv=None):
output_vars
,
cost_var
=
get_model
(
input_vars
)
output_vars
,
cost_var
=
get_model
(
input_vars
)
add_histogram_summary
(
'.*/W'
)
# monitor histogram of all W
add_histogram_summary
(
'.*/W'
)
# monitor histogram of all W
global_step_var
=
tf
.
Variable
(
0
,
trainable
=
False
,
name
=
GLOBAL_STEP_OP_NAME
)
lr
=
tf
.
train
.
exponential_decay
(
learning_rate
=
1e-4
,
global_step
=
global_step_var
,
decay_steps
=
dataset_train
.
size
()
*
50
,
decay_rate
=
0.1
,
staircase
=
True
,
name
=
'learning_rate'
)
tf
.
scalar_summary
(
'learning_rate'
,
lr
)
config
=
dict
(
config
=
dict
(
dataset_train
=
dataset_train
,
dataset_train
=
dataset_train
,
optimizer
=
tf
.
train
.
AdamOptimizer
(
1e-4
),
optimizer
=
tf
.
train
.
AdamOptimizer
(
lr
),
callbacks
=
[
callbacks
=
[
Validation
Accuracy
(
Validation
Error
(
dataset_test
,
dataset_test
,
prefix
=
'test'
),
prefix
=
'test'
),
PeriodicSaver
(
LOG_DIR
,
period
=
1
),
PeriodicSaver
(
LOG_DIR
,
period
=
1
),
...
...
layers/_common.py
View file @
28599036
...
@@ -4,15 +4,43 @@
...
@@ -4,15 +4,43 @@
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
import
tensorflow
as
tf
import
tensorflow
as
tf
from
utils.summary
import
*
def
layer_register
():
def
layer_register
(
summary_activation
=
False
):
"""
summary_activation: default behavior of whether to summary the output of this layer
"""
def
wrapper
(
func
):
def
wrapper
(
func
):
def
inner
(
*
args
,
**
kwargs
):
def
inner
(
*
args
,
**
kwargs
):
name
=
args
[
0
]
name
=
args
[
0
]
assert
isinstance
(
name
,
basestring
)
assert
isinstance
(
name
,
basestring
)
args
=
args
[
1
:]
args
=
args
[
1
:]
do_summary
=
kwargs
.
pop
(
'summary_activation'
,
summary_activation
)
with
tf
.
variable_scope
(
name
):
with
tf
.
variable_scope
(
name
)
as
scope
:
return
func
(
*
args
,
**
kwargs
)
ret
=
func
(
*
args
,
**
kwargs
)
if
do_summary
:
ndim
=
ret
.
get_shape
()
.
ndims
assert
ndim
>=
2
,
\
"Summary a scalar with histogram? Maybe use scalar instead. FIXME!"
add_activation_summary
(
ret
,
scope
.
name
)
return
ret
return
inner
return
inner
return
wrapper
return
wrapper
def
shape2d
(
a
):
"""
a: a int or tuple/list of length 2
"""
if
type
(
a
)
==
int
:
return
[
a
,
a
]
if
type
(
a
)
in
[
list
,
tuple
]:
assert
len
(
a
)
==
2
return
list
(
a
)
raise
RuntimeError
(
"Illegal shape: {}"
.
format
(
a
))
def
shape4d
(
a
):
# for use with tensorflow
return
[
1
]
+
shape2d
(
a
)
+
[
1
]
layers/conv2d.py
View file @
28599036
...
@@ -5,14 +5,14 @@
...
@@ -5,14 +5,14 @@
import
tensorflow
as
tf
import
tensorflow
as
tf
import
math
import
math
from
._common
import
layer_register
from
._common
import
*
__all__
=
[
'Conv2D'
]
__all__
=
[
'Conv2D'
]
@
layer_register
()
@
layer_register
(
summary_activation
=
True
)
def
Conv2D
(
x
,
out_channel
,
kernel_shape
,
def
Conv2D
(
x
,
out_channel
,
kernel_shape
,
padding
=
'VALID'
,
stride
=
None
,
padding
=
'VALID'
,
stride
=
1
,
W_init
=
None
,
b_init
=
None
):
W_init
=
None
,
b_init
=
None
,
nl
=
tf
.
nn
.
relu
):
"""
"""
kernel_shape: (h, w) or a int
kernel_shape: (h, w) or a int
stride: (h, w) or a int
stride: (h, w) or a int
...
@@ -21,19 +21,10 @@ def Conv2D(x, out_channel, kernel_shape,
...
@@ -21,19 +21,10 @@ def Conv2D(x, out_channel, kernel_shape,
in_shape
=
x
.
get_shape
()
.
as_list
()
in_shape
=
x
.
get_shape
()
.
as_list
()
in_channel
=
in_shape
[
-
1
]
in_channel
=
in_shape
[
-
1
]
if
type
(
kernel_shape
)
==
int
:
kernel_shape
=
shape2d
(
kernel_shape
)
kernel_shape
=
[
kernel_shape
,
kernel_shape
]
padding
=
padding
.
upper
()
padding
=
padding
.
upper
()
filter_shape
=
kernel_shape
+
[
in_channel
,
out_channel
]
filter_shape
=
kernel_shape
+
[
in_channel
,
out_channel
]
stride
=
shape4d
(
stride
)
if
stride
is
None
:
stride
=
[
1
,
1
,
1
,
1
]
elif
type
(
stride
)
==
int
:
stride
=
[
1
,
stride
,
stride
,
1
]
elif
type
(
stride
)
in
[
list
,
tuple
]:
assert
len
(
stride
)
==
2
stride
=
[
1
]
+
list
(
stride
)
+
[
1
]
if
W_init
is
None
:
if
W_init
is
None
:
W_init
=
tf
.
truncated_normal_initializer
(
stddev
=
0.04
)
W_init
=
tf
.
truncated_normal_initializer
(
stddev
=
0.04
)
...
@@ -44,5 +35,5 @@ def Conv2D(x, out_channel, kernel_shape,
...
@@ -44,5 +35,5 @@ def Conv2D(x, out_channel, kernel_shape,
b
=
tf
.
get_variable
(
'b'
,
[
out_channel
],
initializer
=
b_init
)
b
=
tf
.
get_variable
(
'b'
,
[
out_channel
],
initializer
=
b_init
)
conv
=
tf
.
nn
.
conv2d
(
x
,
W
,
stride
,
padding
)
conv
=
tf
.
nn
.
conv2d
(
x
,
W
,
stride
,
padding
)
return
tf
.
nn
.
bias_add
(
conv
,
b
)
return
nl
(
tf
.
nn
.
bias_add
(
conv
,
b
)
)
layers/fc.py
View file @
28599036
...
@@ -6,15 +6,14 @@
...
@@ -6,15 +6,14 @@
from
._common
import
layer_register
from
._common
import
layer_register
import
tensorflow
as
tf
import
tensorflow
as
tf
from
utils.symbolic_functions
import
*
import
math
import
math
__all__
=
[
'FullyConnected'
]
__all__
=
[
'FullyConnected'
]
@
layer_register
()
@
layer_register
(
summary_activation
=
True
)
def
FullyConnected
(
x
,
out_dim
,
W_init
=
None
,
b_init
=
None
):
def
FullyConnected
(
x
,
out_dim
,
W_init
=
None
,
b_init
=
None
,
nl
=
tf
.
nn
.
relu
):
"""
x
=
batch_flatten
(
x
)
x: matrix of bxn
"""
in_dim
=
x
.
get_shape
()
.
as_list
()[
1
]
in_dim
=
x
.
get_shape
()
.
as_list
()[
1
]
if
W_init
is
None
:
if
W_init
is
None
:
...
@@ -22,6 +21,6 @@ def FullyConnected(x, out_dim, W_init=None, b_init=None):
...
@@ -22,6 +21,6 @@ def FullyConnected(x, out_dim, W_init=None, b_init=None):
if
b_init
is
None
:
if
b_init
is
None
:
b_init
=
tf
.
constant_initializer
()
b_init
=
tf
.
constant_initializer
()
W
=
tf
.
get_variable
(
'W'
,
[
in_dim
,
out_dim
],
initializer
=
W_init
)
# TODO collections
W
=
tf
.
get_variable
(
'W'
,
[
in_dim
,
out_dim
],
initializer
=
W_init
)
b
=
tf
.
get_variable
(
'b'
,
[
out_dim
],
initializer
=
b_init
)
b
=
tf
.
get_variable
(
'b'
,
[
out_dim
],
initializer
=
b_init
)
return
tf
.
matmul
(
x
,
W
)
+
b
return
nl
(
tf
.
matmul
(
x
,
W
)
+
b
)
layers/pool.py
0 → 100644
View file @
28599036
#!/usr/bin/env python2
# -*- coding: UTF-8 -*-
# File: pool.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
from
._common
import
*
import
tensorflow
as
tf
__all__
=
[
'MaxPooling'
]
@
layer_register
()
def
MaxPooling
(
x
,
shape
,
stride
=
None
,
padding
=
'VALID'
):
"""
shape, stride: int or list/tuple of length 2
if stride is None, use shape by default
padding: 'VALID' or 'SAME'
"""
padding
=
padding
.
upper
()
shape
=
shape4d
(
shape
)
if
stride
is
None
:
stride
=
shape
else
:
stride
=
shape4d
(
stride
)
return
tf
.
nn
.
max_pool
(
x
,
ksize
=
shape
,
strides
=
stride
,
padding
=
padding
)
train.py
View file @
28599036
...
@@ -40,7 +40,11 @@ def start_train(config):
...
@@ -40,7 +40,11 @@ def start_train(config):
for
v
in
output_vars
:
for
v
in
output_vars
:
G
.
add_to_collection
(
OUTPUT_VARS_KEY
,
v
)
G
.
add_to_collection
(
OUTPUT_VARS_KEY
,
v
)
global_step_var
=
tf
.
Variable
(
0
,
trainable
=
False
,
name
=
'global_step'
)
try
:
global_step_var
=
G
.
get_tensor_by_name
(
GLOBAL_STEP_VAR_NAME
)
except
KeyError
:
# not created
global_step_var
=
tf
.
Variable
(
0
,
trainable
=
False
,
name
=
GLOBAL_STEP_OP_NAME
)
# add some summary ops to the graph
# add some summary ops to the graph
averager
=
tf
.
train
.
ExponentialMovingAverage
(
averager
=
tf
.
train
.
ExponentialMovingAverage
(
...
...
utils/callback.py
View file @
28599036
...
@@ -63,7 +63,7 @@ class PeriodicSaver(PeriodicCallback):
...
@@ -63,7 +63,7 @@ class PeriodicSaver(PeriodicCallback):
global_step
=
self
.
epoch_num
,
latest_filename
=
'latest'
)
global_step
=
self
.
epoch_num
,
latest_filename
=
'latest'
)
class
SummaryWriter
(
Callback
):
class
SummaryWriter
(
Callback
):
def
__init__
(
self
,
log_dir
,
histogram_regex
=
None
):
def
__init__
(
self
,
log_dir
):
self
.
log_dir
=
log_dir
self
.
log_dir
=
log_dir
self
.
epoch_num
=
0
self
.
epoch_num
=
0
...
@@ -100,6 +100,7 @@ class Callbacks(Callback):
...
@@ -100,6 +100,7 @@ class Callbacks(Callback):
def
before_train
(
self
):
def
before_train
(
self
):
for
cb
in
self
.
callbacks
:
for
cb
in
self
.
callbacks
:
cb
.
before_train
()
cb
.
before_train
()
self
.
writer
=
tf
.
get_collection
(
SUMMARY_WRITER_COLLECTION_KEY
)[
0
]
def
trigger_step
(
self
,
inputs
,
outputs
,
cost
):
def
trigger_step
(
self
,
inputs
,
outputs
,
cost
):
for
cb
in
self
.
callbacks
:
for
cb
in
self
.
callbacks
:
...
@@ -108,4 +109,5 @@ class Callbacks(Callback):
...
@@ -108,4 +109,5 @@ class Callbacks(Callback):
def
trigger_epoch
(
self
):
def
trigger_epoch
(
self
):
for
cb
in
self
.
callbacks
:
for
cb
in
self
.
callbacks
:
cb
.
trigger_epoch
()
cb
.
trigger_epoch
()
self
.
writer
.
flush
()
utils/naming.py
View file @
28599036
...
@@ -6,8 +6,10 @@
...
@@ -6,8 +6,10 @@
DROPOUT_PROB_OP_NAME
=
'dropout_prob'
DROPOUT_PROB_OP_NAME
=
'dropout_prob'
DROPOUT_PROB_VAR_NAME
=
'dropout_prob:0'
DROPOUT_PROB_VAR_NAME
=
'dropout_prob:0'
GLOBAL_STEP_OP_NAME
=
'global_step'
GLOBAL_STEP_VAR_NAME
=
'global_step:0'
SUMMARY_WRITER_COLLECTION_KEY
=
'summary_writer'
SUMMARY_WRITER_COLLECTION_KEY
=
'summary_writer'
MERGE_SUMMARY_OP_NAME
=
'MergeSummary/MergeSummary:0'
INPUT_VARS_KEY
=
'INPUT_VARIABLES'
INPUT_VARS_KEY
=
'INPUT_VARIABLES'
OUTPUT_VARS_KEY
=
'OUTPUT_VARIABLES'
OUTPUT_VARS_KEY
=
'OUTPUT_VARIABLES'
...
...
utils/
utils
.py
→
utils/
summary
.py
View file @
28599036
...
@@ -5,7 +5,7 @@
...
@@ -5,7 +5,7 @@
import
tensorflow
as
tf
import
tensorflow
as
tf
__all__
=
[
'create_summary'
,
'add_histogram_summary'
]
__all__
=
[
'create_summary'
,
'add_histogram_summary'
,
'add_activation_summary'
]
def
create_summary
(
name
,
v
):
def
create_summary
(
name
,
v
):
"""
"""
...
@@ -19,6 +19,17 @@ def create_summary(name, v):
...
@@ -19,6 +19,17 @@ def create_summary(name, v):
s
.
value
.
add
(
tag
=
name
,
simple_value
=
v
)
s
.
value
.
add
(
tag
=
name
,
simple_value
=
v
)
return
s
return
s
def
add_activation_summary
(
x
,
name
=
None
):
"""
Summary for an activation tensor x.
If name is None, use x.name
"""
if
name
is
None
:
name
=
x
.
name
tf
.
histogram_summary
(
name
+
'/activations'
,
x
)
tf
.
scalar_summary
(
name
+
'/sparsity'
,
tf
.
nn
.
zero_fraction
(
x
))
# TODO avoid repeating activations on multiple GPUs
def
add_histogram_summary
(
regex
):
def
add_histogram_summary
(
regex
):
"""
"""
Add histogram summary for all trainable variables matching the regex
Add histogram summary for all trainable variables matching the regex
...
...
utils/validation_callback.py
View file @
28599036
...
@@ -7,24 +7,24 @@ import tensorflow as tf
...
@@ -7,24 +7,24 @@ import tensorflow as tf
from
.stat
import
*
from
.stat
import
*
from
.callback
import
PeriodicCallback
,
Callback
from
.callback
import
PeriodicCallback
,
Callback
from
.naming
import
*
from
.naming
import
*
from
.
utils
import
*
from
.
summary
import
*
class
Validation
Accuracy
(
PeriodicCallback
):
class
Validation
Error
(
PeriodicCallback
):
"""
"""
Validate the accuracy for the given
correct
and cost variable
Validate the accuracy for the given
wrong
and cost variable
Use under the following setup:
Use under the following setup:
correct_var: integer, number of correct
samples in this batch
wrong_var: integer, number of failed
samples in this batch
ds: batched dataset
ds: batched dataset
"""
"""
def
__init__
(
self
,
ds
,
prefix
,
def
__init__
(
self
,
ds
,
prefix
,
period
=
1
,
period
=
1
,
correct_var_name
=
'correct
:0'
,
wrong_var_name
=
'wrong
:0'
,
cost_var_name
=
'cost:0'
):
cost_var_name
=
'cost:0'
):
super
(
Validation
Accuracy
,
self
)
.
__init__
(
period
)
super
(
Validation
Error
,
self
)
.
__init__
(
period
)
self
.
ds
=
ds
self
.
ds
=
ds
self
.
prefix
=
prefix
self
.
prefix
=
prefix
self
.
correct_var_name
=
correct
_var_name
self
.
wrong_var_name
=
wrong
_var_name
self
.
cost_var_name
=
cost_var_name
self
.
cost_var_name
=
cost_var_name
def
get_tensor
(
self
,
name
):
def
get_tensor
(
self
,
name
):
...
@@ -33,13 +33,13 @@ class ValidationAccuracy(PeriodicCallback):
...
@@ -33,13 +33,13 @@ class ValidationAccuracy(PeriodicCallback):
def
_before_train
(
self
):
def
_before_train
(
self
):
self
.
input_vars
=
tf
.
get_collection
(
INPUT_VARS_KEY
)
self
.
input_vars
=
tf
.
get_collection
(
INPUT_VARS_KEY
)
self
.
dropout_var
=
self
.
get_tensor
(
DROPOUT_PROB_VAR_NAME
)
self
.
dropout_var
=
self
.
get_tensor
(
DROPOUT_PROB_VAR_NAME
)
self
.
correct_var
=
self
.
get_tensor
(
self
.
correct
_var_name
)
self
.
wrong_var
=
self
.
get_tensor
(
self
.
wrong
_var_name
)
self
.
cost_var
=
self
.
get_tensor
(
self
.
cost_var_name
)
self
.
cost_var
=
self
.
get_tensor
(
self
.
cost_var_name
)
self
.
writer
=
tf
.
get_collection
(
SUMMARY_WRITER_COLLECTION_KEY
)[
0
]
self
.
writer
=
tf
.
get_collection
(
SUMMARY_WRITER_COLLECTION_KEY
)[
0
]
def
_trigger
(
self
):
def
_trigger
(
self
):
cnt
=
0
cnt
=
0
correct
_stat
=
Accuracy
()
err
_stat
=
Accuracy
()
cost_sum
=
0
cost_sum
=
0
for
dp
in
self
.
ds
.
get_data
():
for
dp
in
self
.
ds
.
get_data
():
feed
=
{
self
.
dropout_var
:
1.0
}
feed
=
{
self
.
dropout_var
:
1.0
}
...
@@ -48,20 +48,20 @@ class ValidationAccuracy(PeriodicCallback):
...
@@ -48,20 +48,20 @@ class ValidationAccuracy(PeriodicCallback):
batch_size
=
dp
[
0
]
.
shape
[
0
]
# assume batched input
batch_size
=
dp
[
0
]
.
shape
[
0
]
# assume batched input
cnt
+=
batch_size
cnt
+=
batch_size
correct
,
cost
=
self
.
sess
.
run
(
wrong
,
cost
=
self
.
sess
.
run
(
[
self
.
correct
_var
,
self
.
cost_var
],
feed_dict
=
feed
)
[
self
.
wrong
_var
,
self
.
cost_var
],
feed_dict
=
feed
)
correct_stat
.
feed
(
correct
,
batch_size
)
err_stat
.
feed
(
wrong
,
batch_size
)
# each batch might not have the same size in validation
# each batch might not have the same size in validation
cost_sum
+=
cost
*
batch_size
cost_sum
+=
cost
*
batch_size
cost_avg
=
cost_sum
/
cnt
cost_avg
=
cost_sum
/
cnt
self
.
writer
.
add_summary
(
self
.
writer
.
add_summary
(
create_summary
(
'{}_error'
.
format
(
self
.
prefix
),
create_summary
(
'{}_error'
.
format
(
self
.
prefix
),
1
-
correct
_stat
.
accuracy
),
err
_stat
.
accuracy
),
self
.
epoch_num
)
self
.
epoch_num
)
self
.
writer
.
add_summary
(
self
.
writer
.
add_summary
(
create_summary
(
'{}_cost'
.
format
(
self
.
prefix
),
create_summary
(
'{}_cost'
.
format
(
self
.
prefix
),
cost_avg
),
cost_avg
),
self
.
epoch_num
)
self
.
epoch_num
)
print
"{} validation after epoch {}: err={}, cost={}"
.
format
(
print
"{} validation after epoch {}: err={}, cost={}"
.
format
(
self
.
prefix
,
self
.
epoch_num
,
1
-
correct
_stat
.
accuracy
,
cost_avg
)
self
.
prefix
,
self
.
epoch_num
,
err
_stat
.
accuracy
,
cost_avg
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment