Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
8dcf454d
Commit
8dcf454d
authored
Jul 10, 2017
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
add build_placeholder method in InputDesc
parent
c7de2013
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
49 additions
and
60 deletions
+49
-60
tensorpack/models/model_desc.py
tensorpack/models/model_desc.py
+48
-53
tensorpack/train/base.py
tensorpack/train/base.py
+1
-4
tensorpack/utils/naming.py
tensorpack/utils/naming.py
+0
-3
No files found.
tensorpack/models/model_desc.py
View file @
8dcf454d
...
@@ -4,40 +4,70 @@
...
@@ -4,40 +4,70 @@
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
from
abc
import
ABCMeta
,
abstractmethod
from
abc
import
ABCMeta
,
abstractmethod
from
collections
import
namedtuple
import
tensorflow
as
tf
import
tensorflow
as
tf
import
pickle
import
pickle
import
six
import
six
from
..utils
import
logger
from
..utils
import
logger
from
..utils.naming
import
INPUTS_KEY
from
..utils.argtools
import
memoized
from
..utils.argtools
import
memoized
from
.regularize
import
regularize_cost_from_collection
from
.regularize
import
regularize_cost_from_collection
__all__
=
[
'InputDesc'
,
'InputVar'
,
'ModelDesc'
]
__all__
=
[
'InputDesc'
,
'InputVar'
,
'ModelDesc'
]
class
InputDesc
(
object
):
class
InputDesc
(
""" Store metadata about input placeholders. """
namedtuple
(
'InputDescTuple'
,
[
'type'
,
'shape'
,
'name'
])):
def
__init__
(
self
,
type
,
shape
,
name
,
sparse
=
False
):
"""
"""
Args:
Metadata about an input entry point to the graph.
type: tf type of the tensor.
This metadata can be later used to build placeholders or other types of
shape (list):
input source.
name (str):
sparse (bool): whether to use ``tf.sparse_placeholder``.
"""
"""
self
.
type
=
type
self
.
shape
=
shape
self
.
name
=
name
self
.
sparse
=
sparse
def
dumps
(
self
):
def
dumps
(
self
):
"""
Returns:
str: serialized string
"""
return
pickle
.
dumps
(
self
)
return
pickle
.
dumps
(
self
)
@
staticmethod
@
staticmethod
def
loads
(
buf
):
def
loads
(
buf
):
"""
Args:
buf (str): serialized string
Returns:
InputDesc:
"""
return
pickle
.
loads
(
buf
)
return
pickle
.
loads
(
buf
)
def
build_placeholder
(
self
,
prefix
=
''
):
"""
Build a tf.placeholder from the metadata, with an optional prefix.
Args:
prefix(str): the name of the placeholder will be ``prefix + self.name``
Returns:
tf.Tensor:
"""
with
tf
.
name_scope
(
None
):
# clear any name scope it might get called in
return
tf
.
placeholder
(
self
.
type
,
shape
=
self
.
shape
,
name
=
prefix
+
self
.
name
)
# TODO cache results from build_placeholder, and skip it in serialization
@
memoized
def
build_placeholder_reuse
(
self
):
"""
Build a tf.placeholder from the metadata, or return an old one.
Returns:
tf.Tensor:
"""
return
self
.
build_placeholder
()
class
InputVar
(
InputDesc
):
class
InputVar
(
InputDesc
):
def
__init__
(
self
,
*
args
,
**
kwargs
):
def
__init__
(
self
,
*
args
,
**
kwargs
):
...
@@ -70,17 +100,12 @@ class ModelDesc(object):
...
@@ -70,17 +100,12 @@ class ModelDesc(object):
list[tf.Tensor]: the list of built placeholders.
list[tf.Tensor]: the list of built placeholders.
"""
"""
inputs
=
self
.
_get_inputs
()
inputs
=
self
.
_get_inputs
()
for
v
in
inputs
:
tf
.
add_to_collection
(
INPUTS_KEY
,
v
.
dumps
())
ret
=
[]
ret
=
[]
with
tf
.
name_scope
(
None
):
# clear any name scope it might get called in
for
v
in
inputs
:
for
v
in
inputs
:
placehdr_f
=
tf
.
placeholder
if
not
v
.
sparse
else
tf
.
sparse_placeholder
ret
.
append
(
v
.
build_placeholder
(
prefix
))
ret
.
append
(
placehdr_f
(
v
.
type
,
shape
=
v
.
shape
,
name
=
prefix
+
v
.
name
))
return
ret
return
ret
@
memoized
def
get_inputs_desc
(
self
):
def
get_inputs_desc
(
self
):
"""
"""
Returns:
Returns:
...
@@ -150,33 +175,3 @@ class ModelDesc(object):
...
@@ -150,33 +175,3 @@ class ModelDesc(object):
def
_get_gradient_processor
(
self
):
def
_get_gradient_processor
(
self
):
return
[]
return
[]
class
ModelFromMetaGraph
(
ModelDesc
):
"""
Load the exact TF graph from a saved meta_graph.
Only useful for inference.
"""
# TODO this class may not be functional anymore. don't use
def
__init__
(
self
,
filename
):
"""
Args:
filename (str): file name of the saved meta graph.
"""
tf
.
train
.
import_meta_graph
(
filename
)
all_coll
=
tf
.
get_default_graph
()
.
get_all_collection_keys
()
for
k
in
[
INPUTS_KEY
,
tf
.
GraphKeys
.
TRAINABLE_VARIABLES
,
tf
.
GraphKeys
.
GLOBAL_VARIABLES
]:
if
k
not
in
all_coll
:
logger
.
warn
(
"Collection {} not found in metagraph!"
.
format
(
k
))
def
_get_inputs
(
self
):
col
=
tf
.
get_collection
(
INPUTS_KEY
)
col
=
[
InputDesc
.
loads
(
v
)
for
v
in
col
]
return
col
def
_build_graph
(
self
,
_
,
__
):
""" Do nothing. Graph was imported already """
pass
tensorpack/train/base.py
View file @
8dcf454d
...
@@ -38,6 +38,7 @@ class Trainer(object):
...
@@ -38,6 +38,7 @@ class Trainer(object):
config (TrainConfig): the config used in this trainer.
config (TrainConfig): the config used in this trainer.
model (ModelDesc)
model (ModelDesc)
sess (tf.Session): the current session in use.
sess (tf.Session): the current session in use.
hooked_sess (tf.MonitoredSession): the session with hooks.
monitors (Monitors): the monitors. Callbacks can use it for logging.
monitors (Monitors): the monitors. Callbacks can use it for logging.
epoch_num (int): the number of epochs that have finished.
epoch_num (int): the number of epochs that have finished.
...
@@ -107,9 +108,6 @@ class Trainer(object):
...
@@ -107,9 +108,6 @@ class Trainer(object):
""" Abstract method: run one iteration. Subclass should define what is "iteration".
""" Abstract method: run one iteration. Subclass should define what is "iteration".
"""
"""
def
_trigger_epoch
(
self
):
pass
def
setup
(
self
):
def
setup
(
self
):
"""
"""
Setup the trainer and be ready for the main loop.
Setup the trainer and be ready for the main loop.
...
@@ -192,7 +190,6 @@ class Trainer(object):
...
@@ -192,7 +190,6 @@ class Trainer(object):
self
.
_epoch_num
,
self
.
global_step
,
time
.
time
()
-
start_time
))
self
.
_epoch_num
,
self
.
global_step
,
time
.
time
()
-
start_time
))
# trigger epoch outside the timing region.
# trigger epoch outside the timing region.
self
.
_trigger_epoch
()
self
.
_callbacks
.
trigger_epoch
()
self
.
_callbacks
.
trigger_epoch
()
logger
.
info
(
"Training has finished!"
)
logger
.
info
(
"Training has finished!"
)
except
(
StopTraining
,
tf
.
errors
.
OutOfRangeError
):
except
(
StopTraining
,
tf
.
errors
.
OutOfRangeError
):
...
...
tensorpack/utils/naming.py
View file @
8dcf454d
...
@@ -16,9 +16,6 @@ PREDICT_TOWER = 'towerp'
...
@@ -16,9 +16,6 @@ PREDICT_TOWER = 'towerp'
# extra variables to summarize during training in a moving-average way
# extra variables to summarize during training in a moving-average way
MOVING_SUMMARY_OPS_KEY
=
'MOVING_SUMMARY_OPS'
MOVING_SUMMARY_OPS_KEY
=
'MOVING_SUMMARY_OPS'
# metainfo for input tensors
INPUTS_KEY
=
'INPUTS_METAINFO'
SUMMARY_BACKUP_KEYS
=
[
tf
.
GraphKeys
.
SUMMARIES
,
MOVING_SUMMARY_OPS_KEY
]
SUMMARY_BACKUP_KEYS
=
[
tf
.
GraphKeys
.
SUMMARIES
,
MOVING_SUMMARY_OPS_KEY
]
TOWER_FREEZE_KEYS
=
SUMMARY_BACKUP_KEYS
+
[
tf
.
GraphKeys
.
UPDATE_OPS
]
TOWER_FREEZE_KEYS
=
SUMMARY_BACKUP_KEYS
+
[
tf
.
GraphKeys
.
UPDATE_OPS
]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment