Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
3b2f7df1
Commit
3b2f7df1
authored
Jun 20, 2016
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
load model from meta
parent
03b92aba
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
46 additions
and
36 deletions
+46
-36
tensorpack/models/model_desc.py
tensorpack/models/model_desc.py
+31
-21
tensorpack/predict/common.py
tensorpack/predict/common.py
+5
-6
tensorpack/predict/concurrency.py
tensorpack/predict/concurrency.py
+3
-6
tensorpack/tfutils/argscope.py
tensorpack/tfutils/argscope.py
+5
-2
tensorpack/train/trainer.py
tensorpack/train/trainer.py
+1
-1
tensorpack/utils/naming.py
tensorpack/utils/naming.py
+1
-0
No files found.
tensorpack/models/model_desc.py
View file @
3b2f7df1
...
@@ -7,10 +7,10 @@ from abc import ABCMeta, abstractmethod
...
@@ -7,10 +7,10 @@ from abc import ABCMeta, abstractmethod
import
tensorflow
as
tf
import
tensorflow
as
tf
from
collections
import
namedtuple
from
collections
import
namedtuple
from
..utils
import
logger
from
..utils
import
logger
,
INPUT_VARS_KEY
from
..tfutils
import
*
from
..tfutils
import
*
__all__
=
[
'ModelDesc'
,
'InputVar'
]
__all__
=
[
'ModelDesc'
,
'InputVar'
,
'ModelFromMetaGraph'
]
InputVar
=
namedtuple
(
'InputVar'
,
[
'type'
,
'shape'
,
'name'
])
InputVar
=
namedtuple
(
'InputVar'
,
[
'type'
,
'shape'
,
'name'
])
...
@@ -32,6 +32,8 @@ class ModelDesc(object):
...
@@ -32,6 +32,8 @@ class ModelDesc(object):
ret
=
[]
ret
=
[]
for
v
in
input_vars
:
for
v
in
input_vars
:
ret
.
append
(
tf
.
placeholder
(
v
.
type
,
shape
=
v
.
shape
,
name
=
v
.
name
))
ret
.
append
(
tf
.
placeholder
(
v
.
type
,
shape
=
v
.
shape
,
name
=
v
.
name
))
for
v
in
ret
:
tf
.
add_to_collection
(
INPUT_VARS_KEY
,
v
)
return
ret
return
ret
def
reuse_input_vars
(
self
):
def
reuse_input_vars
(
self
):
...
@@ -57,28 +59,12 @@ class ModelDesc(object):
...
@@ -57,28 +59,12 @@ class ModelDesc(object):
"""
"""
self
.
_build_graph
(
model_inputs
,
is_training
)
self
.
_build_graph
(
model_inputs
,
is_training
)
#
@abstractmethod
@
abstractmethod
def
_build_graph
(
self
,
inputs
,
is_training
):
def
_build_graph
(
self
,
inputs
,
is_training
):
if
self
.
_old_version
():
pass
self
.
model_inputs
=
inputs
self
.
is_training
=
is_training
else
:
raise
NotImplementedError
()
def
_old_version
(
self
):
# for backward-compat only.
import
inspect
args
=
inspect
.
getargspec
(
self
.
_get_cost
)
return
len
(
args
.
args
)
==
3
def
get_cost
(
self
):
def
get_cost
(
self
):
if
self
.
_old_version
():
return
self
.
_get_cost
()
assert
type
(
self
.
is_training
)
==
bool
logger
.
warn
(
"!!!using _get_cost to setup the graph is deprecated in favor of _build_graph"
)
logger
.
warn
(
"See examples for details."
)
return
self
.
_get_cost
(
self
.
model_inputs
,
self
.
is_training
)
else
:
return
self
.
_get_cost
()
def
_get_cost
(
self
,
*
args
):
def
_get_cost
(
self
,
*
args
):
return
self
.
cost
return
self
.
cost
...
@@ -87,3 +73,27 @@ class ModelDesc(object):
...
@@ -87,3 +73,27 @@ class ModelDesc(object):
""" Return a list of GradientProcessor. They will be executed in order"""
""" Return a list of GradientProcessor. They will be executed in order"""
return
[
CheckGradient
()]
#, SummaryGradient()]
return
[
CheckGradient
()]
#, SummaryGradient()]
class
ModelFromMetaGraph
(
ModelDesc
):
"""
Load the whole exact TF graph from a saved meta_graph.
Only useful for inference.
"""
def
__init__
(
self
,
filename
):
tf
.
train
.
import_meta_graph
(
filename
)
all_coll
=
tf
.
get_default_graph
()
.
get_all_collection_keys
()
for
k
in
[
INPUT_VARS_KEY
,
tf
.
GraphKeys
.
TRAINABLE_VARIABLES
,
tf
.
GraphKeys
.
VARIABLES
]:
assert
k
in
all_coll
,
\
"Collection {} not found in metagraph!"
.
format
(
k
)
def
get_input_vars
(
self
):
return
tf
.
get_collection
(
INPUT_VARS_KEY
)
def
_get_input_vars
(
self
):
raise
NotImplementedError
(
"Shouldn't call here"
)
def
_build_graph
(
self
,
_
,
__
):
""" Do nothing. Graph was imported already """
pass
tensorpack/predict/common.py
View file @
3b2f7df1
...
@@ -50,8 +50,9 @@ class PredictConfig(object):
...
@@ -50,8 +50,9 @@ class PredictConfig(object):
"""
"""
def
assert_type
(
v
,
tp
):
def
assert_type
(
v
,
tp
):
assert
isinstance
(
v
,
tp
),
v
.
__class__
assert
isinstance
(
v
,
tp
),
v
.
__class__
self
.
session_config
=
kwargs
.
pop
(
'session_config'
,
# XXX does it work? start with minimal memory, but allow growth.
get_default_sess_config
(
0.3
))
# allow_growth doesn't seem to work very well in TF.
self
.
session_config
=
kwargs
.
pop
(
'session_config'
,
get_default_sess_config
(
0.3
))
self
.
session_init
=
kwargs
.
pop
(
'session_init'
)
self
.
session_init
=
kwargs
.
pop
(
'session_init'
)
self
.
model
=
kwargs
.
pop
(
'model'
)
self
.
model
=
kwargs
.
pop
(
'model'
)
self
.
input_data_mapping
=
kwargs
.
pop
(
'input_data_mapping'
,
None
)
self
.
input_data_mapping
=
kwargs
.
pop
(
'input_data_mapping'
,
None
)
...
@@ -61,7 +62,7 @@ class PredictConfig(object):
...
@@ -61,7 +62,7 @@ class PredictConfig(object):
def
get_predict_func
(
config
):
def
get_predict_func
(
config
):
"""
"""
Produce a simple predictor function
in a newly-created session without any parallelism
.
Produce a simple predictor function
run inside a new session
.
:param config: a `PredictConfig` instance.
:param config: a `PredictConfig` instance.
:returns: A prediction function that takes a list of input values, and return
:returns: A prediction function that takes a list of input values, and return
a list of output values defined in ``config.output_var_names``.
a list of output values defined in ``config.output_var_names``.
...
@@ -77,10 +78,8 @@ def get_predict_func(config):
...
@@ -77,10 +78,8 @@ def get_predict_func(config):
input_map
=
[
input_vars
[
k
]
for
k
in
config
.
input_data_mapping
]
input_map
=
[
input_vars
[
k
]
for
k
in
config
.
input_data_mapping
]
# check output_var_names against output_vars
# check output_var_names against output_vars
output_vars
=
[
tf
.
get_default_graph
()
.
get_tensor_by_name
(
get_op_var_name
(
n
)[
1
])
output_vars
=
get_vars_by_names
(
output_var_names
)
for
n
in
output_var_names
]
# XXX does it work? start with minimal memory, but allow growth
sess
=
tf
.
Session
(
config
=
config
.
session_config
)
sess
=
tf
.
Session
(
config
=
config
.
session_config
)
config
.
session_init
.
init
(
sess
)
config
.
session_init
.
init
(
sess
)
...
...
tensorpack/predict/concurrency.py
View file @
3b2f7df1
...
@@ -105,9 +105,9 @@ class PredictorWorkerThread(threading.Thread):
...
@@ -105,9 +105,9 @@ class PredictorWorkerThread(threading.Thread):
for
k
in
range
(
self
.
nr_input_var
):
for
k
in
range
(
self
.
nr_input_var
):
batched
[
k
]
.
append
(
inp
[
k
])
batched
[
k
]
.
append
(
inp
[
k
])
futures
.
append
(
f
)
futures
.
append
(
f
)
cnt
+=
1
except
queue
.
Empty
:
except
queue
.
Empty
:
break
break
cnt
+=
1
return
batched
,
futures
return
batched
,
futures
#self.xxx = None
#self.xxx = None
while
True
:
while
True
:
...
@@ -116,12 +116,9 @@ class PredictorWorkerThread(threading.Thread):
...
@@ -116,12 +116,9 @@ class PredictorWorkerThread(threading.Thread):
outputs
=
self
.
func
(
batched
)
outputs
=
self
.
func
(
batched
)
# debug, for speed testing
# debug, for speed testing
#if self.xxx is None:
#if self.xxx is None:
#outputs = self.func([batched])
#self.xxx = outputs = self.func([batched])
#self.xxx = outputs
#else:
#else:
#outputs = [None, None]
#outputs = [[self.xxx[0][0]] * len(batched), [self.xxx[1][0]] * len(batched)]
#outputs[0] = [self.xxx[0][0]] * len(batched)
#outputs[1] = [self.xxx[1][0]] * len(batched)
for
idx
,
f
in
enumerate
(
futures
):
for
idx
,
f
in
enumerate
(
futures
):
f
.
set_result
([
k
[
idx
]
for
k
in
outputs
])
f
.
set_result
([
k
[
idx
]
for
k
in
outputs
])
...
...
tensorpack/tfutils/argscope.py
View file @
3b2f7df1
...
@@ -13,8 +13,7 @@ __all__ = ['argscope', 'get_arg_scope']
...
@@ -13,8 +13,7 @@ __all__ = ['argscope', 'get_arg_scope']
_ArgScopeStack
=
[]
_ArgScopeStack
=
[]
@
contextmanager
@
contextmanager
def
argscope
(
layers
,
**
kwargs
):
def
argscope
(
layers
,
**
param
):
param
=
kwargs
if
not
isinstance
(
layers
,
list
):
if
not
isinstance
(
layers
,
list
):
layers
=
[
layers
]
layers
=
[
layers
]
...
@@ -35,6 +34,10 @@ def argscope(layers, **kwargs):
...
@@ -35,6 +34,10 @@ def argscope(layers, **kwargs):
del
_ArgScopeStack
[
-
1
]
del
_ArgScopeStack
[
-
1
]
def
get_arg_scope
():
def
get_arg_scope
():
""" return the current argscope
an argscope is a dict of dict:
dict[layername] = {arg: val}
"""
if
len
(
_ArgScopeStack
)
>
0
:
if
len
(
_ArgScopeStack
)
>
0
:
return
_ArgScopeStack
[
-
1
]
return
_ArgScopeStack
[
-
1
]
else
:
else
:
...
...
tensorpack/train/trainer.py
View file @
3b2f7df1
...
@@ -138,7 +138,7 @@ class QueueInputTrainer(Trainer):
...
@@ -138,7 +138,7 @@ class QueueInputTrainer(Trainer):
inputs
=
self
.
model
.
get_input_vars
()
inputs
=
self
.
model
.
get_input_vars
()
tf
.
get_variable_scope
()
.
reuse_variables
()
tf
.
get_variable_scope
()
.
reuse_variables
()
for
k
in
self
.
predict_tower
:
for
k
in
self
.
predict_tower
:
logger
.
info
(
"Building graph for predict towerp{}..."
.
format
(
k
))
logger
.
info
(
"Building graph for predict tower
p{}..."
.
format
(
k
))
with
tf
.
device
(
'/gpu:{}'
.
format
(
k
)
if
k
>=
0
else
'/cpu:0'
),
\
with
tf
.
device
(
'/gpu:{}'
.
format
(
k
)
if
k
>=
0
else
'/cpu:0'
),
\
tf
.
name_scope
(
'towerp{}'
.
format
(
k
)):
tf
.
name_scope
(
'towerp{}'
.
format
(
k
)):
self
.
model
.
build_graph
(
inputs
,
False
)
self
.
model
.
build_graph
(
inputs
,
False
)
...
...
tensorpack/utils/naming.py
View file @
3b2f7df1
...
@@ -7,6 +7,7 @@ GLOBAL_STEP_VAR_NAME = 'global_step:0'
...
@@ -7,6 +7,7 @@ GLOBAL_STEP_VAR_NAME = 'global_step:0'
# extra variables to summarize during training in a moving-average way
# extra variables to summarize during training in a moving-average way
MOVING_SUMMARY_VARS_KEY
=
'MOVING_SUMMARY_VARIABLES'
MOVING_SUMMARY_VARS_KEY
=
'MOVING_SUMMARY_VARIABLES'
INPUT_VARS_KEY
=
'INPUT_VARIABLES'
# export all upper case variables
# export all upper case variables
all_local_names
=
locals
()
.
keys
()
all_local_names
=
locals
()
.
keys
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment