Commit 6f6787db authored by Yuxin Wu's avatar Yuxin Wu

add documentation about KerasModel (#1036)

parent f5d1714a
...@@ -16,4 +16,4 @@ API Documentation ...@@ -16,4 +16,4 @@ API Documentation
predict predict
tfutils tfutils
utils utils
contrib
...@@ -93,6 +93,8 @@ class KerasModelCaller(object): ...@@ -93,6 +93,8 @@ class KerasModelCaller(object):
with clear_tower0_name_scope(): with clear_tower0_name_scope():
model = self.cached_model = self.get_model(*input_tensors) model = self.cached_model = self.get_model(*input_tensors)
assert isinstance(model, tf.keras.Model), \
"Your get_model function should return a `tf.keras.Model`!"
outputs = model.outputs outputs = model.outputs
elif reuse: elif reuse:
# use the cached Keras model to mimic reuse # use the cached Keras model to mimic reuse
...@@ -110,11 +112,16 @@ class KerasModelCaller(object): ...@@ -110,11 +112,16 @@ class KerasModelCaller(object):
return outputs return outputs
# Keras needs an extra input if learning_phase is used by the model
# This cb will be used by
# 1. trainer with isTrain=True
# 2. InferenceRunner with isTrain=False, in the form of hooks
class KerasPhaseCallback(Callback): class KerasPhaseCallback(Callback):
"""
Keras needs an extra input if learning_phase is used by the model
This callback will be used:
1. By the trainer with isTrain=True
2. By InferenceRunner with isTrain=False, in the form of hooks
If you use :class:`KerasModel` or :func:`setup_keras_trainer`,
this callback will be automatically added when needed.
"""
def __init__(self, isTrain): def __init__(self, isTrain):
assert isinstance(isTrain, bool), isTrain assert isinstance(isTrain, bool), isTrain
self._isTrain = isTrain self._isTrain = isTrain
...@@ -221,7 +228,7 @@ class KerasModel(object): ...@@ -221,7 +228,7 @@ class KerasModel(object):
""" """
Args: Args:
get_model (input1, input2, ... -> keras.Model): get_model (input1, input2, ... -> keras.Model):
Takes tensors and returns a Keras model. Will be part of the tower function. A function which takes tensors and returns a Keras model. Will be part of the tower function.
inputs_desc ([InputDesc]): inputs_desc ([InputDesc]):
targets_desc ([InputDesc]): targets_desc ([InputDesc]):
input (InputSource | DataFlow): input (InputSource | DataFlow):
...@@ -229,6 +236,7 @@ class KerasModel(object): ...@@ -229,6 +236,7 @@ class KerasModel(object):
GPUs and use them all. GPUs and use them all.
""" """
self.get_model = get_model self.get_model = get_model
assert callable(get_model), get_model
self.inputs_desc = inputs_desc self.inputs_desc = inputs_desc
self.targets_desc = targets_desc self.targets_desc = targets_desc
if trainer is None: if trainer is None:
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment