Commit 84790b78 authored by Yuxin Wu's avatar Yuxin Wu

re-organize predict/; fix TF incompatibile change of sparse_softmax_cross_entropy_loss

parent fdc90767
...@@ -47,7 +47,7 @@ class Model(mnist_example.Model): ...@@ -47,7 +47,7 @@ class Model(mnist_example.Model):
wrong = symbolic_functions.prediction_incorrect(logits, label) wrong = symbolic_functions.prediction_incorrect(logits, label)
add_moving_summary(tf.reduce_mean(wrong, name='train_error')) add_moving_summary(tf.reduce_mean(wrong, name='train_error'))
cost = tf.nn.sparse_softmax_cross_entropy_with_logits(logits, label) cost = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=logits, labels=label)
cost = tf.reduce_mean(cost, name='cross_entropy_loss') cost = tf.reduce_mean(cost, name='cross_entropy_loss')
wd_cost = tf.multiply(1e-5, regularize_cost('fc.*/W', tf.nn.l2_loss), wd_cost = tf.multiply(1e-5, regularize_cost('fc.*/W', tf.nn.l2_loss),
name='regularize_loss') name='regularize_loss')
......
...@@ -146,7 +146,7 @@ class Model(ModelDesc): ...@@ -146,7 +146,7 @@ class Model(ModelDesc):
prob = tf.nn.softmax(logits, name='output') prob = tf.nn.softmax(logits, name='output')
cost = tf.nn.sparse_softmax_cross_entropy_with_logits(logits, label) cost = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=logits, labels=label)
cost = tf.reduce_mean(cost, name='cross_entropy_loss') cost = tf.reduce_mean(cost, name='cross_entropy_loss')
wrong = prediction_incorrect(logits, label, 1, name='wrong-top1') wrong = prediction_incorrect(logits, label, 1, name='wrong-top1')
......
...@@ -117,7 +117,7 @@ class Model(ModelDesc): ...@@ -117,7 +117,7 @@ class Model(ModelDesc):
# monitor training error # monitor training error
add_moving_summary(tf.reduce_mean(wrong, name='train_error')) add_moving_summary(tf.reduce_mean(wrong, name='train_error'))
cost = tf.nn.sparse_softmax_cross_entropy_with_logits(logits, label) cost = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=logits, labels=label)
cost = tf.reduce_mean(cost, name='cross_entropy_loss') cost = tf.reduce_mean(cost, name='cross_entropy_loss')
# weight decay on all W of fc layers # weight decay on all W of fc layers
wd_cost = regularize_cost('fc.*/W', l2_regularizer(1e-7)) wd_cost = regularize_cost('fc.*/W', l2_regularizer(1e-7))
......
...@@ -75,7 +75,7 @@ class Model(ModelDesc): ...@@ -75,7 +75,7 @@ class Model(ModelDesc):
br1 = Conv2D('loss1conv', l, 128, 1) br1 = Conv2D('loss1conv', l, 128, 1)
br1 = FullyConnected('loss1fc', br1, 1024, nl=tf.nn.relu) br1 = FullyConnected('loss1fc', br1, 1024, nl=tf.nn.relu)
br1 = FullyConnected('loss1logit', br1, 1000, nl=tf.identity) br1 = FullyConnected('loss1logit', br1, 1000, nl=tf.identity)
loss1 = tf.nn.sparse_softmax_cross_entropy_with_logits(br1, label) loss1 = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=br1, labels=label)
loss1 = tf.reduce_mean(loss1, name='loss1') loss1 = tf.reduce_mean(loss1, name='loss1')
# 14 # 14
...@@ -88,7 +88,7 @@ class Model(ModelDesc): ...@@ -88,7 +88,7 @@ class Model(ModelDesc):
br2 = Conv2D('loss2conv', l, 128, 1) br2 = Conv2D('loss2conv', l, 128, 1)
br2 = FullyConnected('loss2fc', br2, 1024, nl=tf.nn.relu) br2 = FullyConnected('loss2fc', br2, 1024, nl=tf.nn.relu)
br2 = FullyConnected('loss2logit', br2, 1000, nl=tf.identity) br2 = FullyConnected('loss2logit', br2, 1000, nl=tf.identity)
loss2 = tf.nn.sparse_softmax_cross_entropy_with_logits(br2, label) loss2 = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=br2, labels=label)
loss2 = tf.reduce_mean(loss2, name='loss2') loss2 = tf.reduce_mean(loss2, name='loss2')
# 7 # 7
...@@ -98,7 +98,7 @@ class Model(ModelDesc): ...@@ -98,7 +98,7 @@ class Model(ModelDesc):
logits = FullyConnected('linear', l, out_dim=1000, nl=tf.identity) logits = FullyConnected('linear', l, out_dim=1000, nl=tf.identity)
prob = tf.nn.softmax(logits, name='output') prob = tf.nn.softmax(logits, name='output')
loss3 = tf.nn.sparse_softmax_cross_entropy_with_logits(logits, label) loss3 = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=logits, labels=label)
loss3 = tf.reduce_mean(loss3, name='loss3') loss3 = tf.reduce_mean(loss3, name='loss3')
cost = tf.add_n([loss3, 0.3 * loss2, 0.3 * loss1], name='weighted_cost') cost = tf.add_n([loss3, 0.3 * loss2, 0.3 * loss1], name='weighted_cost')
......
...@@ -177,10 +177,10 @@ class Model(ModelDesc): ...@@ -177,10 +177,10 @@ class Model(ModelDesc):
l = Dropout('drop', l, 0.8) l = Dropout('drop', l, 0.8)
logits = FullyConnected('linear', l, out_dim=1000, nl=tf.identity) logits = FullyConnected('linear', l, out_dim=1000, nl=tf.identity)
loss1 = tf.nn.sparse_softmax_cross_entropy_with_logits(br1, label) loss1 = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=br1, labels=label)
loss1 = tf.reduce_mean(loss1, name='loss1') loss1 = tf.reduce_mean(loss1, name='loss1')
loss2 = tf.nn.sparse_softmax_cross_entropy_with_logits(logits, label) loss2 = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=logits, labels=label)
loss2 = tf.reduce_mean(loss2, name='loss2') loss2 = tf.reduce_mean(loss2, name='loss2')
wrong = prediction_incorrect(logits, label, 1, name='wrong-top1') wrong = prediction_incorrect(logits, label, 1, name='wrong-top1')
......
...@@ -90,7 +90,7 @@ class Model(ModelDesc): ...@@ -90,7 +90,7 @@ class Model(ModelDesc):
logits = FullyConnected('linear', l, out_dim=10, nl=tf.identity) logits = FullyConnected('linear', l, out_dim=10, nl=tf.identity)
prob = tf.nn.softmax(logits, name='output') prob = tf.nn.softmax(logits, name='output')
cost = tf.nn.sparse_softmax_cross_entropy_with_logits(logits, label) cost = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=logits, labels=label)
cost = tf.reduce_mean(cost, name='cross_entropy_loss') cost = tf.reduce_mean(cost, name='cross_entropy_loss')
wrong = prediction_incorrect(logits, label) wrong = prediction_incorrect(logits, label)
......
...@@ -102,7 +102,7 @@ class Model(ModelDesc): ...@@ -102,7 +102,7 @@ class Model(ModelDesc):
.GlobalAvgPooling('gap') .GlobalAvgPooling('gap')
.FullyConnected('linear', 1000, nl=tf.identity)()) .FullyConnected('linear', 1000, nl=tf.identity)())
loss = tf.nn.sparse_softmax_cross_entropy_with_logits(logits, label) loss = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=logits, labels=label)
loss = tf.reduce_mean(loss, name='xentropy-loss') loss = tf.reduce_mean(loss, name='xentropy-loss')
wrong = prediction_incorrect(logits, label, 1, name='wrong-top1') wrong = prediction_incorrect(logits, label, 1, name='wrong-top1')
......
...@@ -75,7 +75,7 @@ class Model(ModelDesc): ...@@ -75,7 +75,7 @@ class Model(ModelDesc):
.FullyConnected('fct', out_dim=19, nl=tf.identity)()) .FullyConnected('fct', out_dim=19, nl=tf.identity)())
prob = tf.nn.softmax(logits, name='prob') prob = tf.nn.softmax(logits, name='prob')
cost = tf.nn.sparse_softmax_cross_entropy_with_logits(logits, label) cost = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=logits, labels=label)
cost = tf.reduce_mean(cost, name='cross_entropy_loss') cost = tf.reduce_mean(cost, name='cross_entropy_loss')
wrong = symbolic_functions.prediction_incorrect(logits, label) wrong = symbolic_functions.prediction_incorrect(logits, label)
......
...@@ -89,7 +89,7 @@ class Model(ModelDesc): ...@@ -89,7 +89,7 @@ class Model(ModelDesc):
self.prob = tf.nn.softmax(logits / param.softmax_temprature) self.prob = tf.nn.softmax(logits / param.softmax_temprature)
xent_loss = tf.nn.sparse_softmax_cross_entropy_with_logits( xent_loss = tf.nn.sparse_softmax_cross_entropy_with_logits(
logits, symbolic_functions.flatten(nextinput)) logits=logits, labels=symbolic_functions.flatten(nextinput))
self.cost = tf.reduce_mean(xent_loss, name='cost') self.cost = tf.reduce_mean(xent_loss, name='cost')
summary.add_param_summary(('.*/W', ['histogram'])) # monitor histogram of all W summary.add_param_summary(('.*/W', ['histogram'])) # monitor histogram of all W
summary.add_moving_summary(self.cost) summary.add_moving_summary(self.cost)
......
...@@ -58,7 +58,7 @@ class Model(ModelDesc): ...@@ -58,7 +58,7 @@ class Model(ModelDesc):
.FullyConnected('fc1', 512, nl=tf.nn.relu) \ .FullyConnected('fc1', 512, nl=tf.nn.relu) \
.FullyConnected('linear', out_dim=self.cifar_classnum, nl=tf.identity)() .FullyConnected('linear', out_dim=self.cifar_classnum, nl=tf.identity)()
cost = tf.nn.sparse_softmax_cross_entropy_with_logits(logits, label) cost = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=logits, labels=label)
cost = tf.reduce_mean(cost, name='cross_entropy_loss') cost = tf.reduce_mean(cost, name='cross_entropy_loss')
wrong = symbf.prediction_incorrect(logits, label) wrong = symbf.prediction_incorrect(logits, label)
......
...@@ -80,7 +80,7 @@ class Model(ModelDesc): ...@@ -80,7 +80,7 @@ class Model(ModelDesc):
prob = tf.nn.softmax(logits, name='prob') # a Bx10 with probabilities prob = tf.nn.softmax(logits, name='prob') # a Bx10 with probabilities
# a vector of length B with loss of each sample # a vector of length B with loss of each sample
cost = tf.nn.sparse_softmax_cross_entropy_with_logits(logits, label) cost = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=logits, labels=label)
cost = tf.reduce_mean(cost, name='cross_entropy_loss') # the average cross-entropy loss cost = tf.reduce_mean(cost, name='cross_entropy_loss') # the average cross-entropy loss
# compute the "incorrect vector", for the callback ClassificationError to use at validation time # compute the "incorrect vector", for the callback ClassificationError to use at validation time
......
...@@ -51,7 +51,7 @@ class Model(ModelDesc): ...@@ -51,7 +51,7 @@ class Model(ModelDesc):
# monitor training error # monitor training error
add_moving_summary(tf.reduce_mean(wrong, name='train_error')) add_moving_summary(tf.reduce_mean(wrong, name='train_error'))
cost = tf.nn.sparse_softmax_cross_entropy_with_logits(logits, label) cost = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=logits, labels=label)
cost = tf.reduce_mean(cost, name='cross_entropy_loss') cost = tf.reduce_mean(cost, name='cross_entropy_loss')
wd_cost = regularize_cost('fc.*/W', l2_regularizer(0.00001)) wd_cost = regularize_cost('fc.*/W', l2_regularizer(0.00001))
......
...@@ -9,11 +9,13 @@ import six ...@@ -9,11 +9,13 @@ import six
from six.moves import zip, range from six.moves import zip, range
from ..dataflow import DataFlow from ..dataflow import DataFlow
from .base import Callback from ..utils import logger, get_tqdm, PREDICT_TOWER
from .inference import Inferencer
from ..utils import logger, get_tqdm
from ..tfutils.common import get_op_tensor_name from ..tfutils.common import get_op_tensor_name
from ..train.input_data import FeedfreeInput from ..train.input_data import FeedfreeInput
from ..predict import build_prediction_graph
from .base import Callback
from .inference import Inferencer
__all__ = ['InferenceRunner'] __all__ = ['InferenceRunner']
...@@ -142,6 +144,10 @@ class FeedfreeInferenceRunner(Callback): ...@@ -142,6 +144,10 @@ class FeedfreeInferenceRunner(Callback):
IOTensor = namedtuple('IOTensor', ['index', 'isOutput']) IOTensor = namedtuple('IOTensor', ['index', 'isOutput'])
def __init__(self, input, infs, input_names=None): def __init__(self, input, infs, input_names=None):
"""
Args:
input_names (list): must be a subset of the names of InputVar.
"""
assert isinstance(input, FeedfreeInput), input assert isinstance(input, FeedfreeInput), input
self._input_data = input self._input_data = input
if not isinstance(infs, list): if not isinstance(infs, list):
...@@ -154,10 +160,20 @@ class FeedfreeInferenceRunner(Callback): ...@@ -154,10 +160,20 @@ class FeedfreeInferenceRunner(Callback):
assert isinstance(input_names, list) assert isinstance(input_names, list)
self._input_names = input_names self._input_names = input_names
try:
self._size = input.size()
except NotImplementedError:
raise ValueError("Input used in FeedfreeInferencecRunner must have a size!")
def _setup_graph(self): def _setup_graph(self):
self._find_input_tensors() # tensors self._find_input_tensors() # tensors
def fn(_):
self.trainer.model.build_graph(self._input_tensors)
build_prediction_graph(fn, [0])
self._tower_prefix = PREDICT_TOWER + '0'
self._find_output_tensors() self._find_output_tensors()
# TODO build tower
def _find_input_tensors(self): def _find_input_tensors(self):
self._input_data._setup(self.trainer) self._input_data._setup(self.trainer)
...@@ -165,25 +181,32 @@ class FeedfreeInferenceRunner(Callback): ...@@ -165,25 +181,32 @@ class FeedfreeInferenceRunner(Callback):
self._input_tensors = self._input_data.get_input_tensors() self._input_tensors = self._input_data.get_input_tensors()
model_placehdrs = self.trainer.model.get_reuse_placehdrs() model_placehdrs = self.trainer.model.get_reuse_placehdrs()
if self.input_names is not None: if self.input_names is not None:
raise NotImplementedError("Random code. Not tested.")
assert len(self.input_names) == len(self._input_tensors), \ assert len(self.input_names) == len(self._input_tensors), \
"[FeedfreeInferenceRunner] input_names must have the same length as the input data." "[FeedfreeInferenceRunner] input_names must have the same length as the input data."
# XXX incorrect for n, tensor in zip(self.input_names, self._input_tensors):
self._input_tensors = [k for idx, k in enumerate(self._input_tensors) opname, _ = get_op_tensor_name(n)
if model_placehdrs[idx].name in self.input_names] for idx, hdr in enumerate(model_placehdrs):
assert len(self._input_tensors) == len(self.input_names), \ if hdr.name == opname:
"[FeedfreeInferenceRunner] all input_tensors must be defined as InputVar in the Model!" model_placehdrs[idx] = tensor
break
else:
raise ValueError(
"{} doesn't appear in the InputVar of the model!".format(n))
self._input_tensors = model_placehdrs
assert len(self._input_tensors) == len(model_placehdrs), \ assert len(self._input_tensors) == len(model_placehdrs), \
"FeedfreeInput doesn't produce correct number of output tensors" "[FeedfreeInferenceRunner] Unmatched length of input tensors!"
def _find_output_tensors(self): def _find_output_tensors(self):
# TODO doesn't support output an input tensor # TODO doesn't support output an input tensor
# TODO find tensors, not names
dispatcer = OutputTensorDispatcer() dispatcer = OutputTensorDispatcer()
for inf in self.infs: for inf in self.infs:
dispatcer.add_entry(inf.get_output_tensors()) dispatcer.add_entry(inf.get_output_tensors())
all_names = dispatcer.get_all_names() all_names = dispatcer.get_all_names()
IOTensor = InferenceRunner.IOTensor IOTensor = FeedfreeInferenceRunner.IOTensor
self.output_tensors = all_names self.output_tensors = all_names
def find_oid(idxs): def find_oid(idxs):
......
...@@ -13,8 +13,9 @@ from ..tfutils import get_tensors_by_names, TowerContext ...@@ -13,8 +13,9 @@ from ..tfutils import get_tensors_by_names, TowerContext
__all__ = ['PredictorBase', 'AsyncPredictorBase', __all__ = ['PredictorBase', 'AsyncPredictorBase',
'OnlinePredictor', 'OfflinePredictor', 'OnlinePredictor', 'OfflinePredictor',
'MultiTowerOfflinePredictor', 'build_multi_tower_prediction_graph', 'get_predict_func',
'DataParallelOfflinePredictor'] 'build_prediction_graph',
]
@six.add_metaclass(ABCMeta) @six.add_metaclass(ABCMeta)
...@@ -136,7 +137,14 @@ class OfflinePredictor(OnlinePredictor): ...@@ -136,7 +137,14 @@ class OfflinePredictor(OnlinePredictor):
sess, input_vars, output_vars, config.return_input) sess, input_vars, output_vars, config.return_input)
def build_multi_tower_prediction_graph(build_tower_fn, towers): def get_predict_func(config):
"""
Equivalent to ``OfflinePredictor(config)``.
"""
return OfflinePredictor(config)
def build_prediction_graph(build_tower_fn, towers=[0]):
""" """
Args: Args:
build_tower_fn: a function that will be called inside each tower, build_tower_fn: a function that will be called inside each tower,
...@@ -150,81 +158,3 @@ def build_multi_tower_prediction_graph(build_tower_fn, towers): ...@@ -150,81 +158,3 @@ def build_multi_tower_prediction_graph(build_tower_fn, towers):
TowerContext('{}{}'.format(PREDICT_TOWER, k)): TowerContext('{}{}'.format(PREDICT_TOWER, k)):
build_tower_fn(k) build_tower_fn(k)
tf.get_variable_scope().reuse_variables() tf.get_variable_scope().reuse_variables()
class MultiTowerOfflinePredictor(OnlinePredictor):
""" A multi-tower multi-GPU predictor. """
def __init__(self, config, towers):
"""
Args:
config (PredictConfig): the config to use.
towers: a list of relative GPU id.
"""
self.graph = tf.Graph()
self.predictors = []
with self.graph.as_default():
# TODO backup summary keys?
def fn(_):
config.model.build_graph(config.model.get_input_vars())
build_multi_tower_prediction_graph(fn, towers)
self.sess = tf.Session(config=config.session_config)
config.session_init.init(self.sess)
input_vars = get_tensors_by_names(config.input_names)
for k in towers:
output_vars = get_tensors_by_names(
['{}{}/'.format(PREDICT_TOWER, k) + n
for n in config.output_names])
self.predictors.append(OnlinePredictor(
self.sess, input_vars, output_vars, config.return_input))
def _do_call(self, dp):
# use the first tower for compatible PredictorBase interface
return self.predictors[0]._do_call(dp)
def get_predictors(self, n):
"""
Returns:
PredictorBase: the nth predictor on the nth GPU.
"""
return [self.predictors[k % len(self.predictors)] for k in range(n)]
class DataParallelOfflinePredictor(OnlinePredictor):
""" A data-parallel predictor.
It runs different towers in parallel.
"""
def __init__(self, config, towers):
"""
Args:
config (PredictConfig): the config to use.
towers: a list of relative GPU id.
"""
self.graph = tf.Graph()
with self.graph.as_default():
sess = tf.Session(config=config.session_config)
input_var_names = []
output_vars = []
for k in towers:
towername = PREDICT_TOWER + str(k)
input_vars = config.model.build_placeholders(
prefix=towername + '-')
logger.info(
"Building graph for predictor tower {}...".format(k))
with tf.device('/gpu:{}'.format(k) if k >= 0 else '/cpu:0'), \
TowerContext(towername, is_training=False):
config.model.build_graph(input_vars)
tf.get_variable_scope().reuse_variables()
input_var_names.extend([k.name for k in input_vars])
output_vars.extend(get_tensors_by_names(
[towername + '/' + n
for n in config.output_names]))
input_vars = get_tensors_by_names(input_var_names)
config.session_init.init(sess)
super(DataParallelOfflinePredictor, self).__init__(
sess, input_vars, output_vars, config.return_input)
# -*- coding: UTF-8 -*-
# File: common.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
import six
from tensorpack.models import ModelDesc
from ..tfutils import get_default_sess_config
from ..tfutils.sessinit import SessionInit, JustCurrentSession
from .base import OfflinePredictor
__all__ = ['PredictConfig', 'get_predict_func']
class PredictConfig(object):
def __init__(self, **kwargs):
"""
Args:
session_init (SessionInit): how to initialize variables of the session.
model (ModelDesc): the model to use.
input_names (list): a list of input tensor names.
output_names (list): a list of names of the output tensors to predict, the
tensors can be any computable tensor in the graph.
return_input: same as in :attr:`PredictorBase.return_input`.
"""
# TODO use the name "tensor" instead of "variable"
def assert_type(v, tp):
assert isinstance(v, tp), v.__class__
# XXX does it work? start with minimal memory, but allow growth.
# allow_growth doesn't seem to work very well in TF.
self.session_config = kwargs.pop('session_config', get_default_sess_config(0.4))
self.session_init = kwargs.pop('session_init', JustCurrentSession())
assert_type(self.session_init, SessionInit)
self.model = kwargs.pop('model')
assert_type(self.model, ModelDesc)
# inputs & outputs
# TODO add deprecated warning later
self.input_names = kwargs.pop('input_names', None)
if self.input_names is None:
self.input_names = kwargs.pop('input_var_names', None)
if self.input_names is not None:
pass
# logger.warn("[Deprecated] input_var_names is deprecated in PredictConfig. Use input_names instead!")
if self.input_names is None:
# neither options is set, assume all inputs
raw_vars = self.model.get_input_vars_desc()
self.input_names = [k.name for k in raw_vars]
self.output_names = kwargs.pop('output_names', None)
if self.output_names is None:
self.output_names = kwargs.pop('output_var_names')
# logger.warn("[Deprecated] output_var_names is deprecated in PredictConfig. Use output_names instead!")
assert len(self.input_names), self.input_names
for v in self.input_names:
assert_type(v, six.string_types)
assert len(self.output_names), self.output_names
self.return_input = kwargs.pop('return_input', False)
assert len(kwargs) == 0, 'Unknown arguments: {}'.format(str(kwargs.keys()))
def get_predict_func(config):
"""
Equivalent to ``OfflinePredictor(config)``.
"""
return OfflinePredictor(config)
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# File: multigpu.py
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
import tensorflow as tf
from ..utils import logger
from ..utils.naming import PREDICT_TOWER
from ..tfutils import get_tensors_by_names, TowerContext
from .base import OnlinePredictor, build_prediction_graph
__all__ = ['MultiTowerOfflinePredictor',
'DataParallelOfflinePredictor']
class MultiTowerOfflinePredictor(OnlinePredictor):
""" A multi-tower multi-GPU predictor. """
def __init__(self, config, towers):
"""
Args:
config (PredictConfig): the config to use.
towers: a list of relative GPU id.
"""
self.graph = tf.Graph()
self.predictors = []
with self.graph.as_default():
# TODO backup summary keys?
def fn(_):
config.model.build_graph(config.model.get_input_vars())
build_prediction_graph(fn, towers)
self.sess = tf.Session(config=config.session_config)
config.session_init.init(self.sess)
input_vars = get_tensors_by_names(config.input_names)
for k in towers:
output_vars = get_tensors_by_names(
['{}{}/'.format(PREDICT_TOWER, k) + n
for n in config.output_names])
self.predictors.append(OnlinePredictor(
self.sess, input_vars, output_vars, config.return_input))
def _do_call(self, dp):
# use the first tower for compatible PredictorBase interface
return self.predictors[0]._do_call(dp)
def get_predictors(self, n):
"""
Returns:
PredictorBase: the nth predictor on the nth GPU.
"""
return [self.predictors[k % len(self.predictors)] for k in range(n)]
class DataParallelOfflinePredictor(OnlinePredictor):
""" A data-parallel predictor.
It runs different towers in parallel.
"""
def __init__(self, config, towers):
"""
Args:
config (PredictConfig): the config to use.
towers: a list of relative GPU id.
"""
self.graph = tf.Graph()
with self.graph.as_default():
sess = tf.Session(config=config.session_config)
input_var_names = []
output_vars = []
for k in towers:
towername = PREDICT_TOWER + str(k)
input_vars = config.model.build_placeholders(
prefix=towername + '-')
logger.info(
"Building graph for predictor tower {}...".format(k))
with tf.device('/gpu:{}'.format(k) if k >= 0 else '/cpu:0'), \
TowerContext(towername, is_training=False):
config.model.build_graph(input_vars)
tf.get_variable_scope().reuse_variables()
input_var_names.extend([k.name for k in input_vars])
output_vars.extend(get_tensors_by_names(
[towername + '/' + n
for n in config.output_names]))
input_vars = get_tensors_by_names(input_var_names)
config.session_init.init(sess)
super(DataParallelOfflinePredictor, self).__init__(
sess, input_vars, output_vars, config.return_input)
...@@ -163,7 +163,8 @@ class Trainer(object): ...@@ -163,7 +163,8 @@ class Trainer(object):
self.config.starting_epoch, self.config.max_epoch + 1): self.config.starting_epoch, self.config.max_epoch + 1):
with timed_operation( with timed_operation(
'Epoch {} (global_step {})'.format( 'Epoch {} (global_step {})'.format(
epoch_num, get_global_step() + self.config.step_per_epoch)): epoch_num, get_global_step() + self.config.step_per_epoch),
log_start=True):
for step in tqdm.trange( for step in tqdm.trange(
self.config.step_per_epoch, self.config.step_per_epoch,
**get_tqdm_kwargs(leave=True)): **get_tqdm_kwargs(leave=True)):
......
...@@ -189,7 +189,7 @@ class TensorInput(FeedfreeInput): ...@@ -189,7 +189,7 @@ class TensorInput(FeedfreeInput):
def size(self): def size(self):
if self._size is None: if self._size is None:
raise ValueError("size of TensorInput is undefined!") raise NotImplementedError("size of TensorInput is undefined!")
return self._size return self._size
def _setup(self, trainer): def _setup(self, trainer):
......
...@@ -10,7 +10,7 @@ from ..utils import SUMMARY_BACKUP_KEYS, PREDICT_TOWER ...@@ -10,7 +10,7 @@ from ..utils import SUMMARY_BACKUP_KEYS, PREDICT_TOWER
from ..tfutils import (get_tensors_by_names, freeze_collection, from ..tfutils import (get_tensors_by_names, freeze_collection,
get_global_step_var, TowerContext) get_global_step_var, TowerContext)
from ..tfutils.summary import summary_moving_average from ..tfutils.summary import summary_moving_average
from ..predict import OnlinePredictor, build_multi_tower_prediction_graph from ..predict import OnlinePredictor, build_prediction_graph
from ..tfutils.gradproc import apply_grad_processors from ..tfutils.gradproc import apply_grad_processors
from .input_data import FeedInput from .input_data import FeedInput
...@@ -49,7 +49,7 @@ class PredictorFactory(object): ...@@ -49,7 +49,7 @@ class PredictorFactory(object):
freeze_collection(SUMMARY_BACKUP_KEYS): freeze_collection(SUMMARY_BACKUP_KEYS):
def fn(_): def fn(_):
self.model.build_graph(self.model.get_input_vars()) self.model.build_graph(self.model.get_input_vars())
build_multi_tower_prediction_graph(fn, self.towers) build_prediction_graph(fn, self.towers)
self.tower_built = True self.tower_built = True
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment