Commit e072d909 authored by Yuxin Wu's avatar Yuxin Wu

[WIP] reorganize trainer. fix batch_norm

parent 335d6c28
...@@ -8,14 +8,9 @@ import tensorflow as tf ...@@ -8,14 +8,9 @@ import tensorflow as tf
import argparse import argparse
import os import os
from tensorpack.train import TrainConfig, QueueInputTrainer from tensorpack import *
from tensorpack.models import *
from tensorpack.callbacks import *
from tensorpack.utils import *
from tensorpack.tfutils import *
from tensorpack.tfutils.symbolic_functions import * from tensorpack.tfutils.symbolic_functions import *
from tensorpack.tfutils.summary import * from tensorpack.tfutils.summary import *
from tensorpack.dataflow import *
""" """
CIFAR10-resnet example. CIFAR10-resnet example.
...@@ -186,11 +181,9 @@ if __name__ == '__main__': ...@@ -186,11 +181,9 @@ if __name__ == '__main__':
if args.gpu: if args.gpu:
os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
with tf.Graph().as_default(): config = get_config()
with tf.device('/cpu:0'): if args.load:
config = get_config() config.session_init = SaverRestore(args.load)
if args.load: if args.gpu:
config.session_init = SaverRestore(args.load) config.nr_tower = len(args.gpu.split(','))
if args.gpu: SyncMultiGPUTrainer(config).train()
config.nr_tower = len(args.gpu.split(','))
QueueInputTrainer(config).train()
...@@ -5,7 +5,9 @@ ...@@ -5,7 +5,9 @@
import tensorflow as tf import tensorflow as tf
from copy import copy from copy import copy
import re
from ..utils import logger
from ._common import layer_register from ._common import layer_register
__all__ = ['BatchNorm'] __all__ = ['BatchNorm']
...@@ -48,9 +50,28 @@ def BatchNorm(x, use_local_stat=True, decay=0.9, epsilon=1e-5): ...@@ -48,9 +50,28 @@ def BatchNorm(x, use_local_stat=True, decay=0.9, epsilon=1e-5):
else: else:
batch_mean, batch_var = tf.nn.moments(x, [0, 1, 2], keep_dims=False) batch_mean, batch_var = tf.nn.moments(x, [0, 1, 2], keep_dims=False)
ema = tf.train.ExponentialMovingAverage(decay=decay) emaname = 'EMA'
ema_apply_op = ema.apply([batch_mean, batch_var]) if not batch_mean.name.startswith('towerp'):
ema_mean, ema_var = ema.average(batch_mean), ema.average(batch_var) ema = tf.train.ExponentialMovingAverage(decay=decay, name=emaname)
ema_apply_op = ema.apply([batch_mean, batch_var])
ema_mean, ema_var = ema.average(batch_mean), ema.average(batch_var)
else:
assert not use_local_stat
# have to do this again to get actual name. see issue:
# https://github.com/tensorflow/tensorflow/issues/2740
ema = tf.train.ExponentialMovingAverage(decay=decay, name=emaname)
ema_apply_op = ema.apply([batch_mean, batch_var])
ema_mean, ema_var = ema.average(batch_mean), ema.average(batch_var)
mean_name = re.sub('towerp[0-9]+/', '', ema_mean.name)
var_name = re.sub('towerp[0-9]+/', '', ema_var.name)
#var_name = batch_var.op.name[prefixlen:] + '/' + emaname + ':0'
#logger.info("In prediction, using {} instead of {} for {}".format(
#mean_name, ema_mean.name, batch_mean.name))
G = tf.get_default_graph()
ema_mean = G.get_tensor_by_name(mean_name)
ema_var = G.get_tensor_by_name(var_name)
if use_local_stat: if use_local_stat:
with tf.control_dependencies([ema_apply_op]): with tf.control_dependencies([ema_apply_op]):
...@@ -58,6 +79,7 @@ def BatchNorm(x, use_local_stat=True, decay=0.9, epsilon=1e-5): ...@@ -58,6 +79,7 @@ def BatchNorm(x, use_local_stat=True, decay=0.9, epsilon=1e-5):
x, batch_mean, batch_var, beta, gamma, epsilon, 'bn') x, batch_mean, batch_var, beta, gamma, epsilon, 'bn')
else: else:
batch = tf.cast(tf.shape(x)[0], tf.float32) batch = tf.cast(tf.shape(x)[0], tf.float32)
# XXX TODO batch==1?
mean, var = ema_mean, ema_var * batch / (batch - 1) # unbiased variance estimator mean, var = ema_mean, ema_var * batch / (batch - 1) # unbiased variance estimator
return tf.nn.batch_normalization( return tf.nn.batch_normalization(
x, mean, var, beta, gamma, epsilon, 'bn') x, mean, var, beta, gamma, epsilon, 'bn')
...@@ -5,13 +5,19 @@ ...@@ -5,13 +5,19 @@
from ..utils.naming import * from ..utils.naming import *
import tensorflow as tf import tensorflow as tf
from copy import copy
import six
from contextlib import contextmanager
__all__ = ['get_default_sess_config', __all__ = ['get_default_sess_config',
'get_global_step', 'get_global_step',
'get_global_step_var', 'get_global_step_var',
'get_op_var_name', 'get_op_var_name',
'get_vars_by_names' 'get_vars_by_names',
] 'backup_collection',
'restore_collection',
'clear_collection',
'freeze_collection']
def get_default_sess_config(mem_fraction=0.9): def get_default_sess_config(mem_fraction=0.9):
""" """
...@@ -66,3 +72,24 @@ def get_vars_by_names(names): ...@@ -66,3 +72,24 @@ def get_vars_by_names(names):
opn, varn = get_op_var_name(n) opn, varn = get_op_var_name(n)
ret.append(G.get_tensor_by_name(varn)) ret.append(G.get_tensor_by_name(varn))
return ret return ret
def backup_collection(keys):
ret = {}
for k in keys:
ret[k] = copy(tf.get_collection(k))
return ret
def restore_collection(backup):
for k, v in six.iteritems(backup):
del tf.get_collection_ref(k)[:]
tf.get_collection_ref(k).extend(v)
def clear_collection(keys):
for k in keys:
del tf.get_collection_ref(k)[:]
@contextmanager
def freeze_collection(keys):
backup = backup_collection(keys)
yield
restore_collection(backup)
...@@ -16,7 +16,6 @@ from ..utils.concurrency import start_proc_mask_signal ...@@ -16,7 +16,6 @@ from ..utils.concurrency import start_proc_mask_signal
from ..callbacks import StatHolder from ..callbacks import StatHolder
from ..tfutils import * from ..tfutils import *
from ..tfutils.summary import create_summary from ..tfutils.summary import create_summary
from ..tfutils.modelutils import describe_model
__all__ = ['Trainer'] __all__ = ['Trainer']
...@@ -141,7 +140,6 @@ class Trainer(object): ...@@ -141,7 +140,6 @@ class Trainer(object):
self.sess.close() self.sess.close()
def init_session_and_coord(self): def init_session_and_coord(self):
describe_model()
self.sess = tf.Session(config=self.config.session_config) self.sess = tf.Session(config=self.config.session_config)
self.coord = tf.train.Coordinator() self.coord = tf.train.Coordinator()
......
This diff is collapsed.
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment