Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
4267ea4d
Commit
4267ea4d
authored
Dec 26, 2015
by
ppwwyyxx
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
add dropout
parent
5fab58f3
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
36 additions
and
20 deletions
+36
-20
example_mnist.py
example_mnist.py
+15
-8
utils/extension.py
utils/extension.py
+13
-11
utils/symbolic_functions.py
utils/symbolic_functions.py
+8
-1
No files found.
example_mnist.py
View file @
4267ea4d
...
...
@@ -34,26 +34,30 @@ def get_model(input, label):
output: variable
cost: scalar variable
"""
keep_prob
=
tf
.
placeholder
(
tf
.
float32
,
name
=
'dropout_prob'
)
input
=
tf
.
reshape
(
input
,
[
-
1
,
IMAGE_SIZE
,
IMAGE_SIZE
,
1
])
conv0
=
Conv2D
(
'conv0'
,
input
,
out_channel
=
20
,
kernel_shape
=
5
,
conv0
=
Conv2D
(
'conv0'
,
input
,
out_channel
=
32
,
kernel_shape
=
5
,
padding
=
'valid'
)
conv0
=
tf
.
nn
.
relu
(
conv0
)
pool0
=
tf
.
nn
.
max_pool
(
conv0
,
ksize
=
[
1
,
2
,
2
,
1
],
strides
=
[
1
,
2
,
2
,
1
],
padding
=
'SAME'
)
conv1
=
Conv2D
(
'conv1'
,
pool0
,
out_channel
=
40
,
kernel_shape
=
3
,
padding
=
'valid'
)
conv1
=
tf
.
nn
.
relu
(
conv1
)
pool1
=
tf
.
nn
.
max_pool
(
conv1
,
ksize
=
[
1
,
2
,
2
,
1
],
strides
=
[
1
,
2
,
2
,
1
],
padding
=
'SAME'
)
conv2
=
Conv2D
(
'conv2'
,
pool0
,
out_channel
=
40
,
kernel_shape
=
3
,
padding
=
'valid'
)
feature
=
batch_flatten
(
conv2
)
feature
=
batch_flatten
(
pool1
)
fc0
=
FullyConnected
(
'fc0'
,
feature
,
512
)
fc0
=
FullyConnected
(
'fc0'
,
feature
,
1024
)
fc0
=
tf
.
nn
.
relu
(
fc0
)
fc0
=
tf
.
nn
.
dropout
(
fc0
,
keep_prob
)
fc1
=
FullyConnected
(
'lr'
,
fc0
,
out_dim
=
10
)
prob
=
tf
.
nn
.
softmax
(
fc1
,
name
=
'output'
)
logprob
=
tf
.
log
(
prob
)
logprob
=
logSoftmax
(
fc1
)
y
=
one_hot
(
label
,
NUM_CLASS
)
cost
=
tf
.
reduce_sum
(
-
y
*
logprob
,
1
)
cost
=
tf
.
reduce_mean
(
cost
,
name
=
'cost'
)
...
...
@@ -77,7 +81,7 @@ def main():
prob
,
cost
=
get_model
(
input_var
,
label_var
)
optimizer
=
tf
.
train
.
Ada
gradOptimizer
(
0.01
)
optimizer
=
tf
.
train
.
Ada
mOptimizer
(
1e-4
)
train_op
=
optimizer
.
minimize
(
cost
)
for
ext
in
extensions
:
...
...
@@ -90,11 +94,14 @@ def main():
sess
.
run
(
tf
.
initialize_all_variables
())
summary_writer
=
tf
.
train
.
SummaryWriter
(
LOG_DIR
,
graph_def
=
sess
.
graph_def
)
g
=
tf
.
get_default_graph
()
keep_prob
=
g
.
get_tensor_by_name
(
'dropout_prob:0'
)
with
sess
.
as_default
():
for
epoch
in
count
(
1
):
for
(
img
,
label
)
in
BatchData
(
dataset_train
,
batch_size
)
.
get_data
():
feed
=
{
input_var
:
img
,
label_var
:
label
}
label_var
:
label
,
keep_prob
:
0.5
}
_
,
cost_value
=
sess
.
run
([
train_op
,
cost
],
feed_dict
=
feed
)
...
...
utils/extension.py
View file @
4267ea4d
...
...
@@ -41,21 +41,22 @@ class OnehotClassificationValidation(PeriodicExtension):
"""
def
__init__
(
self
,
ds
,
prefix
,
period
=
1
,
input_
op_name
=
'input
'
,
label_
op_name
=
'label
'
,
output_
op_name
=
'output
'
):
input_
var_name
=
'input:0
'
,
label_
var_name
=
'label:0
'
,
output_
var_name
=
'output:0
'
):
super
(
OnehotClassificationValidation
,
self
)
.
__init__
(
period
)
self
.
ds
=
ds
self
.
input_
op_name
=
input_op
_name
self
.
output_
op_name
=
output_op
_name
self
.
label_
op_name
=
label_op
_name
self
.
input_
var_name
=
input_var
_name
self
.
output_
var_name
=
output_var
_name
self
.
label_
var_name
=
label_var
_name
def
init
(
self
):
self
.
graph
=
tf
.
get_default_graph
()
with
tf
.
name_scope
(
'validation'
):
self
.
input_var
=
self
.
graph
.
get_operation_by_name
(
self
.
input_op_name
)
.
outputs
[
0
]
self
.
label_var
=
self
.
graph
.
get_operation_by_name
(
self
.
label_op_name
)
.
outputs
[
0
]
self
.
output_var
=
self
.
graph
.
get_operation_by_name
(
self
.
output_op_name
)
.
outputs
[
0
]
self
.
input_var
=
self
.
graph
.
get_tensor_by_name
(
self
.
input_var_name
)
self
.
label_var
=
self
.
graph
.
get_tensor_by_name
(
self
.
label_var_name
)
self
.
output_var
=
self
.
graph
.
get_tensor_by_name
(
self
.
output_var_name
)
self
.
dropout_var
=
self
.
graph
.
get_tensor_by_name
(
'dropout_prob:0'
)
correct
=
tf
.
equal
(
tf
.
cast
(
tf
.
argmax
(
self
.
output_var
,
1
),
tf
.
int32
),
self
.
label_var
)
...
...
@@ -66,8 +67,9 @@ class OnehotClassificationValidation(PeriodicExtension):
cnt
=
0
cnt_correct
=
0
for
(
img
,
label
)
in
self
.
ds
.
get_data
():
# TODO dropout?
feed
=
{
self
.
input_var
:
img
,
self
.
label_var
:
label
}
feed
=
{
self
.
input_var
:
img
,
self
.
label_var
:
label
,
self
.
dropout_var
:
1.0
}
cnt
+=
img
.
shape
[
0
]
cnt_correct
+=
self
.
nr_correct_var
.
eval
(
feed_dict
=
feed
)
# TODO write to summary?
...
...
utils/symbolic_functions.py
View file @
4267ea4d
...
...
@@ -5,7 +5,7 @@
import
tensorflow
as
tf
import
numpy
as
np
__all__
=
[
'one_hot'
,
'batch_flatten'
]
__all__
=
[
'one_hot'
,
'batch_flatten'
,
'logSoftmax'
]
def
one_hot
(
y
,
num_labels
):
batch_size
=
tf
.
size
(
y
)
...
...
@@ -20,3 +20,10 @@ def one_hot(y, num_labels):
def
batch_flatten
(
x
):
total_dim
=
np
.
prod
(
x
.
get_shape
()[
1
:]
.
as_list
())
return
tf
.
reshape
(
x
,
[
-
1
,
total_dim
])
def
logSoftmax
(
x
):
z
=
x
-
tf
.
reduce_max
(
x
,
1
,
keep_dims
=
True
)
logprob
=
z
-
tf
.
log
(
tf
.
reduce_sum
(
tf
.
exp
(
z
),
1
,
keep_dims
=
True
))
return
logprob
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment