Commit 062790c4 authored by Yuxin Wu's avatar Yuxin Wu

DataParallelInferenceRunner (a complicated implementation). (fix #139)

parent e2b985ca
......@@ -196,7 +196,8 @@ class MySimulatorMaster(SimulatorMaster, Callback):
def get_config():
logger.auto_set_dir()
dirname = os.path.join('train_log', 'train-atari-{}'.format(ENV_NAME))
logger.set_logger_dir(dirname)
M = Model()
name_base = str(uuid.uuid1())[:6]
......
......@@ -15,6 +15,8 @@ from ..tfutils import get_op_tensor_name
__all__ = ['ScalarStats', 'Inferencer',
'ClassificationError', 'BinaryClassificationStats']
# TODO rename get_output_tensors to get_output_names
@six.add_metaclass(ABCMeta)
class Inferencer(object):
......
......@@ -15,14 +15,16 @@ from six.moves import zip
from ..utils import logger, get_tqdm_kwargs
from ..dataflow import DataFlow
from ..tfutils.common import get_op_tensor_name
from ..tfutils.common import get_op_tensor_name, get_tensors_by_names
from ..tfutils.tower import TowerContext
from ..train.input_data import TensorInput, FeedInput
from ..predict import PredictorTowerBuilder
from .base import Triggerable
from .inference import Inferencer
__all__ = ['InferenceRunner', 'FeedfreeInferenceRunner']
__all__ = ['InferenceRunner', 'FeedfreeInferenceRunner',
'DataParallelInferenceRunner']
class InferencerToHook(tf.train.SessionRunHook):
......@@ -98,6 +100,7 @@ class InferenceRunnerBase(Triggerable):
self._input_data.setup(self.trainer.model)
self._setup_input_names()
in_tensors = self._find_input_tensors()
assert isinstance(in_tensors, list), in_tensors
with tf.variable_scope(tf.get_variable_scope(), reuse=True):
def fn(_):
......@@ -115,7 +118,6 @@ class InferenceRunnerBase(Triggerable):
get_tensor_fn = PredictorTowerBuilder.get_tensors_maybe_in_tower
return get_tensor_fn(placeholder_names, names, 0, prefix=self._prefix)
@abstractmethod
def _find_input_tensors(self):
pass
......@@ -230,3 +232,97 @@ class FeedfreeInferenceRunner(InferenceRunnerBase):
idx = placeholder_names.index(name)
ret.append(self._input_tensors[idx])
return InferencerToHook(inf, ret)
class DataParallelInferenceRunner(InferenceRunner):
def __init__(self, ds, infs, gpus, input_names=None):
super(DataParallelInferenceRunner, self).__init__(ds, infs, input_names)
self._gpus = gpus
def _setup_graph(self):
model = self.trainer.model
self._input_data.setup(model)
self._setup_input_names()
# build graph
def build_tower(k):
towername = TowerContext.get_predict_tower_name(k)
# inputs (placeholders) for this tower only
input_tensors = model.build_placeholders(
prefix=towername + '/')
model.build_graph(input_tensors)
builder = PredictorTowerBuilder(build_tower, prefix=self._prefix)
with tf.variable_scope(tf.get_variable_scope(), reuse=True):
for t in self._gpus:
builder.build(t)
# setup feeds and hooks
self._feed_tensors = self._find_feed_tensors()
self._hooks_parallel = [self._build_hook_parallel(inf) for inf in self.infs]
self._hooks = [self._build_hook(inf) for inf in self.infs]
def _duplicate_names_across_towers(self, names):
ret = []
for t in self._gpus:
ret.extend([TowerContext.get_predict_tower_name(t, self._prefix) +
'/' + n for n in names])
return ret
def _find_feed_tensors(self):
names = self._duplicate_names_across_towers(self.input_names)
return get_tensors_by_names(names)
class InferencerToHookDataParallel(InferencerToHook):
def __init__(self, inf, fetches, size):
super(DataParallelInferenceRunner.InferencerToHookDataParallel, self).__init__(inf, fetches)
assert len(self._fetches) % size == 0
self._sz = size
def after_run(self, _, run_values):
res = run_values.results
for i in range(0, len(res), self._sz):
vals = res[i:i + self._sz]
self._inf.datapoint(vals)
def _build_hook_parallel(self, inf):
out_names = inf.get_output_tensors()
sz = len(out_names)
out_names = self._duplicate_names_across_towers(out_names)
fetches = get_tensors_by_names(out_names)
return DataParallelInferenceRunner.InferencerToHookDataParallel(
inf, fetches, sz)
def _build_hook(self, inf):
out_names = inf.get_output_tensors()
names = [TowerContext.get_predict_tower_name(
self._gpus[0], self._prefix) + '/' + n for n in out_names]
fetches = get_tensors_by_names(names)
return InferencerToHook(inf, fetches)
def _before_train(self):
self._hooked_sess = HookedSession(self.trainer.sess, self._hooks)
self._parallel_hooked_sess = HookedSession(self.trainer.sess, self._hooks_parallel)
def _trigger(self):
for inf in self.infs:
inf.before_inference()
self._input_data.reset_state()
total = self._input_data.size()
nr_tower = len(self._gpus)
with tqdm.tqdm(total=total, **get_tqdm_kwargs()) as pbar:
while total >= nr_tower:
dps = []
for k in self._gpus:
dps.extend(self._input_data.next_feed())
feed = dict(zip(self._feed_tensors, dps))
self._parallel_hooked_sess.run(fetches=[], feed_dict=feed)
pbar.update(nr_tower)
total -= nr_tower
# take care of the rest
while total > 0:
dp = self._input_data.next_feed()
feed = dict(zip(self._feed_tensors[:len(dp)], dp))
self._hooked_sess.run(fetches=[], feed_dict=feed)
summary_inferencer(self.trainer, self.infs)
......@@ -199,6 +199,8 @@ def CaffeLMDB(lmdb_path, shuffle=True, keys=None):
log_once("Cannot read key {}".format(k), 'warn')
return None
return [img.transpose(1, 2, 0), datum.label]
logger.warn("Caffe LMDB format doesn't store jpeg-compressed images, \
it's not recommended due to its inferior performance.")
return LMDBDataDecoder(lmdb_data, decoder)
......
......@@ -192,11 +192,14 @@ class PredictorTowerBuilder(object):
TowerContext(towername, is_training=False):
self._fn(tower)
# useful only when the placeholders don't have tower prefix
# note that in DataParallel predictor, placeholders do have tower prefix
@staticmethod
def get_tensors_maybe_in_tower(placeholder_names, names, k, prefix=''):
def get_tensors_maybe_in_tower(placeholder_names, names, tower, prefix=''):
"""
Args:
placeholders (list): A list of __op__ name.
tower (int): relative GPU id.
"""
def maybe_inside_tower(name):
name = get_op_tensor_name(name)[0]
......@@ -204,7 +207,7 @@ class PredictorTowerBuilder(object):
return name
else:
# if the name is not a placeholder, use it's name in each tower
return TowerContext.get_predict_tower_name(k, prefix) + '/' + name
return TowerContext.get_predict_tower_name(tower, prefix) + '/' + name
names = list(map(maybe_inside_tower, names))
tensors = get_tensors_by_names(names)
return tensors
......
......@@ -66,8 +66,9 @@ class DataParallelOfflinePredictor(OnlinePredictor):
"""
A data-parallel predictor.
Note that it doesn't split/concat inputs/outputs automatically.
Its input is: ``[input[0] in tower[0], input[1] in tower[0], ..., input[0] in tower[1], input[1] in tower[1], ...]``
And same for the output.
Instead, its inputs are:
``[input[0] in tower[0], input[1] in tower[0], ..., input[0] in tower[1], input[1] in tower[1], ...]``
Similar for the outputs.
"""
def __init__(self, config, towers):
......
......@@ -125,6 +125,7 @@ def add_moving_summary(v, *args, **kwargs):
assert isinstance(x, tf.Tensor), x
assert x.get_shape().ndims == 0, x.get_shape()
# TODO will produce tower0/xxx?
# TODO use zero_debias
with tf.name_scope(None):
averager = tf.train.ExponentialMovingAverage(
decay, num_updates=get_global_step_var(), name='EMA')
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment