Commit 34a5a809 authored by Yuxin Wu's avatar Yuxin Wu

dataparallelmultitower

parent 93beba57
......@@ -32,12 +32,19 @@ class ModelDesc(object):
return self.reuse_input_vars()
except KeyError:
pass
ret = self.get_placeholders()
for v in ret:
tf.add_to_collection(INPUT_VARS_KEY, v)
return ret
def get_placeholders(self, prefix=''):
""" build placeholders with optional prefix, for each InputVar"""
input_vars = self._get_input_vars()
ret = []
for v in input_vars:
ret.append(tf.placeholder(v.type, shape=v.shape, name=v.name))
for v in ret:
tf.add_to_collection(INPUT_VARS_KEY, v)
ret.append(tf.placeholder(
v.type, shape=v.shape,
name=prefix + v.name))
return ret
def reuse_input_vars(self):
......
......@@ -12,7 +12,8 @@ from ..tfutils import get_tensors_by_names, TowerContext
__all__ = ['OnlinePredictor', 'OfflinePredictor',
'AsyncPredictorBase',
'MultiTowerOfflinePredictor', 'build_multi_tower_prediction_graph']
'MultiTowerOfflinePredictor', 'build_multi_tower_prediction_graph',
'DataParallelOfflinePredictor']
class PredictorBase(object):
......@@ -128,7 +129,7 @@ class MultiTowerOfflinePredictor(OnlinePredictor):
for k in towers:
output_vars = get_tensors_by_names(
['{}{}/'.format(self.PREFIX, k) + n \
['towerp{}/'.format(k) + n \
for n in config.output_names])
self.predictors.append(OnlinePredictor(
self.sess, input_vars, output_vars, config.return_input))
......@@ -139,3 +140,29 @@ class MultiTowerOfflinePredictor(OnlinePredictor):
def get_predictors(self, n):
return [self.predictors[k % len(self.predictors)] for k in range(n)]
class DataParallelOfflinePredictor(OnlinePredictor):
def __init__(self, config, towers):
self.graph = tf.Graph()
with self.graph.as_default():
sess = tf.Session(config=config.session_config)
input_var_names = []
for k in towers:
input_vars = config.model.get_placeholders(prefix='towerp{}-'.format(k))
logger.info(
"Building graph for predictor tower {}...".format(k))
with tf.device('/gpu:{}'.format(k) if k >= 0 else '/cpu:0'), \
TowerContext('towerp{}'.format(k)):
config.model.build_graph(input_vars)
tf.get_variable_scope().reuse_variables()
input_var_names.extend([k.name for k in input_vars])
input_vars = get_tensors_by_names(input_var_names)
config.session_init.init(sess)
output_vars = []
for k in towers:
output_vars.extend(get_tensors_by_names(
['towerp{}/'.format(k) + n \
for n in config.output_names]))
super(DataParallelOfflinePredictor, self).__init__(
sess, input_vars, output_vars, config.return_input)
......@@ -46,6 +46,7 @@ class MultiProcessPredictWorker(multiprocessing.Process):
from tensorpack.models._common import disable_layer_logging
disable_layer_logging()
self.predictor = OfflinePredictor(self.config)
import sys
if self.idx == 0:
with self.predictor.graph.as_default():
describe_model()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment