Commit 63b8fb00 authored by Yuxin Wu's avatar Yuxin Wu

make contrib.keras docs build

parent ce3782ad
...@@ -4,9 +4,7 @@ ...@@ -4,9 +4,7 @@
from contextlib import contextmanager from contextlib import contextmanager
import six import six
import tensorflow as tf import tensorflow as tf
import tensorflow.keras.backend as K
from tensorflow import keras from tensorflow import keras
from tensorflow.python.keras import metrics as metrics_module
from ..callbacks import Callback, CallbackToHook, InferenceRunner, InferenceRunnerBase, ScalarStats from ..callbacks import Callback, CallbackToHook, InferenceRunner, InferenceRunnerBase, ScalarStats
from ..models.regularize import regularize_cost_from_collection from ..models.regularize import regularize_cost_from_collection
...@@ -36,11 +34,10 @@ def _check_name(tensor, name): ...@@ -36,11 +34,10 @@ def _check_name(tensor, name):
class KerasModelCaller(object): class KerasModelCaller(object):
""" """
Keras model doesn't support variable scope reuse. Keras model doesn't support variable scope reuse.
This is a hack to mimic reuse. This is a wrapper around keras model to mimic reuse.
""" """
def __init__(self, get_model): def __init__(self, get_model):
self.get_model = get_model self.get_model = get_model
self.cached_model = None self.cached_model = None
def __call__(self, input_tensors): def __call__(self, input_tensors):
...@@ -70,7 +67,7 @@ class KerasModelCaller(object): ...@@ -70,7 +67,7 @@ class KerasModelCaller(object):
for n in added_trainable_names: for n in added_trainable_names:
if n not in new_trainable_names: if n not in new_trainable_names:
logger.warn("Keras created trainable variable '{}' which is actually not trainable. " logger.warn("Keras created trainable variable '{}' which is actually not trainable. "
"This was automatically corrected by tensorpack.".format(n)) "This was automatically corrected.".format(n))
# Keras models might not use this collection at all (in some versions). # Keras models might not use this collection at all (in some versions).
# This is a BC-breaking change of tf.keras: https://github.com/tensorflow/tensorflow/issues/19643 # This is a BC-breaking change of tf.keras: https://github.com/tensorflow/tensorflow/issues/19643
...@@ -93,7 +90,7 @@ class KerasModelCaller(object): ...@@ -93,7 +90,7 @@ class KerasModelCaller(object):
with clear_tower0_name_scope(): with clear_tower0_name_scope():
model = self.cached_model = self.get_model(*input_tensors) model = self.cached_model = self.get_model(*input_tensors)
assert isinstance(model, tf.keras.Model), \ assert isinstance(model, keras.Model), \
"Your get_model function should return a `tf.keras.Model`!" "Your get_model function should return a `tf.keras.Model`!"
outputs = model.outputs outputs = model.outputs
elif reuse: elif reuse:
...@@ -125,7 +122,7 @@ class KerasPhaseCallback(Callback): ...@@ -125,7 +122,7 @@ class KerasPhaseCallback(Callback):
def __init__(self, isTrain): def __init__(self, isTrain):
assert isinstance(isTrain, bool), isTrain assert isinstance(isTrain, bool), isTrain
self._isTrain = isTrain self._isTrain = isTrain
self._learning_phase = K.learning_phase() self._learning_phase = keras.backend.learning_phase()
def _setup_graph(self): def _setup_graph(self):
logger.info("Using Keras learning phase {} in the graph!".format( logger.info("Using Keras learning phase {} in the graph!".format(
...@@ -149,8 +146,9 @@ def setup_keras_trainer( ...@@ -149,8 +146,9 @@ def setup_keras_trainer(
""" """
Args: Args:
trainer (SingleCostTrainer): trainer (SingleCostTrainer):
get_model (input1, input2, ... -> keras.model.Model): get_model (input1, input2, ... -> tf.keras.Model):
Takes tensors and returns a Keras model. Will be part of the tower function. A function which takes tensors, builds and returns a Keras model.
It will be part of the tower function.
input (InputSource): input (InputSource):
optimizer (tf.train.Optimizer): optimizer (tf.train.Optimizer):
loss, metrics: list of strings loss, metrics: list of strings
...@@ -202,7 +200,7 @@ def setup_keras_trainer( ...@@ -202,7 +200,7 @@ def setup_keras_trainer(
output_tensor = outputs[oid] output_tensor = outputs[oid]
target_tensor = target_tensors[oid] # TODO may not have the same mapping? target_tensor = target_tensors[oid] # TODO may not have the same mapping?
with cached_name_scope('keras_metric', top_level=False): with cached_name_scope('keras_metric', top_level=False):
metric_fn = metrics_module.get(metric_name) metric_fn = keras.metrics.get(metric_name)
metric_tensor = metric_fn(target_tensor, output_tensor) metric_tensor = metric_fn(target_tensor, output_tensor)
metric_tensor = tf.reduce_mean(metric_tensor, name=metric_name) metric_tensor = tf.reduce_mean(metric_tensor, name=metric_name)
_check_name(metric_tensor, metric_name) _check_name(metric_tensor, metric_name)
...@@ -217,7 +215,7 @@ def setup_keras_trainer( ...@@ -217,7 +215,7 @@ def setup_keras_trainer(
input, input,
get_cost, get_cost,
lambda: optimizer) lambda: optimizer)
if len(K.learning_phase().consumers()) > 0: if len(keras.backend.learning_phase().consumers()) > 0:
# check if learning_phase is used in this model # check if learning_phase is used in this model
trainer.register_callback(KerasPhaseCallback(True)) trainer.register_callback(KerasPhaseCallback(True))
...@@ -228,7 +226,8 @@ class KerasModel(object): ...@@ -228,7 +226,8 @@ class KerasModel(object):
""" """
Args: Args:
get_model (input1, input2, ... -> keras.Model): get_model (input1, input2, ... -> keras.Model):
A function which takes tensors and returns a Keras model. Will be part of the tower function. A function which takes tensors, builds and returns a Keras model.
It will be part of the tower function.
inputs_desc ([InputDesc]): inputs_desc ([InputDesc]):
targets_desc ([InputDesc]): targets_desc ([InputDesc]):
input (InputSource | DataFlow): input (InputSource | DataFlow):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment