Commit fb5a99e0 authored by Yuxin Wu's avatar Yuxin Wu

small change

parent 32e70abf
......@@ -16,7 +16,7 @@ from imagenet_utils import (
ImageNetModel, get_imagenet_dataflow, fbresnet_augmentor)
def GroupNorm(x, group):
def GroupNorm(x, group, gamma_initializer=tf.constant_initializer(1.)):
"""
https://arxiv.org/abs/1803.08494
"""
......@@ -39,7 +39,7 @@ def GroupNorm(x, group):
beta = tf.get_variable('beta', [chan], initializer=tf.constant_initializer())
beta = tf.reshape(beta, new_shape)
gamma = tf.get_variable('gamma', [chan], initializer=tf.constant_initializer(1.0))
gamma = tf.get_variable('gamma', [chan], initializer=gamma_initializer)
gamma = tf.reshape(gamma, new_shape)
out = tf.nn.batch_normalization(x, mean, var, beta, gamma, 1e-5, name='output')
......
......@@ -64,12 +64,11 @@ def get_config(model, fake=False):
logger.info("Running on {} towers. Batch size per tower: {}".format(nr_tower, batch))
if fake:
dataset_train = FakeData(
[[batch, 224, 224, 3], [batch]], 1000, random=False, dtype='uint8')
data = QueueInput(FakeData(
[[batch, 224, 224, 3], [batch]], 1000, random=False, dtype='uint8'))
callbacks = []
else:
dataset_train = get_data('train', batch)
dataset_val = get_data('val', batch)
data = QueueInput(get_data('train', batch))
START_LR = 0.1
BASE_LR = START_LR * (args.batch / 256.0)
......@@ -87,6 +86,7 @@ def get_config(model, fake=False):
infs = [ClassificationError('wrong-top1', 'val-error-top1'),
ClassificationError('wrong-top5', 'val-error-top5')]
dataset_val = get_data('val', batch)
if nr_tower == 1:
# single-GPU inference with queue prefetch
callbacks.append(InferenceRunner(QueueInput(dataset_val), infs))
......@@ -97,7 +97,7 @@ def get_config(model, fake=False):
return TrainConfig(
model=model,
dataflow=dataset_train,
data=data,
callbacks=callbacks,
steps_per_epoch=100 if args.fake else 1280000 // args.batch,
max_epoch=105,
......
......@@ -330,6 +330,7 @@ class GradientPacker(object):
if len(dtypes) != 1:
logger.info("Skip GradientPacker due to inconsistent gradient types.")
return False
self._grad_dtype = grads[0].dtype
split_size = self._total_size // self._num_split
split_size_last = self._total_size - split_size * (self._num_split - 1)
......@@ -352,12 +353,14 @@ class GradientPacker(object):
with cached_name_scope("GradientPacker", top_level=False):
concat_grads = tf.concat([tf.reshape(g, [-1]) for g in grads], 0, name='concatenated_grads')
# concat_grads = tf.cast(concat_grads, tf.float16)
grad_packs = tf.split(concat_grads, self._split_sizes)
return grad_packs
def unpack(self, grad_packs):
with cached_name_scope("GradientPacker", top_level=False):
concat_grads = tf.concat(grad_packs, 0, name='concatenated_packs')
# concat_grads = tf.cast(concat_grads, self._grad_dtype)
flattened_grads = tf.split(concat_grads, self._sizes)
grads = [tf.reshape(g, shape) for g, shape in zip(flattened_grads, self._shapes)]
return grads
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment