Commit 0201f2df authored by Bohumír Zámečník's avatar Bohumír Zámečník Committed by Yuxin Wu

Refactor KerasModelCaller to prevent using an unassigned variable. (#758)

fix #756 
parent 9bcf561c
...@@ -56,26 +56,12 @@ class KerasModelCaller(object): ...@@ -56,26 +56,12 @@ class KerasModelCaller(object):
old_trainable_names = set([x.name for x in tf.trainable_variables()]) old_trainable_names = set([x.name for x in tf.trainable_variables()])
trainable_backup = backup_collection([tf.GraphKeys.TRAINABLE_VARIABLES]) trainable_backup = backup_collection([tf.GraphKeys.TRAINABLE_VARIABLES])
try:
if self.cached_model is None: def post_process_model(model):
assert not reuse
M = self.cached_model = self.get_model(*input_tensors)
return M.outputs
elif reuse:
# use the cached Keras model to mimic reuse
# NOTE: ctx.is_training won't be useful inside model,
# because inference will always use the cached Keras model
M = self.cached_model
return M.call(input_tensors)
else:
# create new Keras model if not reuse
M = self.get_model(*input_tensors)
return M.outputs
finally:
added_trainable_names = set([x.name for x in tf.trainable_variables()]) added_trainable_names = set([x.name for x in tf.trainable_variables()])
restore_collection(trainable_backup) restore_collection(trainable_backup)
for v in M.weights: for v in model.weights:
# In Keras, the collection is not respected and could contain non-trainable vars. # In Keras, the collection is not respected and could contain non-trainable vars.
# We put M.weights into the collection instead. # We put M.weights into the collection instead.
if v.name not in old_trainable_names and v.name in added_trainable_names: if v.name not in old_trainable_names and v.name in added_trainable_names:
...@@ -87,6 +73,25 @@ class KerasModelCaller(object): ...@@ -87,6 +73,25 @@ class KerasModelCaller(object):
logger.warn("Keras created trainable variable '{}' which is actually not trainable. " logger.warn("Keras created trainable variable '{}' which is actually not trainable. "
"This was automatically corrected by tensorpack.".format(n)) "This was automatically corrected by tensorpack.".format(n))
if self.cached_model is None:
assert not reuse
model = self.cached_model = self.get_model(*input_tensors)
outputs = model.outputs
elif reuse:
# use the cached Keras model to mimic reuse
# NOTE: ctx.is_training won't be useful inside model,
# because inference will always use the cached Keras model
model = self.cached_model
outputs = model.call(input_tensors)
else:
# create new Keras model if not reuse
model = self.get_model(*input_tensors)
outputs = model.outputs
post_process_model(model)
return outputs
# Keras needs an extra input if learning_phase is used by the model # Keras needs an extra input if learning_phase is used by the model
# This cb will be used by # This cb will be used by
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment