Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
5fab58f3
Commit
5fab58f3
authored
Dec 25, 2015
by
ppwwyyxx
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
fix
parent
88d607d8
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
19 additions
and
14 deletions
+19
-14
dataflow/dataset/mnist.py
dataflow/dataset/mnist.py
+5
-3
example_mnist.py
example_mnist.py
+14
-11
No files found.
dataflow/dataset/mnist.py
View file @
5fab58f3
...
...
@@ -16,7 +16,7 @@ class Mnist(object):
train_or_test: string either 'train' or 'test'
"""
if
dir
is
None
:
dir
=
os
.
path
.
join
(
os
.
path
.
dirname
(
__file__
),
'mnist'
)
dir
=
os
.
path
.
join
(
os
.
path
.
dirname
(
__file__
),
'mnist
_data
'
)
self
.
dataset
=
input_data
.
read_data_sets
(
dir
)
self
.
train_or_test
=
train_or_test
...
...
@@ -28,5 +28,7 @@ class Mnist(object):
yield
(
img
,
label
)
if
__name__
==
'__main__'
:
ds
=
Mnist
()
ds
.
get_data
()
ds
=
Mnist
(
'train'
)
for
(
img
,
label
)
in
ds
.
get_data
():
from
IPython
import
embed
;
embed
()
example_mnist.py
View file @
5fab58f3
...
...
@@ -4,10 +4,10 @@
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
#
HACK protobuf
#
import sys
#
import os
#
sys.path.insert(0, os.path.expanduser('~/.local/lib/python2.7/site-packages'))
#
prefer protobuf in user-namespace
import
sys
import
os
sys
.
path
.
insert
(
0
,
os
.
path
.
expanduser
(
'~/.local/lib/python2.7/site-packages'
))
import
tensorflow
as
tf
import
numpy
as
np
...
...
@@ -35,21 +35,23 @@ def get_model(input, label):
cost: scalar variable
"""
input
=
tf
.
reshape
(
input
,
[
-
1
,
IMAGE_SIZE
,
IMAGE_SIZE
,
1
])
conv0
=
Conv2D
(
'conv0'
,
input
,
out_channel
=
5
,
kernel_shape
=
3
,
conv0
=
Conv2D
(
'conv0'
,
input
,
out_channel
=
20
,
kernel_shape
=
5
,
padding
=
'valid'
)
pool0
=
tf
.
nn
.
max_pool
(
conv0
,
ksize
=
[
1
,
2
,
2
,
1
],
strides
=
[
1
,
2
,
2
,
1
],
padding
=
'SAME'
)
conv1
=
Conv2D
(
'conv1'
,
pool0
,
out_channel
=
10
,
kernel_shape
=
4
,
conv1
=
Conv2D
(
'conv1'
,
pool0
,
out_channel
=
40
,
kernel_shape
=
3
,
padding
=
'valid'
)
pool1
=
tf
.
nn
.
max_pool
(
conv1
,
ksize
=
[
1
,
2
,
2
,
1
],
strides
=
[
1
,
2
,
2
,
1
],
padding
=
'SAME'
)
conv2
=
Conv2D
(
'conv2'
,
pool0
,
out_channel
=
40
,
kernel_shape
=
3
,
padding
=
'valid'
)
feature
=
batch_flatten
(
pool1
)
feature
=
batch_flatten
(
conv2
)
fc0
=
FullyConnected
(
'fc0'
,
feature
,
512
)
fc0
=
tf
.
nn
.
relu
(
fc0
)
fc
2
=
FullyConnected
(
'lr'
,
fc1
,
out_dim
=
10
)
prob
=
tf
.
nn
.
softmax
(
fc
2
,
name
=
'output'
)
fc
1
=
FullyConnected
(
'lr'
,
fc0
,
out_dim
=
10
)
prob
=
tf
.
nn
.
softmax
(
fc
1
,
name
=
'output'
)
logprob
=
tf
.
log
(
prob
)
y
=
one_hot
(
label
,
NUM_CLASS
)
...
...
@@ -82,8 +84,9 @@ def main():
ext
.
init
()
summary_op
=
tf
.
merge_all_summaries
()
sess
=
tf
.
Session
()
config
=
tf
.
ConfigProto
()
config
.
device_count
[
'GPU'
]
=
1
sess
=
tf
.
Session
(
config
=
config
)
sess
.
run
(
tf
.
initialize_all_variables
())
summary_writer
=
tf
.
train
.
SummaryWriter
(
LOG_DIR
,
graph_def
=
sess
.
graph_def
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment