Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
377fba1f
Commit
377fba1f
authored
Sep 14, 2017
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
clean-up some deprecation
parent
73b63247
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
9 additions
and
23 deletions
+9
-23
docs/conf.py
docs/conf.py
+0
-2
tensorpack/tfutils/sessinit.py
tensorpack/tfutils/sessinit.py
+8
-15
tensorpack/train/distributed.py
tensorpack/train/distributed.py
+1
-6
No files found.
docs/conf.py
View file @
377fba1f
...
@@ -353,7 +353,6 @@ def process_signature(app, what, name, obj, options, signature,
...
@@ -353,7 +353,6 @@ def process_signature(app, what, name, obj, options, signature,
def
autodoc_skip_member
(
app
,
what
,
name
,
obj
,
skip
,
options
):
def
autodoc_skip_member
(
app
,
what
,
name
,
obj
,
skip
,
options
):
if
name
in
[
if
name
in
[
'DistributedReplicatedTrainer'
,
'SingleCostFeedfreeTrainer'
,
'SingleCostFeedfreeTrainer'
,
'SimpleFeedfreeTrainer'
,
'SimpleFeedfreeTrainer'
,
'FeedfreeTrainerBase'
,
'FeedfreeTrainerBase'
,
...
@@ -369,7 +368,6 @@ def autodoc_skip_member(app, what, name, obj, skip, options):
...
@@ -369,7 +368,6 @@ def autodoc_skip_member(app, what, name, obj, skip, options):
'GaussianDeform'
,
'GaussianDeform'
,
'dump_chkpt_vars'
,
'dump_chkpt_vars'
,
'VisualQA'
,
'VisualQA'
,
'ParamRestore'
,
'huber_loss'
'huber_loss'
]:
]:
return
True
return
True
...
...
tensorpack/tfutils/sessinit.py
View file @
377fba1f
...
@@ -8,14 +8,12 @@ import tensorflow as tf
...
@@ -8,14 +8,12 @@ import tensorflow as tf
import
six
import
six
from
..utils
import
logger
from
..utils
import
logger
from
..utils.develop
import
deprecated
from
.common
import
get_op_tensor_name
from
.common
import
get_op_tensor_name
from
.varmanip
import
(
SessionUpdate
,
get_savename_from_varname
,
from
.varmanip
import
(
SessionUpdate
,
get_savename_from_varname
,
is_training_name
,
get_checkpoint_path
)
is_training_name
,
get_checkpoint_path
)
__all__
=
[
'SessionInit'
,
'ChainInit'
,
__all__
=
[
'SessionInit'
,
'ChainInit'
,
'SaverRestore'
,
'SaverRestoreRelaxed'
,
'SaverRestore'
,
'SaverRestoreRelaxed'
,
'DictRestore'
,
'ParamRestore'
,
'DictRestore'
,
'JustCurrentSession'
,
'get_model_loader'
,
'TryResumeTraining'
]
'JustCurrentSession'
,
'get_model_loader'
,
'TryResumeTraining'
]
...
@@ -191,24 +189,24 @@ class DictRestore(SessionInit):
...
@@ -191,24 +189,24 @@ class DictRestore(SessionInit):
Restore variables from a dictionary.
Restore variables from a dictionary.
"""
"""
def
__init__
(
self
,
param
_dict
):
def
__init__
(
self
,
variable
_dict
):
"""
"""
Args:
Args:
param
_dict (dict): a dict of {name: value}
variable
_dict (dict): a dict of {name: value}
"""
"""
assert
isinstance
(
param_dict
,
dict
),
type
(
param
_dict
)
assert
isinstance
(
variable_dict
,
dict
),
type
(
variable
_dict
)
# use varname (with :0) for consistency
# use varname (with :0) for consistency
self
.
prms
=
{
get_op_tensor_name
(
n
)[
1
]:
v
for
n
,
v
in
six
.
iteritems
(
param
_dict
)}
self
.
_prms
=
{
get_op_tensor_name
(
n
)[
1
]:
v
for
n
,
v
in
six
.
iteritems
(
variable
_dict
)}
def
_run_init
(
self
,
sess
):
def
_run_init
(
self
,
sess
):
variables
=
tf
.
get_collection
(
tf
.
GraphKeys
.
GLOBAL_VARIABLES
)
variables
=
tf
.
get_collection
(
tf
.
GraphKeys
.
GLOBAL_VARIABLES
)
variable_names
=
set
([
k
.
name
for
k
in
variables
])
variable_names
=
set
([
k
.
name
for
k
in
variables
])
param_names
=
set
(
six
.
iterkeys
(
self
.
prms
))
param_names
=
set
(
six
.
iterkeys
(
self
.
_
prms
))
intersect
=
variable_names
&
param_names
intersect
=
variable_names
&
param_names
logger
.
info
(
"
Params to restore
: {}"
.
format
(
', '
.
join
(
map
(
str
,
intersect
))))
logger
.
info
(
"
Variables to restore from dict
: {}"
.
format
(
', '
.
join
(
map
(
str
,
intersect
))))
mismatch
=
MismatchLogger
(
'graph'
,
'dict'
)
mismatch
=
MismatchLogger
(
'graph'
,
'dict'
)
for
k
in
sorted
(
variable_names
-
param_names
):
for
k
in
sorted
(
variable_names
-
param_names
):
...
@@ -222,12 +220,7 @@ class DictRestore(SessionInit):
...
@@ -222,12 +220,7 @@ class DictRestore(SessionInit):
upd
=
SessionUpdate
(
sess
,
[
v
for
v
in
variables
if
v
.
name
in
intersect
])
upd
=
SessionUpdate
(
sess
,
[
v
for
v
in
variables
if
v
.
name
in
intersect
])
logger
.
info
(
"Restoring from dict ..."
)
logger
.
info
(
"Restoring from dict ..."
)
upd
.
update
({
name
:
value
for
name
,
value
in
six
.
iteritems
(
self
.
prms
)
if
name
in
intersect
})
upd
.
update
({
name
:
value
for
name
,
value
in
six
.
iteritems
(
self
.
_prms
)
if
name
in
intersect
})
@
deprecated
(
"Use `DictRestore` instead!"
,
"2017-09-01"
)
def
ParamRestore
(
d
):
return
DictRestore
(
d
)
class
ChainInit
(
SessionInit
):
class
ChainInit
(
SessionInit
):
...
...
tensorpack/train/distributed.py
View file @
377fba1f
...
@@ -16,7 +16,7 @@ from .multigpu import MultiGPUTrainerBase
...
@@ -16,7 +16,7 @@ from .multigpu import MultiGPUTrainerBase
from
.utility
import
override_to_local_variable
from
.utility
import
override_to_local_variable
__all__
=
[
'Distributed
ReplicatedTrainer'
,
'Distributed
TrainerReplicated'
]
__all__
=
[
'DistributedTrainerReplicated'
]
class
DistributedTrainerReplicated
(
MultiGPUTrainerBase
):
class
DistributedTrainerReplicated
(
MultiGPUTrainerBase
):
...
@@ -336,8 +336,3 @@ class DistributedTrainerReplicated(MultiGPUTrainerBase):
...
@@ -336,8 +336,3 @@ class DistributedTrainerReplicated(MultiGPUTrainerBase):
@
property
@
property
def
vs_name_for_predictor
(
self
):
def
vs_name_for_predictor
(
self
):
return
"tower0"
return
"tower0"
def
DistributedReplicatedTrainer
(
*
args
,
**
kwargs
):
logger
.
warn
(
"DistributedReplicatedTrainer was renamed to DistributedTrainerReplicated!"
)
return
DistributedTrainerReplicated
(
*
args
,
**
kwargs
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment