Commit 062790c4 authored by Yuxin Wu's avatar Yuxin Wu

DataParallelInferenceRunner (a complicated implementation). (fix #139)

parent e2b985ca
...@@ -196,7 +196,8 @@ class MySimulatorMaster(SimulatorMaster, Callback): ...@@ -196,7 +196,8 @@ class MySimulatorMaster(SimulatorMaster, Callback):
def get_config(): def get_config():
logger.auto_set_dir() dirname = os.path.join('train_log', 'train-atari-{}'.format(ENV_NAME))
logger.set_logger_dir(dirname)
M = Model() M = Model()
name_base = str(uuid.uuid1())[:6] name_base = str(uuid.uuid1())[:6]
......
...@@ -15,6 +15,8 @@ from ..tfutils import get_op_tensor_name ...@@ -15,6 +15,8 @@ from ..tfutils import get_op_tensor_name
__all__ = ['ScalarStats', 'Inferencer', __all__ = ['ScalarStats', 'Inferencer',
'ClassificationError', 'BinaryClassificationStats'] 'ClassificationError', 'BinaryClassificationStats']
# TODO rename get_output_tensors to get_output_names
@six.add_metaclass(ABCMeta) @six.add_metaclass(ABCMeta)
class Inferencer(object): class Inferencer(object):
......
...@@ -15,14 +15,16 @@ from six.moves import zip ...@@ -15,14 +15,16 @@ from six.moves import zip
from ..utils import logger, get_tqdm_kwargs from ..utils import logger, get_tqdm_kwargs
from ..dataflow import DataFlow from ..dataflow import DataFlow
from ..tfutils.common import get_op_tensor_name from ..tfutils.common import get_op_tensor_name, get_tensors_by_names
from ..tfutils.tower import TowerContext
from ..train.input_data import TensorInput, FeedInput from ..train.input_data import TensorInput, FeedInput
from ..predict import PredictorTowerBuilder from ..predict import PredictorTowerBuilder
from .base import Triggerable from .base import Triggerable
from .inference import Inferencer from .inference import Inferencer
__all__ = ['InferenceRunner', 'FeedfreeInferenceRunner'] __all__ = ['InferenceRunner', 'FeedfreeInferenceRunner',
'DataParallelInferenceRunner']
class InferencerToHook(tf.train.SessionRunHook): class InferencerToHook(tf.train.SessionRunHook):
...@@ -98,6 +100,7 @@ class InferenceRunnerBase(Triggerable): ...@@ -98,6 +100,7 @@ class InferenceRunnerBase(Triggerable):
self._input_data.setup(self.trainer.model) self._input_data.setup(self.trainer.model)
self._setup_input_names() self._setup_input_names()
in_tensors = self._find_input_tensors() in_tensors = self._find_input_tensors()
assert isinstance(in_tensors, list), in_tensors
with tf.variable_scope(tf.get_variable_scope(), reuse=True): with tf.variable_scope(tf.get_variable_scope(), reuse=True):
def fn(_): def fn(_):
...@@ -115,7 +118,6 @@ class InferenceRunnerBase(Triggerable): ...@@ -115,7 +118,6 @@ class InferenceRunnerBase(Triggerable):
get_tensor_fn = PredictorTowerBuilder.get_tensors_maybe_in_tower get_tensor_fn = PredictorTowerBuilder.get_tensors_maybe_in_tower
return get_tensor_fn(placeholder_names, names, 0, prefix=self._prefix) return get_tensor_fn(placeholder_names, names, 0, prefix=self._prefix)
@abstractmethod
def _find_input_tensors(self): def _find_input_tensors(self):
pass pass
...@@ -230,3 +232,97 @@ class FeedfreeInferenceRunner(InferenceRunnerBase): ...@@ -230,3 +232,97 @@ class FeedfreeInferenceRunner(InferenceRunnerBase):
idx = placeholder_names.index(name) idx = placeholder_names.index(name)
ret.append(self._input_tensors[idx]) ret.append(self._input_tensors[idx])
return InferencerToHook(inf, ret) return InferencerToHook(inf, ret)
class DataParallelInferenceRunner(InferenceRunner):
def __init__(self, ds, infs, gpus, input_names=None):
super(DataParallelInferenceRunner, self).__init__(ds, infs, input_names)
self._gpus = gpus
def _setup_graph(self):
model = self.trainer.model
self._input_data.setup(model)
self._setup_input_names()
# build graph
def build_tower(k):
towername = TowerContext.get_predict_tower_name(k)
# inputs (placeholders) for this tower only
input_tensors = model.build_placeholders(
prefix=towername + '/')
model.build_graph(input_tensors)
builder = PredictorTowerBuilder(build_tower, prefix=self._prefix)
with tf.variable_scope(tf.get_variable_scope(), reuse=True):
for t in self._gpus:
builder.build(t)
# setup feeds and hooks
self._feed_tensors = self._find_feed_tensors()
self._hooks_parallel = [self._build_hook_parallel(inf) for inf in self.infs]
self._hooks = [self._build_hook(inf) for inf in self.infs]
def _duplicate_names_across_towers(self, names):
ret = []
for t in self._gpus:
ret.extend([TowerContext.get_predict_tower_name(t, self._prefix) +
'/' + n for n in names])
return ret
def _find_feed_tensors(self):
names = self._duplicate_names_across_towers(self.input_names)
return get_tensors_by_names(names)
class InferencerToHookDataParallel(InferencerToHook):
def __init__(self, inf, fetches, size):
super(DataParallelInferenceRunner.InferencerToHookDataParallel, self).__init__(inf, fetches)
assert len(self._fetches) % size == 0
self._sz = size
def after_run(self, _, run_values):
res = run_values.results
for i in range(0, len(res), self._sz):
vals = res[i:i + self._sz]
self._inf.datapoint(vals)
def _build_hook_parallel(self, inf):
out_names = inf.get_output_tensors()
sz = len(out_names)
out_names = self._duplicate_names_across_towers(out_names)
fetches = get_tensors_by_names(out_names)
return DataParallelInferenceRunner.InferencerToHookDataParallel(
inf, fetches, sz)
def _build_hook(self, inf):
out_names = inf.get_output_tensors()
names = [TowerContext.get_predict_tower_name(
self._gpus[0], self._prefix) + '/' + n for n in out_names]
fetches = get_tensors_by_names(names)
return InferencerToHook(inf, fetches)
def _before_train(self):
self._hooked_sess = HookedSession(self.trainer.sess, self._hooks)
self._parallel_hooked_sess = HookedSession(self.trainer.sess, self._hooks_parallel)
def _trigger(self):
for inf in self.infs:
inf.before_inference()
self._input_data.reset_state()
total = self._input_data.size()
nr_tower = len(self._gpus)
with tqdm.tqdm(total=total, **get_tqdm_kwargs()) as pbar:
while total >= nr_tower:
dps = []
for k in self._gpus:
dps.extend(self._input_data.next_feed())
feed = dict(zip(self._feed_tensors, dps))
self._parallel_hooked_sess.run(fetches=[], feed_dict=feed)
pbar.update(nr_tower)
total -= nr_tower
# take care of the rest
while total > 0:
dp = self._input_data.next_feed()
feed = dict(zip(self._feed_tensors[:len(dp)], dp))
self._hooked_sess.run(fetches=[], feed_dict=feed)
summary_inferencer(self.trainer, self.infs)
...@@ -199,6 +199,8 @@ def CaffeLMDB(lmdb_path, shuffle=True, keys=None): ...@@ -199,6 +199,8 @@ def CaffeLMDB(lmdb_path, shuffle=True, keys=None):
log_once("Cannot read key {}".format(k), 'warn') log_once("Cannot read key {}".format(k), 'warn')
return None return None
return [img.transpose(1, 2, 0), datum.label] return [img.transpose(1, 2, 0), datum.label]
logger.warn("Caffe LMDB format doesn't store jpeg-compressed images, \
it's not recommended due to its inferior performance.")
return LMDBDataDecoder(lmdb_data, decoder) return LMDBDataDecoder(lmdb_data, decoder)
......
...@@ -192,11 +192,14 @@ class PredictorTowerBuilder(object): ...@@ -192,11 +192,14 @@ class PredictorTowerBuilder(object):
TowerContext(towername, is_training=False): TowerContext(towername, is_training=False):
self._fn(tower) self._fn(tower)
# useful only when the placeholders don't have tower prefix
# note that in DataParallel predictor, placeholders do have tower prefix
@staticmethod @staticmethod
def get_tensors_maybe_in_tower(placeholder_names, names, k, prefix=''): def get_tensors_maybe_in_tower(placeholder_names, names, tower, prefix=''):
""" """
Args: Args:
placeholders (list): A list of __op__ name. placeholders (list): A list of __op__ name.
tower (int): relative GPU id.
""" """
def maybe_inside_tower(name): def maybe_inside_tower(name):
name = get_op_tensor_name(name)[0] name = get_op_tensor_name(name)[0]
...@@ -204,7 +207,7 @@ class PredictorTowerBuilder(object): ...@@ -204,7 +207,7 @@ class PredictorTowerBuilder(object):
return name return name
else: else:
# if the name is not a placeholder, use it's name in each tower # if the name is not a placeholder, use it's name in each tower
return TowerContext.get_predict_tower_name(k, prefix) + '/' + name return TowerContext.get_predict_tower_name(tower, prefix) + '/' + name
names = list(map(maybe_inside_tower, names)) names = list(map(maybe_inside_tower, names))
tensors = get_tensors_by_names(names) tensors = get_tensors_by_names(names)
return tensors return tensors
......
...@@ -66,8 +66,9 @@ class DataParallelOfflinePredictor(OnlinePredictor): ...@@ -66,8 +66,9 @@ class DataParallelOfflinePredictor(OnlinePredictor):
""" """
A data-parallel predictor. A data-parallel predictor.
Note that it doesn't split/concat inputs/outputs automatically. Note that it doesn't split/concat inputs/outputs automatically.
Its input is: ``[input[0] in tower[0], input[1] in tower[0], ..., input[0] in tower[1], input[1] in tower[1], ...]`` Instead, its inputs are:
And same for the output. ``[input[0] in tower[0], input[1] in tower[0], ..., input[0] in tower[1], input[1] in tower[1], ...]``
Similar for the outputs.
""" """
def __init__(self, config, towers): def __init__(self, config, towers):
......
...@@ -125,6 +125,7 @@ def add_moving_summary(v, *args, **kwargs): ...@@ -125,6 +125,7 @@ def add_moving_summary(v, *args, **kwargs):
assert isinstance(x, tf.Tensor), x assert isinstance(x, tf.Tensor), x
assert x.get_shape().ndims == 0, x.get_shape() assert x.get_shape().ndims == 0, x.get_shape()
# TODO will produce tower0/xxx? # TODO will produce tower0/xxx?
# TODO use zero_debias
with tf.name_scope(None): with tf.name_scope(None):
averager = tf.train.ExponentialMovingAverage( averager = tf.train.ExponentialMovingAverage(
decay, num_updates=get_global_step_var(), name='EMA') decay, num_updates=get_global_step_var(), name='EMA')
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment