Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
00811100
Commit
00811100
authored
May 23, 2017
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
check tf version in multigpu.
parent
4a88dfc3
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
10 additions
and
6 deletions
+10
-6
tensorpack/libinfo.py
tensorpack/libinfo.py
+1
-1
tensorpack/tfutils/common.py
tensorpack/tfutils/common.py
+0
-1
tensorpack/train/multigpu.py
tensorpack/train/multigpu.py
+7
-0
tensorpack/train/predict.py
tensorpack/train/predict.py
+2
-4
No files found.
tensorpack/libinfo.py
View file @
00811100
...
@@ -10,4 +10,4 @@ os.environ['TF_ENABLE_WINOGRAD_NONFUSED'] = '1' # issue#9339
...
@@ -10,4 +10,4 @@ os.environ['TF_ENABLE_WINOGRAD_NONFUSED'] = '1' # issue#9339
os
.
environ
[
'TF_AUTOTUNE_THRESHOLD'
]
=
'3'
# use more warm-up
os
.
environ
[
'TF_AUTOTUNE_THRESHOLD'
]
=
'3'
# use more warm-up
os
.
environ
[
'TF_AVGPOOL_USE_CUDNN'
]
=
'1'
# issue#8566
os
.
environ
[
'TF_AVGPOOL_USE_CUDNN'
]
=
'1'
# issue#8566
__version__
=
'0.
1.9
'
__version__
=
'0.
2.0
'
tensorpack/tfutils/common.py
View file @
00811100
...
@@ -10,7 +10,6 @@ from ..utils.argtools import graph_memoized
...
@@ -10,7 +10,6 @@ from ..utils.argtools import graph_memoized
from
..utils.naming
import
GLOBAL_STEP_OP_NAME
from
..utils.naming
import
GLOBAL_STEP_OP_NAME
__all__
=
[
'get_default_sess_config'
,
__all__
=
[
'get_default_sess_config'
,
'get_global_step_value'
,
'get_global_step_value'
,
'get_global_step_var'
,
'get_global_step_var'
,
'get_op_tensor_name'
,
'get_op_tensor_name'
,
...
...
tensorpack/train/multigpu.py
View file @
00811100
...
@@ -27,6 +27,11 @@ __all__ = ['MultiGPUTrainerBase', 'SyncMultiGPUTrainer',
...
@@ -27,6 +27,11 @@ __all__ = ['MultiGPUTrainerBase', 'SyncMultiGPUTrainer',
'SyncMultiGPUTrainerParameterServer'
]
'SyncMultiGPUTrainerParameterServer'
]
def
_check_tf_version
():
ver
=
float
(
'.'
.
join
(
tf
.
VERSION
.
split
(
'.'
)[:
2
]))
assert
ver
>=
1.1
,
"TF version {} is too old to run multi GPU training!"
.
format
(
ver
)
def
apply_prefetch_policy
(
config
,
use_stage
=
True
):
def
apply_prefetch_policy
(
config
,
use_stage
=
True
):
if
config
.
data
is
None
and
config
.
dataflow
is
not
None
:
if
config
.
data
is
None
and
config
.
dataflow
is
not
None
:
config
.
data
=
QueueInput
(
config
.
dataflow
)
config
.
data
=
QueueInput
(
config
.
dataflow
)
...
@@ -55,6 +60,8 @@ class MultiGPUTrainerBase(Trainer):
...
@@ -55,6 +60,8 @@ class MultiGPUTrainerBase(Trainer):
List of outputs of ``func``, evaluated on each tower.
List of outputs of ``func``, evaluated on each tower.
"""
"""
logger
.
info
(
"Training a model of {} tower"
.
format
(
len
(
towers
)))
logger
.
info
(
"Training a model of {} tower"
.
format
(
len
(
towers
)))
if
len
(
towers
)
>
1
:
_check_tf_version
()
ret
=
[]
ret
=
[]
if
devices
is
not
None
:
if
devices
is
not
None
:
...
...
tensorpack/train/predict.py
View file @
00811100
...
@@ -3,7 +3,6 @@
...
@@ -3,7 +3,6 @@
# File: predict.py
# File: predict.py
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
import
tensorflow
as
tf
from
..predict
import
(
OnlinePredictor
,
from
..predict
import
(
OnlinePredictor
,
PredictorTowerBuilder
)
PredictorTowerBuilder
)
...
@@ -34,9 +33,8 @@ class PredictorFactory(object):
...
@@ -34,9 +33,8 @@ class PredictorFactory(object):
an online predictor (which has to be used under a default session)
an online predictor (which has to be used under a default session)
"""
"""
tower
=
self
.
towers
[
tower
]
tower
=
self
.
towers
[
tower
]
with
tf
.
variable_scope
(
tf
.
get_variable_scope
(),
reuse
=
True
):
# just ensure the tower exists. won't rebuild (memoized)
# just ensure the tower exists. won't rebuild
self
.
_tower_builder
.
build
(
tower
)
self
.
_tower_builder
.
build
(
tower
)
placeholder_names
=
set
([
k
.
name
for
k
in
self
.
model
.
get_inputs_desc
()])
placeholder_names
=
set
([
k
.
name
for
k
in
self
.
model
.
get_inputs_desc
()])
get_tensor_fn
=
PredictorTowerBuilder
.
get_tensors_maybe_in_tower
get_tensor_fn
=
PredictorTowerBuilder
.
get_tensors_maybe_in_tower
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment