Commit f1ee1833 authored by Yuxin Wu's avatar Yuxin Wu

[Keras] minor improvements (#160)

parent e8674dca
......@@ -7,22 +7,33 @@ import six
from tensorflow import keras
from tensorflow.python.keras import metrics as metrics_module
from ..models.regularize import regularize_cost_from_collection
from ..graph_builder import InputDesc
from ..tfutils.tower import get_current_tower_context
# from ..tfutils.collection import freeze_collection # TODO freeze UPDATE_OPS in replicated
from ..train import Trainer, SimpleTrainer, SyncMultiGPUTrainerParameterServer, DistributedTrainerBase
from ..callbacks import (
Callback, InferenceRunner, CallbackToHook,
ScalarStats)
from ..tfutils.common import get_op_tensor_name
from ..tfutils.tower import get_current_tower_context
from ..tfutils.scope_utils import cached_name_scope
# from ..tfutils.collection import freeze_collection # TODO freeze UPDATE_OPS in replicated
from ..tfutils.summary import add_moving_summary
from ..utils.gpu import get_nr_gpu
from ..train import Trainer, SimpleTrainer, SyncMultiGPUTrainerParameterServer
__all__ = ['KerasPhaseCallback', 'setup_keras_trainer', 'KerasModel']
TOTAL_LOSS_NAME = 'total_loss'
def _check_name(tensor, name):
tensorname = get_op_tensor_name(tensor.name)[0]
assert tensorname.split('/')[-1] == name, \
"{} does not match {}, you may have name conflict somewhere!".format(tensor.name, name)
class KerasModelCaller(object):
"""
Keras model doesn't support vs reuse.
......@@ -46,6 +57,21 @@ class KerasModelCaller(object):
M = self.get_model(input_tensors)
return M.outputs
def call_virtual(self):
class NoneTensorProxy(object):
def __getitem__(self, index):
return None
def __len__(self):
raise NotImplementedError(
"Do not call `len(inputs)` because it's only a virtual object "
"for the moment! Use `inputs[index]` directly!")
G_tmp = tf.Graph() # we need a model instance to know metadata about inputs/outputs
with G_tmp.as_default():
return self.get_model(NoneTensorProxy())
# Keras needs an extra input if learning_phase is used by the model
# This cb will be used by
......@@ -58,9 +84,9 @@ class KerasPhaseCallback(Callback):
self._learning_phase = keras.backend.learning_phase()
def _setup_graph(self):
# HACK
cbs = self.trainer._callbacks.cbs
for cb in cbs:
# XXX HACK
if isinstance(cb, InferenceRunner):
h = CallbackToHook(KerasPhaseCallback(False))
cb.register_hook(h)
......@@ -72,7 +98,7 @@ class KerasPhaseCallback(Callback):
def setup_keras_trainer(
trainer, get_model, input,
optimizer, loss, metrics=None):
optimizer, loss, metrics):
"""
Args:
trainer (SingleCostTrainer):
......@@ -82,18 +108,18 @@ def setup_keras_trainer(
loss, metrics: list of strings
"""
assert isinstance(optimizer, tf.train.Optimizer), optimizer
assert isinstance(loss, list), loss
assert len(loss) >= 1, "No loss was given!"
assert isinstance(metrics, list), metrics
model_caller = KerasModelCaller(get_model)
M_tmp = model_caller.call_virtual()
G_tmp = tf.Graph() # we need the model instance to know metadata about inputs/outputs
with G_tmp.as_default():
M_tmp = get_model([None]) # TODO use a proxy with Nones
inputs_desc = [InputDesc(t.dtype, t.shape.as_list(), 'input{}'.format(i))
for i, t in enumerate(M_tmp.inputs)]
outputs_desc = [InputDesc(t.dtype, t.shape.as_list(), 'output{}'.format(i))
for i, t in enumerate(M_tmp.outputs)]
nr_inputs = len(inputs_desc)
del G_tmp, M_tmp
model_caller = KerasModelCaller(get_model)
def get_cost(*inputs):
assert len(inputs) == len(inputs_desc) + len(outputs_desc), \
......@@ -112,19 +138,22 @@ def setup_keras_trainer(
assert len(outputs) == len(loss), \
"len({}) != len({})".format(str(outputs), str(loss))
# TODO more losses
with tf.name_scope('keras_loss'):
loss_fn = keras.losses.get(loss[0])
loss_opt = loss_fn(target_tensors[0], outputs[0])
loss_opt = tf.reduce_mean(loss_opt, name=loss[0])
loss_tensors = []
for idx, loss_name in enumerate(loss):
with cached_name_scope('keras_loss', top_level=False):
loss_fn = keras.losses.get(loss_name)
curr_loss = loss_fn(target_tensors[idx], outputs[idx])
curr_loss = tf.reduce_mean(curr_loss, name=loss_name)
_check_name(curr_loss, loss_name)
loss_tensors.append(curr_loss)
loss_reg = regularize_cost_from_collection()
if loss_reg is not None:
total_loss = tf.add(loss_opt, loss_reg, name='total_loss')
add_moving_summary(loss_opt, loss_reg, total_loss)
total_loss = tf.add_n(loss_tensors + [loss_reg], name=TOTAL_LOSS_NAME)
add_moving_summary(loss_reg, total_loss, *loss_tensors)
else:
add_moving_summary(loss_opt)
total_loss = tf.identity(loss_opt, name='total_loss')
add_moving_summary(*loss_tensors)
total_loss = tf.add_n(loss_tensors, name=TOTAL_LOSS_NAME)
if metrics and (ctx.is_main_training_tower or not ctx.is_training):
# for list: one metric for each output
......@@ -132,10 +161,11 @@ def setup_keras_trainer(
for oid, metric_name in enumerate(metrics):
output_tensor = outputs[oid]
target_tensor = target_tensors[oid] # TODO may not have the same mapping?
with tf.name_scope('keras_metric'): # TODO ns reuse
with cached_name_scope('keras_metric', top_level=False):
metric_fn = metrics_module.get(metric_name)
metric_tensor = metric_fn(target_tensor, output_tensor)
metric_tensor = tf.reduce_mean(metric_tensor, name=metric_name)
_check_name(metric_tensor, metric_name)
# check name conflict here
metric_tensors.append(metric_tensor)
add_moving_summary(*metric_tensors)
......@@ -168,6 +198,7 @@ class KerasModel(object):
else:
trainer = SyncMultiGPUTrainerParameterServer(nr_gpu)
assert isinstance(trainer, Trainer), trainer
assert not isinstance(trainer, DistributedTrainerBase)
self.input = input
self.trainer = trainer
......@@ -185,7 +216,7 @@ class KerasModel(object):
if isinstance(metrics, six.string_types):
metrics = [metrics]
self._stats_to_inference = loss + metrics
self._stats_to_inference = loss + metrics + [TOTAL_LOSS_NAME]
setup_keras_trainer(
self.trainer, get_model=self.get_model,
input=self.input,
......@@ -201,7 +232,6 @@ class KerasModel(object):
"""
callbacks = kwargs.pop('callbacks', [])
if validation_data is not None:
callbacks.append(
InferenceRunner(
validation_data, ScalarStats(self._stats_to_inference + ['total_loss'])))
callbacks.append(InferenceRunner(
validation_data, ScalarStats(self._stats_to_inference)))
self.trainer.train_with_defaults(callbacks=callbacks, **kwargs)
......@@ -85,15 +85,18 @@ def _get_cached_ns(name):
@contextmanager
def cached_name_scope(name):
def cached_name_scope(name, top_level=True):
"""
Return a context which either opens and caches a new top-level name scope,
Return a context which either opens and caches a new name scope,
or reenter an existing one.
Note:
The name scope will always be top-level. It will not be nested under
any existing name scope of the caller.
Args:
top_level(bool): if True, the name scope will always be top-level.
It will not be nested under any existing name scope of the caller.
"""
if not top_level:
current_ns = tf.get_default_graph().get_name_scope()
name = current_ns + '/' + name
ns = _get_cached_ns(name)
with tf.name_scope(ns):
yield ns
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment