Commit 6f6787db authored by Yuxin Wu's avatar Yuxin Wu

add documentation about KerasModel (#1036)

parent f5d1714a
......@@ -16,4 +16,4 @@ API Documentation
predict
tfutils
utils
contrib
......@@ -93,6 +93,8 @@ class KerasModelCaller(object):
with clear_tower0_name_scope():
model = self.cached_model = self.get_model(*input_tensors)
assert isinstance(model, tf.keras.Model), \
"Your get_model function should return a `tf.keras.Model`!"
outputs = model.outputs
elif reuse:
# use the cached Keras model to mimic reuse
......@@ -110,11 +112,16 @@ class KerasModelCaller(object):
return outputs
# Keras needs an extra input if learning_phase is used by the model
# This cb will be used by
# 1. trainer with isTrain=True
# 2. InferenceRunner with isTrain=False, in the form of hooks
class KerasPhaseCallback(Callback):
"""
Keras needs an extra input if learning_phase is used by the model
This callback will be used:
1. By the trainer with isTrain=True
2. By InferenceRunner with isTrain=False, in the form of hooks
If you use :class:`KerasModel` or :func:`setup_keras_trainer`,
this callback will be automatically added when needed.
"""
def __init__(self, isTrain):
assert isinstance(isTrain, bool), isTrain
self._isTrain = isTrain
......@@ -221,7 +228,7 @@ class KerasModel(object):
"""
Args:
get_model (input1, input2, ... -> keras.Model):
Takes tensors and returns a Keras model. Will be part of the tower function.
A function which takes tensors and returns a Keras model. Will be part of the tower function.
inputs_desc ([InputDesc]):
targets_desc ([InputDesc]):
input (InputSource | DataFlow):
......@@ -229,6 +236,7 @@ class KerasModel(object):
GPUs and use them all.
"""
self.get_model = get_model
assert callable(get_model), get_model
self.inputs_desc = inputs_desc
self.targets_desc = targets_desc
if trainer is None:
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment