Commit fb5a99e0 authored by Yuxin Wu's avatar Yuxin Wu

small change

parent 32e70abf
...@@ -16,7 +16,7 @@ from imagenet_utils import ( ...@@ -16,7 +16,7 @@ from imagenet_utils import (
ImageNetModel, get_imagenet_dataflow, fbresnet_augmentor) ImageNetModel, get_imagenet_dataflow, fbresnet_augmentor)
def GroupNorm(x, group): def GroupNorm(x, group, gamma_initializer=tf.constant_initializer(1.)):
""" """
https://arxiv.org/abs/1803.08494 https://arxiv.org/abs/1803.08494
""" """
...@@ -39,7 +39,7 @@ def GroupNorm(x, group): ...@@ -39,7 +39,7 @@ def GroupNorm(x, group):
beta = tf.get_variable('beta', [chan], initializer=tf.constant_initializer()) beta = tf.get_variable('beta', [chan], initializer=tf.constant_initializer())
beta = tf.reshape(beta, new_shape) beta = tf.reshape(beta, new_shape)
gamma = tf.get_variable('gamma', [chan], initializer=tf.constant_initializer(1.0)) gamma = tf.get_variable('gamma', [chan], initializer=gamma_initializer)
gamma = tf.reshape(gamma, new_shape) gamma = tf.reshape(gamma, new_shape)
out = tf.nn.batch_normalization(x, mean, var, beta, gamma, 1e-5, name='output') out = tf.nn.batch_normalization(x, mean, var, beta, gamma, 1e-5, name='output')
......
...@@ -64,12 +64,11 @@ def get_config(model, fake=False): ...@@ -64,12 +64,11 @@ def get_config(model, fake=False):
logger.info("Running on {} towers. Batch size per tower: {}".format(nr_tower, batch)) logger.info("Running on {} towers. Batch size per tower: {}".format(nr_tower, batch))
if fake: if fake:
dataset_train = FakeData( data = QueueInput(FakeData(
[[batch, 224, 224, 3], [batch]], 1000, random=False, dtype='uint8') [[batch, 224, 224, 3], [batch]], 1000, random=False, dtype='uint8'))
callbacks = [] callbacks = []
else: else:
dataset_train = get_data('train', batch) data = QueueInput(get_data('train', batch))
dataset_val = get_data('val', batch)
START_LR = 0.1 START_LR = 0.1
BASE_LR = START_LR * (args.batch / 256.0) BASE_LR = START_LR * (args.batch / 256.0)
...@@ -87,6 +86,7 @@ def get_config(model, fake=False): ...@@ -87,6 +86,7 @@ def get_config(model, fake=False):
infs = [ClassificationError('wrong-top1', 'val-error-top1'), infs = [ClassificationError('wrong-top1', 'val-error-top1'),
ClassificationError('wrong-top5', 'val-error-top5')] ClassificationError('wrong-top5', 'val-error-top5')]
dataset_val = get_data('val', batch)
if nr_tower == 1: if nr_tower == 1:
# single-GPU inference with queue prefetch # single-GPU inference with queue prefetch
callbacks.append(InferenceRunner(QueueInput(dataset_val), infs)) callbacks.append(InferenceRunner(QueueInput(dataset_val), infs))
...@@ -97,7 +97,7 @@ def get_config(model, fake=False): ...@@ -97,7 +97,7 @@ def get_config(model, fake=False):
return TrainConfig( return TrainConfig(
model=model, model=model,
dataflow=dataset_train, data=data,
callbacks=callbacks, callbacks=callbacks,
steps_per_epoch=100 if args.fake else 1280000 // args.batch, steps_per_epoch=100 if args.fake else 1280000 // args.batch,
max_epoch=105, max_epoch=105,
......
...@@ -330,6 +330,7 @@ class GradientPacker(object): ...@@ -330,6 +330,7 @@ class GradientPacker(object):
if len(dtypes) != 1: if len(dtypes) != 1:
logger.info("Skip GradientPacker due to inconsistent gradient types.") logger.info("Skip GradientPacker due to inconsistent gradient types.")
return False return False
self._grad_dtype = grads[0].dtype
split_size = self._total_size // self._num_split split_size = self._total_size // self._num_split
split_size_last = self._total_size - split_size * (self._num_split - 1) split_size_last = self._total_size - split_size * (self._num_split - 1)
...@@ -352,12 +353,14 @@ class GradientPacker(object): ...@@ -352,12 +353,14 @@ class GradientPacker(object):
with cached_name_scope("GradientPacker", top_level=False): with cached_name_scope("GradientPacker", top_level=False):
concat_grads = tf.concat([tf.reshape(g, [-1]) for g in grads], 0, name='concatenated_grads') concat_grads = tf.concat([tf.reshape(g, [-1]) for g in grads], 0, name='concatenated_grads')
# concat_grads = tf.cast(concat_grads, tf.float16)
grad_packs = tf.split(concat_grads, self._split_sizes) grad_packs = tf.split(concat_grads, self._split_sizes)
return grad_packs return grad_packs
def unpack(self, grad_packs): def unpack(self, grad_packs):
with cached_name_scope("GradientPacker", top_level=False): with cached_name_scope("GradientPacker", top_level=False):
concat_grads = tf.concat(grad_packs, 0, name='concatenated_packs') concat_grads = tf.concat(grad_packs, 0, name='concatenated_packs')
# concat_grads = tf.cast(concat_grads, self._grad_dtype)
flattened_grads = tf.split(concat_grads, self._sizes) flattened_grads = tf.split(concat_grads, self._sizes)
grads = [tf.reshape(g, shape) for g, shape in zip(flattened_grads, self._shapes)] grads = [tf.reshape(g, shape) for g, shape in zip(flattened_grads, self._shapes)]
return grads return grads
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment