Commit 5feb3529 authored by Yuxin Wu's avatar Yuxin Wu

fix BN update issue to improve SeparateGANTrainer performance

parent e5f5da3c
......@@ -5,10 +5,11 @@
import tensorflow as tf
import numpy as np
from tensorpack import (TowerTrainer, StagingInput,
ModelDescBase, DataFlow)
ModelDescBase, DataFlow, argscope, BatchNorm)
from tensorpack.tfutils.tower import TowerContext, TowerFuncWrapper
from tensorpack.graph_builder import DataParallelBuilder, LeastLoadedDeviceSetter
from tensorpack.tfutils.summary import add_moving_summary
from tensorpack.utils import logger
from tensorpack.utils.argtools import memoized_method
from tensorpack.utils.develop import deprecated
......@@ -178,8 +179,14 @@ class SeparateGANTrainer(TowerTrainer):
# Build the graph
self.tower_func = TowerFuncWrapper(model.build_graph, model.get_inputs_desc())
with TowerContext('', is_training=True):
with TowerContext('', is_training=True), \
argscope(BatchNorm, internal_update=True):
# should not hook the updates to both train_op, it will hurt training speed.
self.tower_func(*input.get_input_tensors())
update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
if len(update_ops):
logger.warn("Found {} ops in UPDATE_OPS collection!".format(len(update_ops)))
logger.warn("Using SeparateGANTrainer with UPDATE_OPS may hurt your training speed a lot!")
opt = model.get_optimizer()
with tf.name_scope('optimize'):
......
......@@ -75,7 +75,9 @@ class RunOp(Callback):
class RunUpdateOps(RunOp):
"""
Run ops from the collection UPDATE_OPS every step
Run ops from the collection UPDATE_OPS every step.
The ops will be hooked to `trainer.hooked_sess` and run along with
each `sess.run` call.
"""
def __init__(self, collection=tf.GraphKeys.UPDATE_OPS):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment