Commit 63b8fb00 authored by Yuxin Wu's avatar Yuxin Wu

make contrib.keras docs build

parent ce3782ad
......@@ -4,9 +4,7 @@
from contextlib import contextmanager
import six
import tensorflow as tf
import tensorflow.keras.backend as K
from tensorflow import keras
from tensorflow.python.keras import metrics as metrics_module
from ..callbacks import Callback, CallbackToHook, InferenceRunner, InferenceRunnerBase, ScalarStats
from ..models.regularize import regularize_cost_from_collection
......@@ -36,11 +34,10 @@ def _check_name(tensor, name):
class KerasModelCaller(object):
"""
Keras model doesn't support variable scope reuse.
This is a hack to mimic reuse.
This is a wrapper around keras model to mimic reuse.
"""
def __init__(self, get_model):
self.get_model = get_model
self.cached_model = None
def __call__(self, input_tensors):
......@@ -70,7 +67,7 @@ class KerasModelCaller(object):
for n in added_trainable_names:
if n not in new_trainable_names:
logger.warn("Keras created trainable variable '{}' which is actually not trainable. "
"This was automatically corrected by tensorpack.".format(n))
"This was automatically corrected.".format(n))
# Keras models might not use this collection at all (in some versions).
# This is a BC-breaking change of tf.keras: https://github.com/tensorflow/tensorflow/issues/19643
......@@ -93,7 +90,7 @@ class KerasModelCaller(object):
with clear_tower0_name_scope():
model = self.cached_model = self.get_model(*input_tensors)
assert isinstance(model, tf.keras.Model), \
assert isinstance(model, keras.Model), \
"Your get_model function should return a `tf.keras.Model`!"
outputs = model.outputs
elif reuse:
......@@ -125,7 +122,7 @@ class KerasPhaseCallback(Callback):
def __init__(self, isTrain):
assert isinstance(isTrain, bool), isTrain
self._isTrain = isTrain
self._learning_phase = K.learning_phase()
self._learning_phase = keras.backend.learning_phase()
def _setup_graph(self):
logger.info("Using Keras learning phase {} in the graph!".format(
......@@ -149,8 +146,9 @@ def setup_keras_trainer(
"""
Args:
trainer (SingleCostTrainer):
get_model (input1, input2, ... -> keras.model.Model):
Takes tensors and returns a Keras model. Will be part of the tower function.
get_model (input1, input2, ... -> tf.keras.Model):
A function which takes tensors, builds and returns a Keras model.
It will be part of the tower function.
input (InputSource):
optimizer (tf.train.Optimizer):
loss, metrics: list of strings
......@@ -202,7 +200,7 @@ def setup_keras_trainer(
output_tensor = outputs[oid]
target_tensor = target_tensors[oid] # TODO may not have the same mapping?
with cached_name_scope('keras_metric', top_level=False):
metric_fn = metrics_module.get(metric_name)
metric_fn = keras.metrics.get(metric_name)
metric_tensor = metric_fn(target_tensor, output_tensor)
metric_tensor = tf.reduce_mean(metric_tensor, name=metric_name)
_check_name(metric_tensor, metric_name)
......@@ -217,7 +215,7 @@ def setup_keras_trainer(
input,
get_cost,
lambda: optimizer)
if len(K.learning_phase().consumers()) > 0:
if len(keras.backend.learning_phase().consumers()) > 0:
# check if learning_phase is used in this model
trainer.register_callback(KerasPhaseCallback(True))
......@@ -228,7 +226,8 @@ class KerasModel(object):
"""
Args:
get_model (input1, input2, ... -> keras.Model):
A function which takes tensors and returns a Keras model. Will be part of the tower function.
A function which takes tensors, builds and returns a Keras model.
It will be part of the tower function.
inputs_desc ([InputDesc]):
targets_desc ([InputDesc]):
input (InputSource | DataFlow):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment