Commit 4f0e1bd5 authored by Yuxin Wu's avatar Yuxin Wu

predict tower reorganize

parent c98f2351
...@@ -71,6 +71,7 @@ class RawTIMIT(DataFlow): ...@@ -71,6 +71,7 @@ class RawTIMIT(DataFlow):
self.filelists = [k for k in fs.recursive_walk(self.dirname) self.filelists = [k for k in fs.recursive_walk(self.dirname)
if k.endswith('.wav')] if k.endswith('.wav')]
logger.info("Found {} wav files ...".format(len(self.filelists))) logger.info("Found {} wav files ...".format(len(self.filelists)))
assert len(self.filelists), self.filelists
assert label in ['phoneme', 'letter'], label assert label in ['phoneme', 'letter'], label
self.label = label self.label = label
......
...@@ -10,6 +10,7 @@ from six.moves import range ...@@ -10,6 +10,7 @@ from six.moves import range
__all__ = ['TIMITBatch'] __all__ = ['TIMITBatch']
def batch_feature(feats): def batch_feature(feats):
# pad to the longest in the batch
maxlen = max([k.shape[0] for k in feats]) maxlen = max([k.shape[0] for k in feats])
bsize = len(feats) bsize = len(feats)
ret = np.zeros((bsize, maxlen, feats[0].shape[1])) ret = np.zeros((bsize, maxlen, feats[0].shape[1]))
......
...@@ -7,6 +7,7 @@ from abc import abstractmethod, ABCMeta, abstractproperty ...@@ -7,6 +7,7 @@ from abc import abstractmethod, ABCMeta, abstractproperty
import tensorflow as tf import tensorflow as tf
import six import six
from ..utils.naming import *
from ..utils import logger from ..utils import logger
from ..tfutils import get_tensors_by_names, TowerContext from ..tfutils import get_tensors_by_names, TowerContext
...@@ -100,17 +101,17 @@ class OfflinePredictor(OnlinePredictor): ...@@ -100,17 +101,17 @@ class OfflinePredictor(OnlinePredictor):
sess, input_vars, output_vars, config.return_input) sess, input_vars, output_vars, config.return_input)
def build_multi_tower_prediction_graph(model, towers): def build_multi_tower_prediction_graph(build_tower_fn, towers):
""" """
:param build_tower_fn: the function to be called inside each tower, taking tower as the argument
:param towers: a list of gpu relative id. :param towers: a list of gpu relative id.
""" """
input_vars = model.get_input_vars()
for k in towers: for k in towers:
logger.info( logger.info(
"Building graph for predictor tower {}...".format(k)) "Building graph for predictor tower {}...".format(k))
with tf.device('/gpu:{}'.format(k) if k >= 0 else '/cpu:0'), \ with tf.device('/gpu:{}'.format(k) if k >= 0 else '/cpu:0'), \
TowerContext('towerp{}'.format(k)): TowerContext('{}{}'.format(PREDICT_TOWER, k)):
model.build_graph(input_vars) build_tower_fn(k)
tf.get_variable_scope().reuse_variables() tf.get_variable_scope().reuse_variables()
class MultiTowerOfflinePredictor(OnlinePredictor): class MultiTowerOfflinePredictor(OnlinePredictor):
...@@ -119,7 +120,8 @@ class MultiTowerOfflinePredictor(OnlinePredictor): ...@@ -119,7 +120,8 @@ class MultiTowerOfflinePredictor(OnlinePredictor):
self.predictors = [] self.predictors = []
with self.graph.as_default(): with self.graph.as_default():
# TODO backup summary keys? # TODO backup summary keys?
build_multi_tower_prediction_graph(config.model, towers) fn = lambda _: config.model.build_graph(config.model.get_input_vars())
build_multi_tower_prediction_graph(fn, towers)
self.sess = tf.Session(config=config.session_config) self.sess = tf.Session(config=config.session_config)
config.session_init.init(self.sess) config.session_init.init(self.sess)
...@@ -128,7 +130,7 @@ class MultiTowerOfflinePredictor(OnlinePredictor): ...@@ -128,7 +130,7 @@ class MultiTowerOfflinePredictor(OnlinePredictor):
for k in towers: for k in towers:
output_vars = get_tensors_by_names( output_vars = get_tensors_by_names(
['towerp{}/'.format(k) + n \ ['{}{}/'.format(PREDICT_TOWER, k) + n \
for n in config.output_names]) for n in config.output_names])
self.predictors.append(OnlinePredictor( self.predictors.append(OnlinePredictor(
self.sess, input_vars, output_vars, config.return_input)) self.sess, input_vars, output_vars, config.return_input))
...@@ -146,22 +148,22 @@ class DataParallelOfflinePredictor(OnlinePredictor): ...@@ -146,22 +148,22 @@ class DataParallelOfflinePredictor(OnlinePredictor):
with self.graph.as_default(): with self.graph.as_default():
sess = tf.Session(config=config.session_config) sess = tf.Session(config=config.session_config)
input_var_names = [] input_var_names = []
output_vars = []
for k in towers: for k in towers:
input_vars = config.model.get_placeholders(prefix='towerp{}-'.format(k)) towername = PREDICT_TOWER + str(k)
input_vars = config.model.get_placeholders(prefix=towername + '-')
logger.info( logger.info(
"Building graph for predictor tower {}...".format(k)) "Building graph for predictor tower {}...".format(k))
with tf.device('/gpu:{}'.format(k) if k >= 0 else '/cpu:0'), \ with tf.device('/gpu:{}'.format(k) if k >= 0 else '/cpu:0'), \
TowerContext('towerp{}'.format(k)): TowerContext(towername, is_training=False):
config.model.build_graph(input_vars) config.model.build_graph(input_vars)
tf.get_variable_scope().reuse_variables() tf.get_variable_scope().reuse_variables()
input_var_names.extend([k.name for k in input_vars]) input_var_names.extend([k.name for k in input_vars])
output_vars.extend(get_tensors_by_names(
[towername + '/' + n \
for n in config.output_names]))
input_vars = get_tensors_by_names(input_var_names) input_vars = get_tensors_by_names(input_var_names)
config.session_init.init(sess) config.session_init.init(sess)
output_vars = []
for k in towers:
output_vars.extend(get_tensors_by_names(
['towerp{}/'.format(k) + n \
for n in config.output_names]))
super(DataParallelOfflinePredictor, self).__init__( super(DataParallelOfflinePredictor, self).__init__(
sess, input_vars, output_vars, config.return_input) sess, input_vars, output_vars, config.return_input)
...@@ -104,8 +104,8 @@ class SaverRestore(SessionInit): ...@@ -104,8 +104,8 @@ class SaverRestore(SessionInit):
reader = tf.train.NewCheckpointReader(model_path) reader = tf.train.NewCheckpointReader(model_path)
ckpt_vars = reader.get_variable_to_shape_map().keys() ckpt_vars = reader.get_variable_to_shape_map().keys()
for v in ckpt_vars: for v in ckpt_vars:
if v.startswith('towerp'): if v.startswith(PREDICT_TOWER):
logger.warn("Found {} in checkpoint. Anything from prediction tower shouldn't be saved.".format(v.name)) logger.error("Found {} in checkpoint. But anything from prediction tower shouldn't be saved.".format(v.name))
return set(ckpt_vars) return set(ckpt_vars)
def _get_vars_to_restore_multimap(self, vars_available): def _get_vars_to_restore_multimap(self, vars_available):
......
...@@ -5,6 +5,7 @@ ...@@ -5,6 +5,7 @@
import tensorflow as tf import tensorflow as tf
import re import re
from ..utils.naming import *
__all__ = ['get_current_tower_context', 'TowerContext'] __all__ = ['get_current_tower_context', 'TowerContext']
...@@ -15,7 +16,7 @@ class TowerContext(object): ...@@ -15,7 +16,7 @@ class TowerContext(object):
""" tower_name: 'tower0', 'towerp0', or '' """ """ tower_name: 'tower0', 'towerp0', or '' """
self._name = tower_name self._name = tower_name
if is_training is None: if is_training is None:
is_training = not self._name.startswith('towerp') is_training = not self._name.startswith(PREDICT_TOWER)
self._is_training = is_training self._is_training = is_training
@property @property
...@@ -52,12 +53,13 @@ class TowerContext(object): ...@@ -52,12 +53,13 @@ class TowerContext(object):
def find_tensor_in_main_tower(self, graph, name): def find_tensor_in_main_tower(self, graph, name):
if self.is_main_tower: if self.is_main_tower:
return graph.get_tensor_by_name(name) return graph.get_tensor_by_name(name)
if name.startswith('towerp'): if name.startswith(PREDICT_TOWER):
newname = re.sub('towerp[0-9]+/', '', name) predict_tower_prefix = '{}[0-9]+/'.format(PREDICT_TOWER)
newname = re.sub(predict_tower_prefix, '', name)
try: try:
return graph.get_tensor_by_name(newname) return graph.get_tensor_by_name(newname)
except KeyError: except KeyError:
newname = re.sub('towerp[0-9]+/', 'tower0/', name) newname = re.sub(predict_tower_prefix, 'tower0/', name)
return graph.get_tensor_by_name(newname) return graph.get_tensor_by_name(newname)
def __enter__(self): def __enter__(self):
......
...@@ -25,8 +25,8 @@ def get_savename_from_varname( ...@@ -25,8 +25,8 @@ def get_savename_from_varname(
:returns: the name used to save the variable :returns: the name used to save the variable
""" """
name = varname name = varname
if 'towerp/' in name: if PREDICT_TOWER in name:
logger.error("No variable should be under 'towerp' name scope".format(v.name)) logger.error("No variable under '{}' name scope should be saved!".format(PREDICT_TOWER))
# don't overwrite anything in the current prediction graph # don't overwrite anything in the current prediction graph
return None return None
if 'tower' in name: if 'tower' in name:
......
...@@ -8,7 +8,7 @@ from six.moves import zip ...@@ -8,7 +8,7 @@ from six.moves import zip
from .base import Trainer from .base import Trainer
from ..utils import logger, SUMMARY_BACKUP_KEYS from ..utils import logger, SUMMARY_BACKUP_KEYS, PREDICT_TOWER
from ..tfutils import (get_tensors_by_names, freeze_collection, from ..tfutils import (get_tensors_by_names, freeze_collection,
get_global_step_var, TowerContext) get_global_step_var, TowerContext)
from ..tfutils.summary import summary_moving_average, add_moving_summary from ..tfutils.summary import summary_moving_average, add_moving_summary
...@@ -39,16 +39,17 @@ class PredictorFactory(object): ...@@ -39,16 +39,17 @@ class PredictorFactory(object):
self._build_predict_tower() self._build_predict_tower()
tower = self.towers[tower % len(self.towers)] tower = self.towers[tower % len(self.towers)]
raw_input_vars = get_tensors_by_names(input_names) raw_input_vars = get_tensors_by_names(input_names)
output_names = ['towerp{}/'.format(tower) + n for n in output_names] output_names = ['{}{}/'.format(PREDICT_TOWER, tower) + n for n in output_names]
output_vars = get_tensors_by_names(output_names) output_vars = get_tensors_by_names(output_names)
return OnlinePredictor(self.sess, raw_input_vars, output_vars) return OnlinePredictor(self.sess, raw_input_vars, output_vars)
def _build_predict_tower(self): def _build_predict_tower(self):
tf.get_variable_scope().reuse_variables() tf.get_variable_scope().reuse_variables()
# build_predict_tower might get called anywhere, but 'towerp' should be the outermost name scope # build_predict_tower might get called anywhere, but 'PREDICT_TOWER' should be the outermost name scope
with tf.name_scope(None), \ with tf.name_scope(None), \
freeze_collection(SUMMARY_BACKUP_KEYS): freeze_collection(SUMMARY_BACKUP_KEYS):
build_multi_tower_prediction_graph(self.model, self.towers) fn = lambda _: self.model.build_graph(self.model.get_input_vars())
build_multi_tower_prediction_graph(fn, self.towers)
self.tower_built = True self.tower_built = True
class SimpleTrainer(Trainer): class SimpleTrainer(Trainer):
......
...@@ -21,9 +21,9 @@ def enable_call_trace(): ...@@ -21,9 +21,9 @@ def enable_call_trace():
if caller: if caller:
caller_line_no = caller.f_lineno caller_line_no = caller.f_lineno
caller_filename = caller.f_code.co_filename caller_filename = caller.f_code.co_filename
print 'Call to `%s` on line %s:%s from %s:%s' % \ print('Call to `%s` on line %s:%s from %s:%s' % \
(func_name, func_filename, func_line_no, (func_name, func_filename, func_line_no,
caller_filename, caller_line_no) caller_filename, caller_line_no))
return return
sys.settrace(tracer) sys.settrace(tracer)
...@@ -31,9 +31,9 @@ if __name__ == '__main__': ...@@ -31,9 +31,9 @@ if __name__ == '__main__':
enable_call_trace() enable_call_trace()
def b(a): def b(a):
print 2 print(2)
def a(): def a():
print 1 print(1)
b(1) b(1)
a() a()
...@@ -5,6 +5,9 @@ ...@@ -5,6 +5,9 @@
GLOBAL_STEP_OP_NAME = 'global_step' GLOBAL_STEP_OP_NAME = 'global_step'
GLOBAL_STEP_VAR_NAME = 'global_step:0' GLOBAL_STEP_VAR_NAME = 'global_step:0'
# prefix of predict tower
PREDICT_TOWER = 'towerp'
# extra variables to summarize during training in a moving-average way # extra variables to summarize during training in a moving-average way
MOVING_SUMMARY_VARS_KEY = 'MOVING_SUMMARY_VARIABLES' MOVING_SUMMARY_VARS_KEY = 'MOVING_SUMMARY_VARIABLES'
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment