Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
4a6b480c
Commit
4a6b480c
authored
Dec 27, 2015
by
ppwwyyxx
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
flag select gpu
parent
838b1df7
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
22 additions
and
15 deletions
+22
-15
example_mnist.py
example_mnist.py
+7
-15
train.py
train.py
+15
-0
No files found.
example_mnist.py
View file @
4a6b480c
...
@@ -41,9 +41,8 @@ def get_model(inputs):
...
@@ -41,9 +41,8 @@ def get_model(inputs):
pool0
=
MaxPooling
(
'pool0'
,
conv0
,
2
)
pool0
=
MaxPooling
(
'pool0'
,
conv0
,
2
)
conv1
=
Conv2D
(
'conv1'
,
pool0
,
out_channel
=
40
,
kernel_shape
=
3
)
conv1
=
Conv2D
(
'conv1'
,
pool0
,
out_channel
=
40
,
kernel_shape
=
3
)
pool1
=
MaxPooling
(
'pool1'
,
conv1
,
2
)
pool1
=
MaxPooling
(
'pool1'
,
conv1
,
2
)
conv2
=
Conv2D
(
'conv2'
,
pool1
,
out_channel
=
32
,
kernel_shape
=
3
)
fc0
=
FullyConnected
(
'fc0'
,
conv2
,
1024
)
fc0
=
FullyConnected
(
'fc0'
,
pool1
,
1024
)
fc0
=
tf
.
nn
.
dropout
(
fc0
,
keep_prob
)
fc0
=
tf
.
nn
.
dropout
(
fc0
,
keep_prob
)
# fc will have activation summary by default. disable this for the output layer
# fc will have activation summary by default. disable this for the output layer
...
@@ -56,16 +55,14 @@ def get_model(inputs):
...
@@ -56,16 +55,14 @@ def get_model(inputs):
cost
=
tf
.
reduce_mean
(
cost
,
name
=
'cross_entropy_loss'
)
cost
=
tf
.
reduce_mean
(
cost
,
name
=
'cross_entropy_loss'
)
tf
.
add_to_collection
(
COST_VARS_KEY
,
cost
)
tf
.
add_to_collection
(
COST_VARS_KEY
,
cost
)
# compute the number of failed samples, for ValidationErro to use at test time
# compute the number of failed samples, for ValidationErro
r
to use at test time
wrong
=
tf
.
not_equal
(
wrong
=
tf
.
not_equal
(
tf
.
cast
(
tf
.
argmax
(
prob
,
1
),
tf
.
int32
),
label
)
tf
.
cast
(
tf
.
argmax
(
prob
,
1
),
tf
.
int32
),
label
)
wrong
=
tf
.
cast
(
wrong
,
tf
.
float32
)
wrong
=
tf
.
cast
(
wrong
,
tf
.
float32
)
nr_wrong
=
tf
.
reduce_sum
(
wrong
,
name
=
'wrong'
)
nr_wrong
=
tf
.
reduce_sum
(
wrong
,
name
=
'wrong'
)
# monitor training error
# monitor training accuracy
tf
.
add_to_collection
(
tf
.
add_to_collection
(
SUMMARY_VARS_KEY
,
SUMMARY_VARS_KEY
,
tf
.
reduce_mean
(
wrong
,
name
=
'train_error'
))
tf
.
sub
(
1.0
,
tf
.
reduce_mean
(
wrong
),
name
=
'train_error'
))
# weight decay on all W of fc layers
# weight decay on all W of fc layers
wd_cost
=
tf
.
mul
(
1e-4
,
wd_cost
=
tf
.
mul
(
1e-4
,
...
@@ -86,6 +83,7 @@ def get_config():
...
@@ -86,6 +83,7 @@ def get_config():
sess_config
=
tf
.
ConfigProto
()
sess_config
=
tf
.
ConfigProto
()
sess_config
.
device_count
[
'GPU'
]
=
1
sess_config
.
device_count
[
'GPU'
]
=
1
sess_config
.
allow_soft_placement
=
True
# prepare model
# prepare model
image_var
=
tf
.
placeholder
(
tf
.
float32
,
shape
=
(
None
,
IMAGE_SIZE
,
IMAGE_SIZE
),
name
=
'input'
)
image_var
=
tf
.
placeholder
(
tf
.
float32
,
shape
=
(
None
,
IMAGE_SIZE
,
IMAGE_SIZE
),
name
=
'input'
)
...
@@ -117,12 +115,6 @@ def get_config():
...
@@ -117,12 +115,6 @@ def get_config():
max_epoch
=
100
,
max_epoch
=
100
,
)
)
def
main
(
argv
=
None
):
with
tf
.
Graph
()
.
as_default
():
from
train
import
prepare
,
start_train
prepare
()
config
=
get_config
()
start_train
(
config
)
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
tf
.
app
.
run
()
from
train
import
main
main
(
get_config
)
train.py
View file @
4a6b480c
...
@@ -7,6 +7,7 @@ import tensorflow as tf
...
@@ -7,6 +7,7 @@ import tensorflow as tf
from
utils
import
*
from
utils
import
*
from
dataflow
import
DataFlow
from
dataflow
import
DataFlow
from
itertools
import
count
from
itertools
import
count
import
argparse
def
prepare
():
def
prepare
():
keep_prob
=
tf
.
placeholder
(
keep_prob
=
tf
.
placeholder
(
...
@@ -92,3 +93,17 @@ def start_train(config):
...
@@ -92,3 +93,17 @@ def start_train(config):
callbacks
.
trigger_step
(
feed
,
outputs
,
cost
)
callbacks
.
trigger_step
(
feed
,
outputs
,
cost
)
callbacks
.
trigger_epoch
()
callbacks
.
trigger_epoch
()
def
main
(
get_config_func
):
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
'--gpu'
,
help
=
'GPU(s) to use.'
)
# nargs='*' in multi mode
args
=
parser
.
parse_args
()
device
=
'/cpu:0'
if
args
.
gpu
:
device
=
'/gpu:{}'
.
format
(
args
.
gpu
)
with
tf
.
Graph
()
.
as_default
():
with
tf
.
device
(
device
):
prepare
()
config
=
get_config_func
()
start_train
(
config
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment