Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
8dcf454d
Commit
8dcf454d
authored
Jul 10, 2017
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
add build_placeholder method in InputDesc
parent
c7de2013
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
49 additions
and
60 deletions
+49
-60
tensorpack/models/model_desc.py
tensorpack/models/model_desc.py
+48
-53
tensorpack/train/base.py
tensorpack/train/base.py
+1
-4
tensorpack/utils/naming.py
tensorpack/utils/naming.py
+0
-3
No files found.
tensorpack/models/model_desc.py
View file @
8dcf454d
...
...
@@ -4,40 +4,70 @@
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
from
abc
import
ABCMeta
,
abstractmethod
from
collections
import
namedtuple
import
tensorflow
as
tf
import
pickle
import
six
from
..utils
import
logger
from
..utils.naming
import
INPUTS_KEY
from
..utils.argtools
import
memoized
from
.regularize
import
regularize_cost_from_collection
__all__
=
[
'InputDesc'
,
'InputVar'
,
'ModelDesc'
]
class
InputDesc
(
object
):
""" Store metadata about input placeholders. """
def
__init__
(
self
,
type
,
shape
,
name
,
sparse
=
False
):
"""
Args:
type: tf type of the tensor.
shape (list):
name (str):
sparse (bool): whether to use ``tf.sparse_placeholder``.
"""
self
.
type
=
type
self
.
shape
=
shape
self
.
name
=
name
self
.
sparse
=
sparse
class
InputDesc
(
namedtuple
(
'InputDescTuple'
,
[
'type'
,
'shape'
,
'name'
])):
"""
Metadata about an input entry point to the graph.
This metadata can be later used to build placeholders or other types of
input source.
"""
def
dumps
(
self
):
"""
Returns:
str: serialized string
"""
return
pickle
.
dumps
(
self
)
@
staticmethod
def
loads
(
buf
):
"""
Args:
buf (str): serialized string
Returns:
InputDesc:
"""
return
pickle
.
loads
(
buf
)
def
build_placeholder
(
self
,
prefix
=
''
):
"""
Build a tf.placeholder from the metadata, with an optional prefix.
Args:
prefix(str): the name of the placeholder will be ``prefix + self.name``
Returns:
tf.Tensor:
"""
with
tf
.
name_scope
(
None
):
# clear any name scope it might get called in
return
tf
.
placeholder
(
self
.
type
,
shape
=
self
.
shape
,
name
=
prefix
+
self
.
name
)
# TODO cache results from build_placeholder, and skip it in serialization
@
memoized
def
build_placeholder_reuse
(
self
):
"""
Build a tf.placeholder from the metadata, or return an old one.
Returns:
tf.Tensor:
"""
return
self
.
build_placeholder
()
class
InputVar
(
InputDesc
):
def
__init__
(
self
,
*
args
,
**
kwargs
):
...
...
@@ -70,17 +100,12 @@ class ModelDesc(object):
list[tf.Tensor]: the list of built placeholders.
"""
inputs
=
self
.
_get_inputs
()
for
v
in
inputs
:
tf
.
add_to_collection
(
INPUTS_KEY
,
v
.
dumps
())
ret
=
[]
with
tf
.
name_scope
(
None
):
# clear any name scope it might get called in
for
v
in
inputs
:
placehdr_f
=
tf
.
placeholder
if
not
v
.
sparse
else
tf
.
sparse_placeholder
ret
.
append
(
placehdr_f
(
v
.
type
,
shape
=
v
.
shape
,
name
=
prefix
+
v
.
name
))
for
v
in
inputs
:
ret
.
append
(
v
.
build_placeholder
(
prefix
))
return
ret
@
memoized
def
get_inputs_desc
(
self
):
"""
Returns:
...
...
@@ -150,33 +175,3 @@ class ModelDesc(object):
def
_get_gradient_processor
(
self
):
return
[]
class
ModelFromMetaGraph
(
ModelDesc
):
"""
Load the exact TF graph from a saved meta_graph.
Only useful for inference.
"""
# TODO this class may not be functional anymore. don't use
def
__init__
(
self
,
filename
):
"""
Args:
filename (str): file name of the saved meta graph.
"""
tf
.
train
.
import_meta_graph
(
filename
)
all_coll
=
tf
.
get_default_graph
()
.
get_all_collection_keys
()
for
k
in
[
INPUTS_KEY
,
tf
.
GraphKeys
.
TRAINABLE_VARIABLES
,
tf
.
GraphKeys
.
GLOBAL_VARIABLES
]:
if
k
not
in
all_coll
:
logger
.
warn
(
"Collection {} not found in metagraph!"
.
format
(
k
))
def
_get_inputs
(
self
):
col
=
tf
.
get_collection
(
INPUTS_KEY
)
col
=
[
InputDesc
.
loads
(
v
)
for
v
in
col
]
return
col
def
_build_graph
(
self
,
_
,
__
):
""" Do nothing. Graph was imported already """
pass
tensorpack/train/base.py
View file @
8dcf454d
...
...
@@ -38,6 +38,7 @@ class Trainer(object):
config (TrainConfig): the config used in this trainer.
model (ModelDesc)
sess (tf.Session): the current session in use.
hooked_sess (tf.MonitoredSession): the session with hooks.
monitors (Monitors): the monitors. Callbacks can use it for logging.
epoch_num (int): the number of epochs that have finished.
...
...
@@ -107,9 +108,6 @@ class Trainer(object):
""" Abstract method: run one iteration. Subclass should define what is "iteration".
"""
def
_trigger_epoch
(
self
):
pass
def
setup
(
self
):
"""
Setup the trainer and be ready for the main loop.
...
...
@@ -192,7 +190,6 @@ class Trainer(object):
self
.
_epoch_num
,
self
.
global_step
,
time
.
time
()
-
start_time
))
# trigger epoch outside the timing region.
self
.
_trigger_epoch
()
self
.
_callbacks
.
trigger_epoch
()
logger
.
info
(
"Training has finished!"
)
except
(
StopTraining
,
tf
.
errors
.
OutOfRangeError
):
...
...
tensorpack/utils/naming.py
View file @
8dcf454d
...
...
@@ -16,9 +16,6 @@ PREDICT_TOWER = 'towerp'
# extra variables to summarize during training in a moving-average way
MOVING_SUMMARY_OPS_KEY
=
'MOVING_SUMMARY_OPS'
# metainfo for input tensors
INPUTS_KEY
=
'INPUTS_METAINFO'
SUMMARY_BACKUP_KEYS
=
[
tf
.
GraphKeys
.
SUMMARIES
,
MOVING_SUMMARY_OPS_KEY
]
TOWER_FREEZE_KEYS
=
SUMMARY_BACKUP_KEYS
+
[
tf
.
GraphKeys
.
UPDATE_OPS
]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment