Commit f6ede612 authored by Yuxin Wu's avatar Yuxin Wu

Better BatchNorm (with ema_update option decoupled from training)

parent 4a46b93d
...@@ -169,8 +169,8 @@ class SeparateGANTrainer(TowerTrainer): ...@@ -169,8 +169,8 @@ class SeparateGANTrainer(TowerTrainer):
# Build the graph # Build the graph
self.tower_func = TowerFuncWrapper(model.build_graph, model.get_input_signature()) self.tower_func = TowerFuncWrapper(model.build_graph, model.get_input_signature())
with TowerContext('', is_training=True), \ with TowerContext('', is_training=True), \
argscope(BatchNorm, internal_update=True): argscope(BatchNorm, ema_update='internal'):
# should not hook the updates to both train_op, it will hurt training speed. # should not hook the EMA updates to both train_op, it will hurt training speed.
self.tower_func(*input.get_input_tensors()) self.tower_func(*input.get_input_tensors())
update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
if len(update_ops): if len(update_ops):
......
This diff is collapsed.
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment