Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
4f0e1bd5
Commit
4f0e1bd5
authored
Dec 15, 2016
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
predict tower reorganize
parent
c98f2351
Changes
9
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
39 additions
and
29 deletions
+39
-29
examples/TIMIT/create-lmdb.py
examples/TIMIT/create-lmdb.py
+1
-0
examples/TIMIT/timitdata.py
examples/TIMIT/timitdata.py
+1
-0
tensorpack/predict/base.py
tensorpack/predict/base.py
+15
-13
tensorpack/tfutils/sessinit.py
tensorpack/tfutils/sessinit.py
+2
-2
tensorpack/tfutils/tower.py
tensorpack/tfutils/tower.py
+6
-4
tensorpack/tfutils/varmanip.py
tensorpack/tfutils/varmanip.py
+2
-2
tensorpack/train/trainer.py
tensorpack/train/trainer.py
+5
-4
tensorpack/utils/debug.py
tensorpack/utils/debug.py
+4
-4
tensorpack/utils/naming.py
tensorpack/utils/naming.py
+3
-0
No files found.
examples/TIMIT/create-lmdb.py
View file @
4f0e1bd5
...
...
@@ -71,6 +71,7 @@ class RawTIMIT(DataFlow):
self
.
filelists
=
[
k
for
k
in
fs
.
recursive_walk
(
self
.
dirname
)
if
k
.
endswith
(
'.wav'
)]
logger
.
info
(
"Found {} wav files ..."
.
format
(
len
(
self
.
filelists
)))
assert
len
(
self
.
filelists
),
self
.
filelists
assert
label
in
[
'phoneme'
,
'letter'
],
label
self
.
label
=
label
...
...
examples/TIMIT/timitdata.py
View file @
4f0e1bd5
...
...
@@ -10,6 +10,7 @@ from six.moves import range
__all__
=
[
'TIMITBatch'
]
def
batch_feature
(
feats
):
# pad to the longest in the batch
maxlen
=
max
([
k
.
shape
[
0
]
for
k
in
feats
])
bsize
=
len
(
feats
)
ret
=
np
.
zeros
((
bsize
,
maxlen
,
feats
[
0
]
.
shape
[
1
]))
...
...
tensorpack/predict/base.py
View file @
4f0e1bd5
...
...
@@ -7,6 +7,7 @@ from abc import abstractmethod, ABCMeta, abstractproperty
import
tensorflow
as
tf
import
six
from
..utils.naming
import
*
from
..utils
import
logger
from
..tfutils
import
get_tensors_by_names
,
TowerContext
...
...
@@ -100,17 +101,17 @@ class OfflinePredictor(OnlinePredictor):
sess
,
input_vars
,
output_vars
,
config
.
return_input
)
def
build_multi_tower_prediction_graph
(
model
,
towers
):
def
build_multi_tower_prediction_graph
(
build_tower_fn
,
towers
):
"""
:param build_tower_fn: the function to be called inside each tower, taking tower as the argument
:param towers: a list of gpu relative id.
"""
input_vars
=
model
.
get_input_vars
()
for
k
in
towers
:
logger
.
info
(
"Building graph for predictor tower {}..."
.
format
(
k
))
with
tf
.
device
(
'/gpu:{}'
.
format
(
k
)
if
k
>=
0
else
'/cpu:0'
),
\
TowerContext
(
'
towerp{}'
.
format
(
k
)):
model
.
build_graph
(
input_vars
)
TowerContext
(
'
{}{}'
.
format
(
PREDICT_TOWER
,
k
)):
build_tower_fn
(
k
)
tf
.
get_variable_scope
()
.
reuse_variables
()
class
MultiTowerOfflinePredictor
(
OnlinePredictor
):
...
...
@@ -119,7 +120,8 @@ class MultiTowerOfflinePredictor(OnlinePredictor):
self
.
predictors
=
[]
with
self
.
graph
.
as_default
():
# TODO backup summary keys?
build_multi_tower_prediction_graph
(
config
.
model
,
towers
)
fn
=
lambda
_
:
config
.
model
.
build_graph
(
config
.
model
.
get_input_vars
())
build_multi_tower_prediction_graph
(
fn
,
towers
)
self
.
sess
=
tf
.
Session
(
config
=
config
.
session_config
)
config
.
session_init
.
init
(
self
.
sess
)
...
...
@@ -128,7 +130,7 @@ class MultiTowerOfflinePredictor(OnlinePredictor):
for
k
in
towers
:
output_vars
=
get_tensors_by_names
(
[
'
towerp{}/'
.
format
(
k
)
+
n
\
[
'
{}{}/'
.
format
(
PREDICT_TOWER
,
k
)
+
n
\
for
n
in
config
.
output_names
])
self
.
predictors
.
append
(
OnlinePredictor
(
self
.
sess
,
input_vars
,
output_vars
,
config
.
return_input
))
...
...
@@ -146,22 +148,22 @@ class DataParallelOfflinePredictor(OnlinePredictor):
with
self
.
graph
.
as_default
():
sess
=
tf
.
Session
(
config
=
config
.
session_config
)
input_var_names
=
[]
output_vars
=
[]
for
k
in
towers
:
input_vars
=
config
.
model
.
get_placeholders
(
prefix
=
'towerp{}-'
.
format
(
k
))
towername
=
PREDICT_TOWER
+
str
(
k
)
input_vars
=
config
.
model
.
get_placeholders
(
prefix
=
towername
+
'-'
)
logger
.
info
(
"Building graph for predictor tower {}..."
.
format
(
k
))
with
tf
.
device
(
'/gpu:{}'
.
format
(
k
)
if
k
>=
0
else
'/cpu:0'
),
\
TowerContext
(
'towerp{}'
.
format
(
k
)
):
TowerContext
(
towername
,
is_training
=
False
):
config
.
model
.
build_graph
(
input_vars
)
tf
.
get_variable_scope
()
.
reuse_variables
()
input_var_names
.
extend
([
k
.
name
for
k
in
input_vars
])
output_vars
.
extend
(
get_tensors_by_names
(
[
towername
+
'/'
+
n
\
for
n
in
config
.
output_names
]))
input_vars
=
get_tensors_by_names
(
input_var_names
)
config
.
session_init
.
init
(
sess
)
output_vars
=
[]
for
k
in
towers
:
output_vars
.
extend
(
get_tensors_by_names
(
[
'towerp{}/'
.
format
(
k
)
+
n
\
for
n
in
config
.
output_names
]))
super
(
DataParallelOfflinePredictor
,
self
)
.
__init__
(
sess
,
input_vars
,
output_vars
,
config
.
return_input
)
tensorpack/tfutils/sessinit.py
View file @
4f0e1bd5
...
...
@@ -104,8 +104,8 @@ class SaverRestore(SessionInit):
reader
=
tf
.
train
.
NewCheckpointReader
(
model_path
)
ckpt_vars
=
reader
.
get_variable_to_shape_map
()
.
keys
()
for
v
in
ckpt_vars
:
if
v
.
startswith
(
'towerp'
):
logger
.
warn
(
"Found {} in checkpoint. A
nything from prediction tower shouldn't be saved."
.
format
(
v
.
name
))
if
v
.
startswith
(
PREDICT_TOWER
):
logger
.
error
(
"Found {} in checkpoint. But a
nything from prediction tower shouldn't be saved."
.
format
(
v
.
name
))
return
set
(
ckpt_vars
)
def
_get_vars_to_restore_multimap
(
self
,
vars_available
):
...
...
tensorpack/tfutils/tower.py
View file @
4f0e1bd5
...
...
@@ -5,6 +5,7 @@
import
tensorflow
as
tf
import
re
from
..utils.naming
import
*
__all__
=
[
'get_current_tower_context'
,
'TowerContext'
]
...
...
@@ -15,7 +16,7 @@ class TowerContext(object):
""" tower_name: 'tower0', 'towerp0', or '' """
self
.
_name
=
tower_name
if
is_training
is
None
:
is_training
=
not
self
.
_name
.
startswith
(
'towerp'
)
is_training
=
not
self
.
_name
.
startswith
(
PREDICT_TOWER
)
self
.
_is_training
=
is_training
@
property
...
...
@@ -52,12 +53,13 @@ class TowerContext(object):
def
find_tensor_in_main_tower
(
self
,
graph
,
name
):
if
self
.
is_main_tower
:
return
graph
.
get_tensor_by_name
(
name
)
if
name
.
startswith
(
'towerp'
):
newname
=
re
.
sub
(
'towerp[0-9]+/'
,
''
,
name
)
if
name
.
startswith
(
PREDICT_TOWER
):
predict_tower_prefix
=
'{}[0-9]+/'
.
format
(
PREDICT_TOWER
)
newname
=
re
.
sub
(
predict_tower_prefix
,
''
,
name
)
try
:
return
graph
.
get_tensor_by_name
(
newname
)
except
KeyError
:
newname
=
re
.
sub
(
'towerp[0-9]+/'
,
'tower0/'
,
name
)
newname
=
re
.
sub
(
predict_tower_prefix
,
'tower0/'
,
name
)
return
graph
.
get_tensor_by_name
(
newname
)
def
__enter__
(
self
):
...
...
tensorpack/tfutils/varmanip.py
View file @
4f0e1bd5
...
...
@@ -25,8 +25,8 @@ def get_savename_from_varname(
:returns: the name used to save the variable
"""
name
=
varname
if
'towerp/'
in
name
:
logger
.
error
(
"No variable
should be under 'towerp' name scope"
.
format
(
v
.
name
))
if
PREDICT_TOWER
in
name
:
logger
.
error
(
"No variable
under '{}' name scope should be saved!"
.
format
(
PREDICT_TOWER
))
# don't overwrite anything in the current prediction graph
return
None
if
'tower'
in
name
:
...
...
tensorpack/train/trainer.py
View file @
4f0e1bd5
...
...
@@ -8,7 +8,7 @@ from six.moves import zip
from
.base
import
Trainer
from
..utils
import
logger
,
SUMMARY_BACKUP_KEYS
from
..utils
import
logger
,
SUMMARY_BACKUP_KEYS
,
PREDICT_TOWER
from
..tfutils
import
(
get_tensors_by_names
,
freeze_collection
,
get_global_step_var
,
TowerContext
)
from
..tfutils.summary
import
summary_moving_average
,
add_moving_summary
...
...
@@ -39,16 +39,17 @@ class PredictorFactory(object):
self
.
_build_predict_tower
()
tower
=
self
.
towers
[
tower
%
len
(
self
.
towers
)]
raw_input_vars
=
get_tensors_by_names
(
input_names
)
output_names
=
[
'
towerp{}/'
.
format
(
tower
)
+
n
for
n
in
output_names
]
output_names
=
[
'
{}{}/'
.
format
(
PREDICT_TOWER
,
tower
)
+
n
for
n
in
output_names
]
output_vars
=
get_tensors_by_names
(
output_names
)
return
OnlinePredictor
(
self
.
sess
,
raw_input_vars
,
output_vars
)
def
_build_predict_tower
(
self
):
tf
.
get_variable_scope
()
.
reuse_variables
()
# build_predict_tower might get called anywhere, but '
towerp
' should be the outermost name scope
# build_predict_tower might get called anywhere, but '
PREDICT_TOWER
' should be the outermost name scope
with
tf
.
name_scope
(
None
),
\
freeze_collection
(
SUMMARY_BACKUP_KEYS
):
build_multi_tower_prediction_graph
(
self
.
model
,
self
.
towers
)
fn
=
lambda
_
:
self
.
model
.
build_graph
(
self
.
model
.
get_input_vars
())
build_multi_tower_prediction_graph
(
fn
,
self
.
towers
)
self
.
tower_built
=
True
class
SimpleTrainer
(
Trainer
):
...
...
tensorpack/utils/debug.py
View file @
4f0e1bd5
...
...
@@ -21,9 +21,9 @@ def enable_call_trace():
if
caller
:
caller_line_no
=
caller
.
f_lineno
caller_filename
=
caller
.
f_code
.
co_filename
print
'Call to `
%
s` on line
%
s:
%
s from
%
s:
%
s'
%
\
print
(
'Call to `
%
s` on line
%
s:
%
s from
%
s:
%
s'
%
\
(
func_name
,
func_filename
,
func_line_no
,
caller_filename
,
caller_line_no
)
caller_filename
,
caller_line_no
)
)
return
sys
.
settrace
(
tracer
)
...
...
@@ -31,9 +31,9 @@ if __name__ == '__main__':
enable_call_trace
()
def
b
(
a
):
print
2
print
(
2
)
def
a
():
print
1
print
(
1
)
b
(
1
)
a
()
tensorpack/utils/naming.py
View file @
4f0e1bd5
...
...
@@ -5,6 +5,9 @@
GLOBAL_STEP_OP_NAME
=
'global_step'
GLOBAL_STEP_VAR_NAME
=
'global_step:0'
# prefix of predict tower
PREDICT_TOWER
=
'towerp'
# extra variables to summarize during training in a moving-average way
MOVING_SUMMARY_VARS_KEY
=
'MOVING_SUMMARY_VARIABLES'
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment