Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
9bf42054
Commit
9bf42054
authored
Dec 27, 2015
by
ppwwyyxx
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
gpu option. before queue
parent
4a6b480c
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
8 additions
and
8 deletions
+8
-8
example_mnist.py
example_mnist.py
+2
-1
train.py
train.py
+6
-7
No files found.
example_mnist.py
View file @
9bf42054
...
@@ -74,7 +74,7 @@ def get_model(inputs):
...
@@ -74,7 +74,7 @@ def get_model(inputs):
def
get_config
():
def
get_config
():
IMAGE_SIZE
=
28
IMAGE_SIZE
=
28
LOG_DIR
=
'train_log'
LOG_DIR
=
os
.
path
.
join
(
'train_log'
,
os
.
path
.
basename
(
__file__
)[:
-
3
])
BATCH_SIZE
=
128
BATCH_SIZE
=
128
logger
.
set_file
(
os
.
path
.
join
(
LOG_DIR
,
'training.log'
))
logger
.
set_file
(
os
.
path
.
join
(
LOG_DIR
,
'training.log'
))
...
@@ -83,6 +83,7 @@ def get_config():
...
@@ -83,6 +83,7 @@ def get_config():
sess_config
=
tf
.
ConfigProto
()
sess_config
=
tf
.
ConfigProto
()
sess_config
.
device_count
[
'GPU'
]
=
1
sess_config
.
device_count
[
'GPU'
]
=
1
sess_config
.
gpu_options
.
allocator_type
=
'BFC'
sess_config
.
allow_soft_placement
=
True
sess_config
.
allow_soft_placement
=
True
# prepare model
# prepare model
...
...
train.py
View file @
9bf42054
...
@@ -93,17 +93,16 @@ def start_train(config):
...
@@ -93,17 +93,16 @@ def start_train(config):
callbacks
.
trigger_step
(
feed
,
outputs
,
cost
)
callbacks
.
trigger_step
(
feed
,
outputs
,
cost
)
callbacks
.
trigger_epoch
()
callbacks
.
trigger_epoch
()
sess
.
close
()
def
main
(
get_config_func
):
def
main
(
get_config_func
):
parser
=
argparse
.
ArgumentParser
()
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
'--gpu'
,
help
=
'GPU(s) to use.'
)
# nargs='*' in multi mode
parser
.
add_argument
(
'--gpu'
,
help
=
'
comma separated list of
GPU(s) to use.'
)
# nargs='*' in multi mode
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
device
=
'/cpu:0'
if
args
.
gpu
:
if
args
.
gpu
:
device
=
'/gpu:{}'
.
format
(
args
.
gpu
)
os
.
environ
[
'CUDA_VISIBLE_DEVICES'
]
=
args
.
gpu
with
tf
.
Graph
()
.
as_default
():
with
tf
.
Graph
()
.
as_default
():
with
tf
.
device
(
device
):
prepare
()
prepare
()
config
=
get_config_func
()
config
=
get_config_func
()
start_train
(
config
)
start_train
(
config
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment