Commit 5feb3529 authored by Yuxin Wu's avatar Yuxin Wu

fix BN update issue to improve SeparateGANTrainer performance

parent e5f5da3c
...@@ -5,10 +5,11 @@ ...@@ -5,10 +5,11 @@
import tensorflow as tf import tensorflow as tf
import numpy as np import numpy as np
from tensorpack import (TowerTrainer, StagingInput, from tensorpack import (TowerTrainer, StagingInput,
ModelDescBase, DataFlow) ModelDescBase, DataFlow, argscope, BatchNorm)
from tensorpack.tfutils.tower import TowerContext, TowerFuncWrapper from tensorpack.tfutils.tower import TowerContext, TowerFuncWrapper
from tensorpack.graph_builder import DataParallelBuilder, LeastLoadedDeviceSetter from tensorpack.graph_builder import DataParallelBuilder, LeastLoadedDeviceSetter
from tensorpack.tfutils.summary import add_moving_summary from tensorpack.tfutils.summary import add_moving_summary
from tensorpack.utils import logger
from tensorpack.utils.argtools import memoized_method from tensorpack.utils.argtools import memoized_method
from tensorpack.utils.develop import deprecated from tensorpack.utils.develop import deprecated
...@@ -178,8 +179,14 @@ class SeparateGANTrainer(TowerTrainer): ...@@ -178,8 +179,14 @@ class SeparateGANTrainer(TowerTrainer):
# Build the graph # Build the graph
self.tower_func = TowerFuncWrapper(model.build_graph, model.get_inputs_desc()) self.tower_func = TowerFuncWrapper(model.build_graph, model.get_inputs_desc())
with TowerContext('', is_training=True): with TowerContext('', is_training=True), \
argscope(BatchNorm, internal_update=True):
# should not hook the updates to both train_op, it will hurt training speed.
self.tower_func(*input.get_input_tensors()) self.tower_func(*input.get_input_tensors())
update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
if len(update_ops):
logger.warn("Found {} ops in UPDATE_OPS collection!".format(len(update_ops)))
logger.warn("Using SeparateGANTrainer with UPDATE_OPS may hurt your training speed a lot!")
opt = model.get_optimizer() opt = model.get_optimizer()
with tf.name_scope('optimize'): with tf.name_scope('optimize'):
......
...@@ -75,7 +75,9 @@ class RunOp(Callback): ...@@ -75,7 +75,9 @@ class RunOp(Callback):
class RunUpdateOps(RunOp): class RunUpdateOps(RunOp):
""" """
Run ops from the collection UPDATE_OPS every step Run ops from the collection UPDATE_OPS every step.
The ops will be hooked to `trainer.hooked_sess` and run along with
each `sess.run` call.
""" """
def __init__(self, collection=tf.GraphKeys.UPDATE_OPS): def __init__(self, collection=tf.GraphKeys.UPDATE_OPS):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment