Commit 34a5a809 authored by Yuxin Wu's avatar Yuxin Wu

dataparallelmultitower

parent 93beba57
...@@ -32,12 +32,19 @@ class ModelDesc(object): ...@@ -32,12 +32,19 @@ class ModelDesc(object):
return self.reuse_input_vars() return self.reuse_input_vars()
except KeyError: except KeyError:
pass pass
ret = self.get_placeholders()
for v in ret:
tf.add_to_collection(INPUT_VARS_KEY, v)
return ret
def get_placeholders(self, prefix=''):
""" build placeholders with optional prefix, for each InputVar"""
input_vars = self._get_input_vars() input_vars = self._get_input_vars()
ret = [] ret = []
for v in input_vars: for v in input_vars:
ret.append(tf.placeholder(v.type, shape=v.shape, name=v.name)) ret.append(tf.placeholder(
for v in ret: v.type, shape=v.shape,
tf.add_to_collection(INPUT_VARS_KEY, v) name=prefix + v.name))
return ret return ret
def reuse_input_vars(self): def reuse_input_vars(self):
......
...@@ -12,7 +12,8 @@ from ..tfutils import get_tensors_by_names, TowerContext ...@@ -12,7 +12,8 @@ from ..tfutils import get_tensors_by_names, TowerContext
__all__ = ['OnlinePredictor', 'OfflinePredictor', __all__ = ['OnlinePredictor', 'OfflinePredictor',
'AsyncPredictorBase', 'AsyncPredictorBase',
'MultiTowerOfflinePredictor', 'build_multi_tower_prediction_graph'] 'MultiTowerOfflinePredictor', 'build_multi_tower_prediction_graph',
'DataParallelOfflinePredictor']
class PredictorBase(object): class PredictorBase(object):
...@@ -128,7 +129,7 @@ class MultiTowerOfflinePredictor(OnlinePredictor): ...@@ -128,7 +129,7 @@ class MultiTowerOfflinePredictor(OnlinePredictor):
for k in towers: for k in towers:
output_vars = get_tensors_by_names( output_vars = get_tensors_by_names(
['{}{}/'.format(self.PREFIX, k) + n \ ['towerp{}/'.format(k) + n \
for n in config.output_names]) for n in config.output_names])
self.predictors.append(OnlinePredictor( self.predictors.append(OnlinePredictor(
self.sess, input_vars, output_vars, config.return_input)) self.sess, input_vars, output_vars, config.return_input))
...@@ -139,3 +140,29 @@ class MultiTowerOfflinePredictor(OnlinePredictor): ...@@ -139,3 +140,29 @@ class MultiTowerOfflinePredictor(OnlinePredictor):
def get_predictors(self, n): def get_predictors(self, n):
return [self.predictors[k % len(self.predictors)] for k in range(n)] return [self.predictors[k % len(self.predictors)] for k in range(n)]
class DataParallelOfflinePredictor(OnlinePredictor):
def __init__(self, config, towers):
self.graph = tf.Graph()
with self.graph.as_default():
sess = tf.Session(config=config.session_config)
input_var_names = []
for k in towers:
input_vars = config.model.get_placeholders(prefix='towerp{}-'.format(k))
logger.info(
"Building graph for predictor tower {}...".format(k))
with tf.device('/gpu:{}'.format(k) if k >= 0 else '/cpu:0'), \
TowerContext('towerp{}'.format(k)):
config.model.build_graph(input_vars)
tf.get_variable_scope().reuse_variables()
input_var_names.extend([k.name for k in input_vars])
input_vars = get_tensors_by_names(input_var_names)
config.session_init.init(sess)
output_vars = []
for k in towers:
output_vars.extend(get_tensors_by_names(
['towerp{}/'.format(k) + n \
for n in config.output_names]))
super(DataParallelOfflinePredictor, self).__init__(
sess, input_vars, output_vars, config.return_input)
...@@ -46,6 +46,7 @@ class MultiProcessPredictWorker(multiprocessing.Process): ...@@ -46,6 +46,7 @@ class MultiProcessPredictWorker(multiprocessing.Process):
from tensorpack.models._common import disable_layer_logging from tensorpack.models._common import disable_layer_logging
disable_layer_logging() disable_layer_logging()
self.predictor = OfflinePredictor(self.config) self.predictor = OfflinePredictor(self.config)
import sys
if self.idx == 0: if self.idx == 0:
with self.predictor.graph.as_default(): with self.predictor.graph.as_default():
describe_model() describe_model()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment