Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
bd686aab
Commit
bd686aab
authored
Feb 11, 2017
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
deprecate _get_input_vars
parent
bbaf8d12
Changes
9
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
47 additions
and
41 deletions
+47
-41
tensorpack/callbacks/inference_runner.py
tensorpack/callbacks/inference_runner.py
+2
-2
tensorpack/dataflow/format.py
tensorpack/dataflow/format.py
+1
-1
tensorpack/models/model_desc.py
tensorpack/models/model_desc.py
+27
-22
tensorpack/predict/base.py
tensorpack/predict/base.py
+1
-1
tensorpack/predict/config.py
tensorpack/predict/config.py
+1
-1
tensorpack/predict/multigpu.py
tensorpack/predict/multigpu.py
+1
-1
tensorpack/train/input_data.py
tensorpack/train/input_data.py
+10
-9
tensorpack/train/trainer.py
tensorpack/train/trainer.py
+2
-2
tensorpack/utils/naming.py
tensorpack/utils/naming.py
+2
-2
No files found.
tensorpack/callbacks/inference_runner.py
View file @
bd686aab
...
...
@@ -97,7 +97,7 @@ class InferenceRunner(Triggerable):
def
_find_input_tensors
(
self
):
if
self
.
input_tensors
is
None
:
input_vars
=
self
.
trainer
.
model
.
get_reuse_placehdrs
()
input_vars
=
self
.
trainer
.
model
.
get_reuse
d
_placehdrs
()
# TODO even if it works here, sparse still is unavailable
# because get_tensor_by_name doesn't work for sparse
...
...
@@ -198,7 +198,7 @@ class FeedfreeInferenceRunner(Triggerable):
self
.
_input_data
.
_setup
(
self
.
trainer
)
# only 1 prediction tower will be used for inference
self
.
_input_tensors
=
self
.
_input_data
.
get_input_tensors
()
model_placehdrs
=
self
.
trainer
.
model
.
get_reuse_placehdrs
()
model_placehdrs
=
self
.
trainer
.
model
.
get_reuse
d
_placehdrs
()
if
self
.
_input_names
is
not
None
:
raise
NotImplementedError
(
"Random code. Not tested."
)
assert
len
(
self
.
_input_names
)
==
len
(
self
.
_input_tensors
),
\
...
...
tensorpack/dataflow/format.py
View file @
bd686aab
...
...
@@ -66,7 +66,7 @@ class LMDBData(RNGDataFlow):
Args:
lmdb_path (str): a directory or a file.
shuffle (bool): shuffle the keys or not.
keys (list
of str
or str): list of str as the keys, used only when shuffle is True.
keys (list
[str]
or str): list of str as the keys, used only when shuffle is True.
It can also be a format string e.g. ``{:0>8d}`` which will be
formatted with the indices from 0 to *total_size - 1*.
...
...
tensorpack/models/model_desc.py
View file @
bd686aab
...
...
@@ -8,17 +8,17 @@ import tensorflow as tf
import
pickle
import
six
from
..utils
import
logger
,
INPUT
_VAR
S_KEY
from
..utils
import
logger
,
INPUTS_KEY
from
..tfutils.gradproc
import
CheckGradient
from
..tfutils.summary
import
add_moving_summary
from
..tfutils.tower
import
get_current_tower_context
__all__
=
[
'
ModelDesc'
,
'InputVar
'
,
'ModelFromMetaGraph'
]
__all__
=
[
'
InputDesc'
,
'InputVar'
,
'ModelDesc
'
,
'ModelFromMetaGraph'
]
# TODO "variable" is not the right name to use for input here.
# TODO "variable" is not a right name to use across this file.
class
Input
Var
(
object
):
class
Input
Desc
(
object
):
""" Store metadata about input placeholders. """
def
__init__
(
self
,
type
,
shape
,
name
,
sparse
=
False
):
"""
...
...
@@ -41,13 +41,16 @@ class InputVar(object):
return
pickle
.
loads
(
buf
)
InputVar
=
InputDesc
@
six
.
add_metaclass
(
ABCMeta
)
class
ModelDesc
(
object
):
""" Base class for a model description """
def
get_
input_va
rs
(
self
):
def
get_
reused_placehd
rs
(
self
):
"""
Create or return (if already created) raw input TF placeholder
var
s in the graph.
Create or return (if already created) raw input TF placeholders in the graph.
Returns:
list[tf.Tensor]: the list of input placeholders in the graph.
...
...
@@ -58,20 +61,21 @@ class ModelDesc(object):
self
.
reuse_input_vars
=
ret
return
ret
# alias
get_reuse_placehdrs
=
get_input_vars
def
get_input_vars
(
self
):
logger
.
warn
(
"[Deprecated] get_input_vars() was renamed to get_reused_placehdrs()!"
)
return
self
.
get_reused_placehdrs
()
def
build_placeholders
(
self
,
prefix
=
''
):
"""
For each
InputVar
, create new placeholders with optional prefix and
For each
input
, create new placeholders with optional prefix and
return them. Useful when building new towers.
Returns:
list[tf.Tensor]: the list of built placeholders.
"""
input_vars
=
self
.
_get_input
_var
s
()
input_vars
=
self
.
_get_inputs
()
for
v
in
input_vars
:
tf
.
add_to_collection
(
INPUT
_VAR
S_KEY
,
v
.
dumps
())
tf
.
add_to_collection
(
INPUTS_KEY
,
v
.
dumps
())
ret
=
[]
for
v
in
input_vars
:
placehdr_f
=
tf
.
placeholder
if
not
v
.
sparse
else
tf
.
sparse_placeholder
...
...
@@ -80,20 +84,21 @@ class ModelDesc(object):
name
=
prefix
+
v
.
name
))
return
ret
def
get_input
_var
s_desc
(
self
):
def
get_inputs_desc
(
self
):
"""
Returns:
list[:class:`Input
Var`]: list of the underlying :class:`InputVar
`.
list[:class:`Input
Desc`]: list of the underlying :class:`InputDesc
`.
"""
return
self
.
_get_input
_var
s
()
return
self
.
_get_inputs
()
def
_get_input
_vars
(
self
):
# keep backward compatibility
def
_get_input
s
(
self
):
# this is a better name than _get_input_vars
"""
:returns: a list of Input
Var
:returns: a list of Input
Desc
"""
return
self
.
_get_inputs
()
logger
.
warn
(
"[Deprecated] _get_input_vars() is renamed to _get_inputs()"
)
return
self
.
_get_input_vars
()
def
_get_input
s
(
self
):
# this is a better name than _get_input_vars
def
_get_input
_vars
(
self
):
# keep backward compatibility
raise
NotImplementedError
()
def
build_graph
(
self
,
model_inputs
):
...
...
@@ -102,7 +107,7 @@ class ModelDesc(object):
Args:
model_inputs (list[tf.Tensor]): a list of inputs, corresponding to
Input
Vars
of this model.
Input
Desc
of this model.
"""
self
.
_build_graph
(
model_inputs
)
...
...
@@ -169,14 +174,14 @@ class ModelFromMetaGraph(ModelDesc):
"""
tf
.
train
.
import_meta_graph
(
filename
)
all_coll
=
tf
.
get_default_graph
()
.
get_all_collection_keys
()
for
k
in
[
INPUT
_VAR
S_KEY
,
tf
.
GraphKeys
.
TRAINABLE_VARIABLES
,
for
k
in
[
INPUTS_KEY
,
tf
.
GraphKeys
.
TRAINABLE_VARIABLES
,
tf
.
GraphKeys
.
GLOBAL_VARIABLES
]:
assert
k
in
all_coll
,
\
"Collection {} not found in metagraph!"
.
format
(
k
)
def
_get_inputs
(
self
):
col
=
tf
.
get_collection
(
INPUT
_VAR
S_KEY
)
col
=
[
Input
Var
.
loads
(
v
)
for
v
in
col
]
col
=
tf
.
get_collection
(
INPUTS_KEY
)
col
=
[
Input
Desc
.
loads
(
v
)
for
v
in
col
]
return
col
def
_build_graph
(
self
,
_
,
__
):
...
...
tensorpack/predict/base.py
View file @
bd686aab
...
...
@@ -123,7 +123,7 @@ class OfflinePredictor(OnlinePredictor):
"""
self
.
graph
=
tf
.
Graph
()
with
self
.
graph
.
as_default
():
input_placehdrs
=
config
.
model
.
get_
input_va
rs
()
input_placehdrs
=
config
.
model
.
get_
reused_placehd
rs
()
with
TowerContext
(
''
,
False
):
config
.
model
.
build_graph
(
input_placehdrs
)
...
...
tensorpack/predict/config.py
View file @
bd686aab
...
...
@@ -47,7 +47,7 @@ class PredictConfig(object):
self
.
input_names
=
input_names
if
self
.
input_names
is
None
:
# neither options is set, assume all inputs
raw_vars
=
self
.
model
.
get_input
_var
s_desc
()
raw_vars
=
self
.
model
.
get_inputs_desc
()
self
.
input_names
=
[
k
.
name
for
k
in
raw_vars
]
self
.
output_names
=
output_names
assert_type
(
self
.
output_names
,
list
)
...
...
tensorpack/predict/multigpu.py
View file @
bd686aab
...
...
@@ -28,7 +28,7 @@ class MultiTowerOfflinePredictor(OnlinePredictor):
with
self
.
graph
.
as_default
():
# TODO backup summary keys?
def
fn
(
_
):
config
.
model
.
build_graph
(
config
.
model
.
get_
input_va
rs
())
config
.
model
.
build_graph
(
config
.
model
.
get_
reused_placehd
rs
())
build_prediction_graph
(
fn
,
towers
)
self
.
sess
=
tf
.
Session
(
config
=
config
.
session_config
)
...
...
tensorpack/train/input_data.py
View file @
bd686aab
...
...
@@ -39,14 +39,14 @@ class FeedInput(InputData):
return
self
.
ds
.
size
()
def
_setup
(
self
,
trainer
):
self
.
input_
vars
=
trainer
.
model
.
get_input_va
rs
()
self
.
input_
placehdrs
=
trainer
.
model
.
get_reused_placehd
rs
()
rds
=
RepeatedData
(
self
.
ds
,
-
1
)
rds
.
reset_state
()
self
.
data_producer
=
rds
.
get_data
()
def
next_feed
(
self
):
data
=
next
(
self
.
data_producer
)
feed
=
dict
(
zip
(
self
.
input_
va
rs
,
data
))
feed
=
dict
(
zip
(
self
.
input_
placehd
rs
,
data
))
self
.
_last_feed
=
feed
return
feed
...
...
@@ -134,7 +134,7 @@ class QueueInput(FeedfreeInput):
return
self
.
ds
.
size
()
def
_setup
(
self
,
trainer
):
self
.
input_placehdrs
=
trainer
.
model
.
get_
input_va
rs
()
self
.
input_placehdrs
=
trainer
.
model
.
get_
reused_placehd
rs
()
assert
len
(
self
.
input_placehdrs
)
>
0
,
\
"QueueInput can only be used with input placeholders!"
if
self
.
queue
is
None
:
...
...
@@ -182,7 +182,7 @@ class BatchQueueInput(FeedfreeInput):
return
self
.
ds
.
size
()
//
self
.
batch_size
def
_setup
(
self
,
trainer
):
self
.
input_placehdrs
=
trainer
.
model
.
get_
input_va
rs
()
self
.
input_placehdrs
=
trainer
.
model
.
get_
reused_placehd
rs
()
assert
len
(
self
.
input_placehdrs
)
>
0
,
\
"QueueInput can only be used with input placeholders!"
...
...
@@ -194,7 +194,7 @@ class BatchQueueInput(FeedfreeInput):
name
=
get_op_tensor_name
(
p
.
name
)[
0
]
+
'-nobatch'
))
# dequeue_many requires fully-defined shapes
shape_err
=
"Use of BatchQueueInput requires input
variable
s to have fully-defined "
shape_err
=
"Use of BatchQueueInput requires inputs to have fully-defined "
"shapes except for the batch dimension"
shapes
=
[]
for
p
in
placehdrs_nobatch
:
...
...
@@ -226,7 +226,7 @@ class BatchQueueInput(FeedfreeInput):
class
DummyConstantInput
(
FeedfreeInput
):
""" Input some constant
variables
. Only for debugging performance issues """
""" Input some constant
tensor
. Only for debugging performance issues """
def
__init__
(
self
,
shapes
):
self
.
shapes
=
shapes
...
...
@@ -238,9 +238,10 @@ class DummyConstantInput(FeedfreeInput):
ret
=
[]
for
idx
,
p
in
enumerate
(
placehdrs
):
with
tf
.
device
(
'/gpu:0'
):
ret
.
append
(
tf
.
get_variable
(
'dummy-'
+
p
.
op
.
name
,
shape
=
self
.
shapes
[
idx
],
dtype
=
p
.
dtype
,
trainable
=
False
,
initializer
=
tf
.
constant_initializer
()))
ret
.
append
(
tf
.
get_variable
(
'dummy-'
+
p
.
op
.
name
,
shape
=
self
.
shapes
[
idx
],
dtype
=
p
.
dtype
,
trainable
=
False
,
initializer
=
tf
.
constant_initializer
()))
return
ret
...
...
tensorpack/train/trainer.py
View file @
bd686aab
...
...
@@ -47,7 +47,7 @@ class PredictorFactory(object):
freeze_collection
(
SUMMARY_BACKUP_KEYS
),
\
tf
.
variable_scope
(
tf
.
get_variable_scope
(),
reuse
=
True
):
def
fn
(
_
):
self
.
model
.
build_graph
(
self
.
model
.
get_
input_va
rs
())
self
.
model
.
build_graph
(
self
.
model
.
get_
reused_placehd
rs
())
build_prediction_graph
(
fn
,
self
.
towers
)
self
.
tower_built
=
True
...
...
@@ -79,7 +79,7 @@ class SimpleTrainer(Trainer):
def
_setup
(
self
):
self
.
_input_method
.
_setup
(
self
)
model
=
self
.
model
self
.
input_vars
=
model
.
get_
input_va
rs
()
self
.
input_vars
=
model
.
get_
reused_placehd
rs
()
with
TowerContext
(
''
,
is_training
=
True
):
model
.
build_graph
(
self
.
input_vars
)
cost_var
=
model
.
get_cost
()
...
...
tensorpack/utils/naming.py
View file @
bd686aab
...
...
@@ -19,8 +19,8 @@ PREDICT_TOWER = 'towerp'
# extra variables to summarize during training in a moving-average way
MOVING_SUMMARY_VARS_KEY
=
'MOVING_SUMMARY_VARIABLES'
#
placeholders for input variable
s
INPUT
_VARS_KEY
=
'INPUT_VARIABLES
'
#
metainfo for input tensor
s
INPUT
S_KEY
=
'INPUTS_METAINFO
'
SUMMARY_BACKUP_KEYS
=
[
tf
.
GraphKeys
.
SUMMARIES
,
MOVING_SUMMARY_VARS_KEY
]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment