Commit 5ef0578f authored by Yuxin Wu's avatar Yuxin Wu

regularization handles parameters of different dtypes

parent 1139854d
...@@ -32,7 +32,8 @@ See [tutorials](http://tensorpack.readthedocs.io/en/latest/tutorial/index.html) ...@@ -32,7 +32,8 @@ See [tutorials](http://tensorpack.readthedocs.io/en/latest/tutorial/index.html)
## [Examples](examples): ## [Examples](examples):
Instead of showing you 10 random networks trained on toy datasets, We refuse toy examples.
Instead of showing you 10 arbitrary networks trained on toy datasets,
[tensorpack examples](examples) faithfully replicate papers and care about reproducing numbers, [tensorpack examples](examples) faithfully replicate papers and care about reproducing numbers,
demonstrating its flexibility for actual research. demonstrating its flexibility for actual research.
Some highlights: Some highlights:
......
...@@ -8,13 +8,14 @@ With the magic of SGD, wrong deep learning code often appears to still work, ...@@ -8,13 +8,14 @@ With the magic of SGD, wrong deep learning code often appears to still work,
especially if you try it on toy datasets. especially if you try it on toy datasets.
See [Unawareness of Deep Learning Mistakes](https://medium.com/@ppwwyyxx/unawareness-of-deep-learning-mistakes-d5b5774da0ba). See [Unawareness of Deep Learning Mistakes](https://medium.com/@ppwwyyxx/unawareness-of-deep-learning-mistakes-d5b5774da0ba).
We refuse toy examples.
Instead of showing you 10 arbitrary networks trained on toy datasets with random final performance, Instead of showing you 10 arbitrary networks trained on toy datasets with random final performance,
tensorpack examples try to faithfully replicate experiments and performance in the paper as much as possible, tensorpack examples try to faithfully replicate experiments and performance in the paper as much as possible,
so you're confident that they are correct. so you're confident that they are correct.
## Getting Started: ## Getting Started:
These examples don't have meaningful performance numbers. They are supposed to be just demos. These are all the toy examples in tensorpack. They are supposed to be just demos.
+ [An illustrative MNIST example with explanation of the framework](basics/mnist-convnet.py) + [An illustrative MNIST example with explanation of the framework](basics/mnist-convnet.py)
+ Tensorpack supports any symbolic libraries. See the same MNIST example written with [tf.layers](basics/mnist-tflayers.py), [tf-slim](basics/mnist-tfslim.py), and [with weights visualizations](basics/mnist-visualizations.py) + Tensorpack supports any symbolic libraries. See the same MNIST example written with [tf.layers](basics/mnist-tflayers.py), [tf-slim](basics/mnist-tfslim.py), and [with weights visualizations](basics/mnist-visualizations.py)
+ A tiny [Cifar ConvNet](basics/cifar-convnet.py) and [SVHN ConvNet](basics/svhn-digit-convnet.py) + A tiny [Cifar ConvNet](basics/cifar-convnet.py) and [SVHN ConvNet](basics/svhn-digit-convnet.py)
......
...@@ -81,7 +81,12 @@ def resnet_bottleneck(l, ch_out, stride, stride_first=False): ...@@ -81,7 +81,12 @@ def resnet_bottleneck(l, ch_out, stride, stride_first=False):
""" """
shortcut = l shortcut = l
l = Conv2D('conv1', l, ch_out, 1, strides=stride if stride_first else 1, activation=BNReLU) l = Conv2D('conv1', l, ch_out, 1, strides=stride if stride_first else 1, activation=BNReLU)
l = Conv2D('conv2', l, ch_out, 3, strides=1 if stride_first else stride, activation=BNReLU) if stride == 2:
l = tf.pad(l, [[0,0],[0,0],[1,1],[1,1]])
l = Conv2D('conv2', l, ch_out, 3, strides=1 if stride_first else
stride, activation=BNReLU, padding='VALID')
else:
l = Conv2D('conv2', l, ch_out, 3, strides=1 if stride_first else stride, activation=BNReLU)
l = Conv2D('conv3', l, ch_out * 4, 1, activation=get_bn(zero_init=True)) l = Conv2D('conv3', l, ch_out * 4, 1, activation=get_bn(zero_init=True))
return l + resnet_shortcut(shortcut, ch_out * 4, stride, activation=get_bn(zero_init=False)) return l + resnet_shortcut(shortcut, ch_out * 4, stride, activation=get_bn(zero_init=False))
...@@ -117,8 +122,11 @@ def resnet_backbone(image, num_blocks, group_func, block_func): ...@@ -117,8 +122,11 @@ def resnet_backbone(image, num_blocks, group_func, block_func):
with argscope(Conv2D, use_bias=False, with argscope(Conv2D, use_bias=False,
kernel_initializer=tf.variance_scaling_initializer(scale=2.0, mode='fan_out')): kernel_initializer=tf.variance_scaling_initializer(scale=2.0, mode='fan_out')):
logits = (LinearWrap(image) logits = (LinearWrap(image)
.Conv2D('conv0', 64, 7, strides=2, activation=BNReLU) .tf.pad([[0,0],[0,0],[3,3],[3,3]])
.MaxPooling('pool0', shape=3, stride=2, padding='SAME') .Conv2D('conv0', 64, 7, strides=2, activation=BNReLU,
padding='VALID')
.tf.pad([[0,0],[0,0],[1,1],[1,1]])
.MaxPooling('pool0', shape=3, stride=2, padding='VALID')
.apply(group_func, 'group0', block_func, 64, num_blocks[0], 1) .apply(group_func, 'group0', block_func, 64, num_blocks[0], 1)
.apply(group_func, 'group1', block_func, 128, num_blocks[1], 2) .apply(group_func, 'group1', block_func, 128, num_blocks[1], 2)
.apply(group_func, 'group2', block_func, 256, num_blocks[2], 2) .apply(group_func, 'group2', block_func, 256, num_blocks[2], 2)
......
...@@ -8,7 +8,7 @@ Keras alone has various overhead. In particular, it is not efficient when workin ...@@ -8,7 +8,7 @@ Keras alone has various overhead. In particular, it is not efficient when workin
The article [Towards Efficient Multi-GPU Training in Keras with TensorFlow](https://medium.com/rossum/towards-efficient-multi-gpu-training-in-keras-with-tensorflow-8a0091074fb2) The article [Towards Efficient Multi-GPU Training in Keras with TensorFlow](https://medium.com/rossum/towards-efficient-multi-gpu-training-in-keras-with-tensorflow-8a0091074fb2)
has mentioned some of it. has mentioned some of it.
Even on a single GPU, tensorpack can run [1.1~2x faster](https://github.com/tensorpack/benchmarks/tree/master/other-wrappers) Even on a single GPU, tensorpack can run [1.2~2x faster](https://github.com/tensorpack/benchmarks/tree/master/other-wrappers)
than the equivalent Keras code. The gap becomes larger when you scale. than the equivalent Keras code. The gap becomes larger when you scale.
Tensorpack and [horovod](https://github.com/uber/horovod/blob/master/examples/keras_imagenet_resnet50.py) Tensorpack and [horovod](https://github.com/uber/horovod/blob/master/examples/keras_imagenet_resnet50.py)
are the only two tools I know that can scale the training of a large Keras model. are the only two tools I know that can scale the training of a large Keras model.
...@@ -26,7 +26,7 @@ reproduce exactly the same setting of [tensorpack ResNet example](../ResNet) on ...@@ -26,7 +26,7 @@ reproduce exactly the same setting of [tensorpack ResNet example](../ResNet) on
It has: It has:
+ ResNet-50 model modified from [keras.applications](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/keras/_impl/keras/applications/resnet50.py). + ResNet-50 model modified from [keras.applications](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/keras/_impl/keras/applications/resnet50.py).
(We put stride on 3x3 conv in each bottleneck, which is different from some other implementations). (We put stride on 3x3 conv in each bottleneck, which is different from certain other implementations).
+ Multi-GPU data-parallel __training and validation__ which scales + Multi-GPU data-parallel __training and validation__ which scales
+ Finished 100 epochs in 19.5 hours on 8 V100s, with >90% GPU utilization. + Finished 100 epochs in 19.5 hours on 8 V100s, with >90% GPU utilization.
+ Still slightly slower than native tensorpack examples. + Still slightly slower than native tensorpack examples.
......
...@@ -64,7 +64,13 @@ def regularize_cost(regex, func, name='regularize_cost'): ...@@ -64,7 +64,13 @@ def regularize_cost(regex, func, name='regularize_cost'):
for p in params: for p in params:
para_name = p.op.name para_name = p.op.name
if re.search(regex, para_name): if re.search(regex, para_name):
costs.append(func(p)) regloss = func(p)
assert regloss.dtype.is_floating, regloss
# Some variables may not be fp32, but it should
# be fine to assume regularization in fp32
if regloss.dtype != tf.float32:
regloss = tf.cast(regloss, tf.float32)
costs.append(regloss)
names.append(p.name) names.append(p.name)
if not costs: if not costs:
return tf.constant(0, dtype=tf.float32, name='empty_' + name) return tf.constant(0, dtype=tf.float32, name='empty_' + name)
...@@ -112,6 +118,14 @@ def regularize_cost_from_collection(name='regularize_cost'): ...@@ -112,6 +118,14 @@ def regularize_cost_from_collection(name='regularize_cost'):
if len(losses) > 0: if len(losses) > 0:
logger.info("regularize_cost_from_collection() found {} regularizers " logger.info("regularize_cost_from_collection() found {} regularizers "
"in REGULARIZATION_LOSSES collection.".format(len(losses))) "in REGULARIZATION_LOSSES collection.".format(len(losses)))
def maploss(l):
assert l.dtype.is_floating, l
if l.dtype != tf.float32:
l = tf.cast(l, tf.float32)
return l
losses = [maploss(l) for l in losses]
reg_loss = tf.add_n(losses, name=name) reg_loss = tf.add_n(losses, name=name)
return reg_loss return reg_loss
else: else:
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment