Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
ab86361f
Commit
ab86361f
authored
Jan 18, 2017
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
use "tensor" instead of "var"
parent
64f97425
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
13 additions
and
12 deletions
+13
-12
tensorpack/callbacks/param.py
tensorpack/callbacks/param.py
+2
-2
tensorpack/tfutils/__init__.py
tensorpack/tfutils/__init__.py
+1
-1
tensorpack/tfutils/common.py
tensorpack/tfutils/common.py
+8
-7
tensorpack/tfutils/sessinit.py
tensorpack/tfutils/sessinit.py
+2
-2
No files found.
tensorpack/callbacks/param.py
View file @
ab86361f
...
@@ -11,7 +11,7 @@ import os
...
@@ -11,7 +11,7 @@ import os
from
.base
import
Callback
from
.base
import
Callback
from
..utils
import
logger
from
..utils
import
logger
from
..tfutils
import
get_op_
va
r_name
from
..tfutils
import
get_op_
tenso
r_name
__all__
=
[
'HyperParam'
,
'GraphVarParam'
,
'ObjAttrParam'
,
__all__
=
[
'HyperParam'
,
'GraphVarParam'
,
'ObjAttrParam'
,
'HyperParamSetter'
,
'HumanHyperParamSetter'
,
'HyperParamSetter'
,
'HumanHyperParamSetter'
,
...
@@ -62,7 +62,7 @@ class GraphVarParam(HyperParam):
...
@@ -62,7 +62,7 @@ class GraphVarParam(HyperParam):
"""
"""
self
.
name
=
name
self
.
name
=
name
self
.
shape
=
shape
self
.
shape
=
shape
self
.
_readable_name
,
self
.
var_name
=
get_op_
va
r_name
(
name
)
self
.
_readable_name
,
self
.
var_name
=
get_op_
tenso
r_name
(
name
)
def
setup_graph
(
self
):
def
setup_graph
(
self
):
""" Will setup the assign operator for that variable. """
""" Will setup the assign operator for that variable. """
...
...
tensorpack/tfutils/__init__.py
View file @
ab86361f
...
@@ -17,8 +17,8 @@ def _global_import(name):
...
@@ -17,8 +17,8 @@ def _global_import(name):
_TO_IMPORT
=
set
([
_TO_IMPORT
=
set
([
'sessinit'
,
'common'
,
'common'
,
'sessinit'
,
'gradproc'
,
'gradproc'
,
'argscope'
,
'argscope'
,
'tower'
'tower'
...
...
tensorpack/tfutils/common.py
View file @
ab86361f
...
@@ -12,10 +12,9 @@ from contextlib import contextmanager
...
@@ -12,10 +12,9 @@ from contextlib import contextmanager
__all__
=
[
'get_default_sess_config'
,
__all__
=
[
'get_default_sess_config'
,
'get_global_step'
,
'get_global_step'
,
'get_global_step_var'
,
'get_global_step_var'
,
'get_op_var_name'
,
'get_op_tensor_name'
,
'get_op_tensor_name'
,
'get_vars_by_names'
,
'get_tensors_by_names'
,
'get_tensors_by_names'
,
'get_op_or_tensor_by_name'
,
'backup_collection'
,
'backup_collection'
,
'restore_collection'
,
'restore_collection'
,
'clear_collection'
,
'clear_collection'
,
...
@@ -87,9 +86,6 @@ def get_op_tensor_name(name):
...
@@ -87,9 +86,6 @@ def get_op_tensor_name(name):
return
name
,
name
+
':0'
return
name
,
name
+
':0'
get_op_var_name
=
get_op_tensor_name
def
get_tensors_by_names
(
names
):
def
get_tensors_by_names
(
names
):
"""
"""
Get a list of tensors in the default graph by a list of names.
Get a list of tensors in the default graph by a list of names.
...
@@ -100,12 +96,17 @@ def get_tensors_by_names(names):
...
@@ -100,12 +96,17 @@ def get_tensors_by_names(names):
ret
=
[]
ret
=
[]
G
=
tf
.
get_default_graph
()
G
=
tf
.
get_default_graph
()
for
n
in
names
:
for
n
in
names
:
opn
,
varn
=
get_op_
va
r_name
(
n
)
opn
,
varn
=
get_op_
tenso
r_name
(
n
)
ret
.
append
(
G
.
get_tensor_by_name
(
varn
))
ret
.
append
(
G
.
get_tensor_by_name
(
varn
))
return
ret
return
ret
get_vars_by_names
=
get_tensors_by_names
def
get_op_or_tensor_by_name
(
name
):
G
=
tf
.
get_default_graph
()
if
len
(
name
)
>=
3
and
name
[
-
2
]
==
':'
:
return
G
.
get_tensor_by_name
(
name
)
else
:
return
G
.
get_operation_by_name
(
name
)
def
backup_collection
(
keys
):
def
backup_collection
(
keys
):
...
...
tensorpack/tfutils/sessinit.py
View file @
ab86361f
...
@@ -10,7 +10,7 @@ import tensorflow as tf
...
@@ -10,7 +10,7 @@ import tensorflow as tf
import
six
import
six
from
..utils
import
logger
,
PREDICT_TOWER
from
..utils
import
logger
,
PREDICT_TOWER
from
.common
import
get_op_
va
r_name
from
.common
import
get_op_
tenso
r_name
from
.varmanip
import
(
SessionUpdate
,
get_savename_from_varname
,
from
.varmanip
import
(
SessionUpdate
,
get_savename_from_varname
,
is_training_name
,
get_checkpoint_path
)
is_training_name
,
get_checkpoint_path
)
...
@@ -149,7 +149,7 @@ class ParamRestore(SessionInit):
...
@@ -149,7 +149,7 @@ class ParamRestore(SessionInit):
param_dict (dict): a dict of {name: value}
param_dict (dict): a dict of {name: value}
"""
"""
# use varname (with :0) for consistency
# use varname (with :0) for consistency
self
.
prms
=
{
get_op_
va
r_name
(
n
)[
1
]:
v
for
n
,
v
in
six
.
iteritems
(
param_dict
)}
self
.
prms
=
{
get_op_
tenso
r_name
(
n
)[
1
]:
v
for
n
,
v
in
six
.
iteritems
(
param_dict
)}
def
_init
(
self
,
sess
):
def
_init
(
self
,
sess
):
variables
=
tf
.
get_collection
(
tf
.
GraphKeys
()
.
VARIABLES
)
# TODO
variables
=
tf
.
get_collection
(
tf
.
GraphKeys
()
.
VARIABLES
)
# TODO
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment