Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
41bf8ffe
Commit
41bf8ffe
authored
Apr 02, 2016
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
add multigpu to cifar10_convnet
parent
b6370d50
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
17 additions
and
8 deletions
+17
-8
examples/cifar10_convnet.py
examples/cifar10_convnet.py
+4
-1
tensorpack/callbacks/param.py
tensorpack/callbacks/param.py
+2
-1
tensorpack/tfutils/common.py
tensorpack/tfutils/common.py
+6
-0
tensorpack/utils/utils.py
tensorpack/utils/utils.py
+5
-6
No files found.
examples/cifar10_convnet.py
View file @
41bf8ffe
...
...
@@ -20,6 +20,8 @@ from tensorpack.dataflow import imgaug
"""
CIFAR10 90
%
validation accuracy after 70k step.
91
%
validation accuracy after 36k step with 3 GPU.
"""
BATCH_SIZE
=
128
...
...
@@ -126,10 +128,11 @@ def get_config():
sess_config
=
get_default_sess_config
()
sess_config
.
gpu_options
.
per_process_gpu_memory_fraction
=
0.5
nr_gpu
=
get_nr_gpu
()
lr
=
tf
.
train
.
exponential_decay
(
learning_rate
=
1e-2
,
global_step
=
get_global_step_var
(),
decay_steps
=
dataset_train
.
size
()
*
30
,
decay_steps
=
dataset_train
.
size
()
*
30
if
nr_gpu
==
1
else
15
,
decay_rate
=
0.5
,
staircase
=
True
,
name
=
'learning_rate'
)
tf
.
scalar_summary
(
'learning_rate'
,
lr
)
...
...
tensorpack/callbacks/param.py
View file @
41bf8ffe
...
...
@@ -8,7 +8,8 @@ from abc import abstractmethod, ABCMeta
import
operator
from
.base
import
Callback
from
..utils
import
logger
,
get_op_var_name
from
..utils
import
logger
from
..tfutils
import
get_op_var_name
__all__
=
[
'HyperParamSetter'
,
'HumanHyperParamSetter'
,
'ScheduledHyperParamSetter'
]
...
...
tensorpack/tfutils/common.py
View file @
41bf8ffe
...
...
@@ -32,3 +32,9 @@ def get_global_step():
tf
.
get_default_session
(),
get_global_step_var
())
def
get_op_var_name
(
name
):
if
name
.
endswith
(
':0'
):
return
name
[:
-
2
],
name
else
:
return
name
,
name
+
':0'
tensorpack/utils/utils.py
View file @
41bf8ffe
...
...
@@ -11,7 +11,7 @@ import numpy as np
from
.
import
logger
__all__
=
[
'timed_operation'
,
'change_env'
,
'get_rng'
,
'memoized'
,
'get_
op_var_name
'
]
'get_
nr_gpu
'
]
#def expand_dim_if_necessary(var, dp):
# """
...
...
@@ -79,8 +79,7 @@ def get_rng(self):
seed
=
(
id
(
self
)
+
os
.
getpid
())
%
4294967295
return
np
.
random
.
RandomState
(
seed
)
def
get_op_var_name
(
name
):
if
name
.
endswith
(
':0'
):
return
name
[:
-
2
],
name
else
:
return
name
,
name
+
':0'
def
get_nr_gpu
():
env
=
os
.
environ
[
'CUDA_VISIBLE_DEVICES'
]
assert
env
is
not
None
return
len
(
env
.
split
(
','
))
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment