Commit 627ad534 authored by Yuxin Wu's avatar Yuxin Wu

support BN updates in Keras models

parent 035f597d
......@@ -84,6 +84,11 @@ def setup_keras_trainer(
target_tensors=target_tensors,
metrics=metrics)
# BN updates
if ctx.is_training:
for u in M.updates:
tf.add_to_collection(tf.GraphKeys.UPDATE_OPS, u)
add_moving_summary(tf.identity(M.total_loss, name='total_loss'))
assert len(M.metrics) == len(M.metrics_tensors)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment