Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
4942ef45
Commit
4942ef45
authored
Jun 20, 2016
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
improve both sessioninit
parent
3b2f7df1
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
14 additions
and
6 deletions
+14
-6
tensorpack/callbacks/inference.py
tensorpack/callbacks/inference.py
+1
-1
tensorpack/models/batch_norm.py
tensorpack/models/batch_norm.py
+1
-1
tensorpack/models/model_desc.py
tensorpack/models/model_desc.py
+0
-1
tensorpack/predict/common.py
tensorpack/predict/common.py
+5
-1
tensorpack/tfutils/sessinit.py
tensorpack/tfutils/sessinit.py
+7
-2
No files found.
tensorpack/callbacks/inference.py
View file @
4942ef45
...
@@ -121,7 +121,7 @@ class InferenceRunner(Callback):
...
@@ -121,7 +121,7 @@ class InferenceRunner(Callback):
class
ScalarStats
(
Inferencer
):
class
ScalarStats
(
Inferencer
):
"""
"""
Write s
tat and summary of some scalar tensor
.
Write s
ome scalar tensor to both stat and summary
.
The output of the given Ops must be a scalar.
The output of the given Ops must be a scalar.
The value will be averaged over all data points in the dataset.
The value will be averaged over all data points in the dataset.
"""
"""
...
...
tensorpack/models/batch_norm.py
View file @
4942ef45
...
@@ -42,7 +42,7 @@ def BatchNorm(x, use_local_stat=True, decay=0.9, epsilon=1e-5):
...
@@ -42,7 +42,7 @@ def BatchNorm(x, use_local_stat=True, decay=0.9, epsilon=1e-5):
assert
n_out
is
not
None
assert
n_out
is
not
None
beta
=
tf
.
get_variable
(
'beta'
,
[
n_out
])
beta
=
tf
.
get_variable
(
'beta'
,
[
n_out
])
gamma
=
tf
.
get_variable
(
'gamma'
,
[
n_out
],
gamma
=
tf
.
get_variable
(
'gamma'
,
[
n_out
],
initializer
=
tf
.
constant_initializer
(
1.0
)
)
initializer
=
tf
.
ones_initializer
)
if
len
(
shape
)
==
2
:
if
len
(
shape
)
==
2
:
batch_mean
,
batch_var
=
tf
.
nn
.
moments
(
x
,
[
0
],
keep_dims
=
False
)
batch_mean
,
batch_var
=
tf
.
nn
.
moments
(
x
,
[
0
],
keep_dims
=
False
)
...
...
tensorpack/models/model_desc.py
View file @
4942ef45
...
@@ -96,4 +96,3 @@ class ModelFromMetaGraph(ModelDesc):
...
@@ -96,4 +96,3 @@ class ModelFromMetaGraph(ModelDesc):
def
_build_graph
(
self
,
_
,
__
):
def
_build_graph
(
self
,
_
,
__
):
""" Do nothing. Graph was imported already """
""" Do nothing. Graph was imported already """
pass
pass
tensorpack/predict/common.py
View file @
4942ef45
...
@@ -6,6 +6,7 @@ import tensorflow as tf
...
@@ -6,6 +6,7 @@ import tensorflow as tf
from
collections
import
namedtuple
from
collections
import
namedtuple
from
six.moves
import
zip
from
six.moves
import
zip
from
tensorpack.models
import
ModelDesc
from
..tfutils
import
*
from
..tfutils
import
*
import
multiprocessing
import
multiprocessing
...
@@ -53,8 +54,10 @@ class PredictConfig(object):
...
@@ -53,8 +54,10 @@ class PredictConfig(object):
# XXX does it work? start with minimal memory, but allow growth.
# XXX does it work? start with minimal memory, but allow growth.
# allow_growth doesn't seem to work very well in TF.
# allow_growth doesn't seem to work very well in TF.
self
.
session_config
=
kwargs
.
pop
(
'session_config'
,
get_default_sess_config
(
0.3
))
self
.
session_config
=
kwargs
.
pop
(
'session_config'
,
get_default_sess_config
(
0.3
))
self
.
session_init
=
kwargs
.
pop
(
'session_init'
)
self
.
session_init
=
kwargs
.
pop
(
'session_init'
,
JustCurrentSession
())
assert_type
(
self
.
session_init
,
SessionInit
)
self
.
model
=
kwargs
.
pop
(
'model'
)
self
.
model
=
kwargs
.
pop
(
'model'
)
assert_type
(
self
.
model
,
ModelDesc
)
self
.
input_data_mapping
=
kwargs
.
pop
(
'input_data_mapping'
,
None
)
self
.
input_data_mapping
=
kwargs
.
pop
(
'input_data_mapping'
,
None
)
self
.
output_var_names
=
kwargs
.
pop
(
'output_var_names'
)
self
.
output_var_names
=
kwargs
.
pop
(
'output_var_names'
)
self
.
return_input
=
kwargs
.
pop
(
'return_input'
,
False
)
self
.
return_input
=
kwargs
.
pop
(
'return_input'
,
False
)
...
@@ -86,4 +89,5 @@ def get_predict_func(config):
...
@@ -86,4 +89,5 @@ def get_predict_func(config):
def
run_input
(
dp
):
def
run_input
(
dp
):
feed
=
dict
(
zip
(
input_map
,
dp
))
feed
=
dict
(
zip
(
input_map
,
dp
))
return
sess
.
run
(
output_vars
,
feed_dict
=
feed
)
return
sess
.
run
(
output_vars
,
feed_dict
=
feed
)
run_input
.
session
=
sess
return
run_input
return
run_input
tensorpack/tfutils/sessinit.py
View file @
4942ef45
...
@@ -68,7 +68,9 @@ class SaverRestore(SessionInit):
...
@@ -68,7 +68,9 @@ class SaverRestore(SessionInit):
chkpt_vars
=
SaverRestore
.
_read_checkpoint_vars
(
self
.
path
)
chkpt_vars
=
SaverRestore
.
_read_checkpoint_vars
(
self
.
path
)
vars_map
=
SaverRestore
.
_get_vars_to_restore_multimap
(
chkpt_vars
)
vars_map
=
SaverRestore
.
_get_vars_to_restore_multimap
(
chkpt_vars
)
for
dic
in
SaverRestore
.
_produce_restore_dict
(
vars_map
):
for
dic
in
SaverRestore
.
_produce_restore_dict
(
vars_map
):
saver
=
tf
.
train
.
Saver
(
var_list
=
dic
)
# multiple saver under same name scope would cause error:
# training/saver.py: assert restore_op.name.endswith("restore_all"), restore_op.name
saver
=
tf
.
train
.
Saver
(
var_list
=
dic
,
name
=
str
(
id
(
dic
)))
saver
.
restore
(
sess
,
self
.
path
)
saver
.
restore
(
sess
,
self
.
path
)
def
set_path
(
self
,
model_path
):
def
set_path
(
self
,
model_path
):
...
@@ -148,7 +150,10 @@ class ParamRestore(SessionInit):
...
@@ -148,7 +150,10 @@ class ParamRestore(SessionInit):
"{}: {}!={}"
.
format
(
name
,
varshape
,
value
.
shape
)
"{}: {}!={}"
.
format
(
name
,
varshape
,
value
.
shape
)
logger
.
warn
(
"Param {} is reshaped during loading!"
.
format
(
name
))
logger
.
warn
(
"Param {} is reshaped during loading!"
.
format
(
name
))
value
=
value
.
reshape
(
varshape
)
value
=
value
.
reshape
(
varshape
)
sess
.
run
(
var
.
assign
(
value
))
# assign(value) creates ops with values being saved, doubling the size of metagraph
# assign(placeholder) works better here
p
=
tf
.
placeholder
(
value
.
dtype
,
shape
=
value
.
shape
)
sess
.
run
(
var
.
assign
(
p
),
feed_dict
=
{
p
:
value
})
def
ChainInit
(
SessionInit
):
def
ChainInit
(
SessionInit
):
""" Init a session by a list of SessionInit instance."""
""" Init a session by a list of SessionInit instance."""
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment