Commit a963b01d authored by Yuxin Wu's avatar Yuxin Wu

Use the initializer from tf core instead of contrib

parent c2661527
...@@ -90,9 +90,9 @@ class Model(GANModelDesc): ...@@ -90,9 +90,9 @@ class Model(GANModelDesc):
A = tf.transpose(A / 255.0, [0, 3, 1, 2]) A = tf.transpose(A / 255.0, [0, 3, 1, 2])
B = tf.transpose(B / 255.0, [0, 3, 1, 2]) B = tf.transpose(B / 255.0, [0, 3, 1, 2])
# use the initializers from torch # use the torch initializers
with argscope([Conv2D, Deconv2D, FullyConnected], with argscope([Conv2D, Deconv2D, FullyConnected],
W_init=tf.contrib.layers.variance_scaling_initializer(factor=0.333, uniform=True), W_init=tf.variance_scaling_initializer(scale=0.333, distribution='uniform'),
use_bias=False), \ use_bias=False), \
argscope(BatchNorm, gamma_init=tf.random_uniform_initializer()), \ argscope(BatchNorm, gamma_init=tf.random_uniform_initializer()), \
argscope([Conv2D, Deconv2D, BatchNorm], data_format='NCHW'): argscope([Conv2D, Deconv2D, BatchNorm], data_format='NCHW'):
......
...@@ -13,7 +13,6 @@ from tensorpack.utils.gpu import get_nr_gpu ...@@ -13,7 +13,6 @@ from tensorpack.utils.gpu import get_nr_gpu
from tensorpack.dataflow import dataset from tensorpack.dataflow import dataset
import tensorflow as tf import tensorflow as tf
from tensorflow.contrib.layers import variance_scaling_initializer
""" """
CIFAR10 ResNet example. See: CIFAR10 ResNet example. See:
...@@ -75,7 +74,7 @@ class Model(ModelDesc): ...@@ -75,7 +74,7 @@ class Model(ModelDesc):
with argscope([Conv2D, AvgPooling, BatchNorm, GlobalAvgPooling], data_format='NCHW'), \ with argscope([Conv2D, AvgPooling, BatchNorm, GlobalAvgPooling], data_format='NCHW'), \
argscope(Conv2D, nl=tf.identity, use_bias=False, kernel_shape=3, argscope(Conv2D, nl=tf.identity, use_bias=False, kernel_shape=3,
W_init=variance_scaling_initializer(mode='FAN_OUT')): W_init=tf.variance_scaling_initializer(scale=2.0, mode='FAN_OUT')):
l = Conv2D('conv0', image, 16, nl=BNReLU) l = Conv2D('conv0', image, 16, nl=BNReLU)
l = residual('res1.0', l, first=True) l = residual('res1.0', l, first=True)
for k in range(1, self.n): for k in range(1, self.n):
......
...@@ -3,7 +3,6 @@ ...@@ -3,7 +3,6 @@
# File: resnet_model.py # File: resnet_model.py
import tensorflow as tf import tensorflow as tf
from tensorflow.contrib.layers import variance_scaling_initializer
from tensorpack.tfutils.argscope import argscope, get_arg_scope from tensorpack.tfutils.argscope import argscope, get_arg_scope
...@@ -116,7 +115,7 @@ def resnet_group(l, name, block_func, features, count, stride): ...@@ -116,7 +115,7 @@ def resnet_group(l, name, block_func, features, count, stride):
def resnet_backbone(image, num_blocks, group_func, block_func): def resnet_backbone(image, num_blocks, group_func, block_func):
with argscope(Conv2D, nl=tf.identity, use_bias=False, with argscope(Conv2D, nl=tf.identity, use_bias=False,
W_init=variance_scaling_initializer(mode='FAN_OUT')): W_init=tf.variance_scaling_initializer(scale=2.0, mode='FAN_OUT')):
logits = (LinearWrap(image) logits = (LinearWrap(image)
.Conv2D('conv0', 64, 7, stride=2, nl=BNReLU) .Conv2D('conv0', 64, 7, stride=2, nl=BNReLU)
.MaxPooling('pool0', shape=3, stride=2, padding='SAME') .MaxPooling('pool0', shape=3, stride=2, padding='SAME')
......
...@@ -11,7 +11,6 @@ import multiprocessing ...@@ -11,7 +11,6 @@ import multiprocessing
import tensorflow as tf import tensorflow as tf
from tensorflow.contrib.layers import variance_scaling_initializer
from tensorpack import * from tensorpack import *
from tensorpack.dataflow import dataset from tensorpack.dataflow import dataset
from tensorpack.tfutils import optimizer from tensorpack.tfutils import optimizer
...@@ -48,7 +47,7 @@ class Model(ModelDesc): ...@@ -48,7 +47,7 @@ class Model(ModelDesc):
defs, block_func = cfg[DEPTH] defs, block_func = cfg[DEPTH]
with argscope(Conv2D, nl=tf.identity, use_bias=False, with argscope(Conv2D, nl=tf.identity, use_bias=False,
W_init=variance_scaling_initializer(mode='FAN_OUT')), \ W_init=tf.variance_scaling_initializer(scale=2.0, mode='FAN_OUT')), \
argscope([Conv2D, MaxPooling, GlobalAvgPooling, BatchNorm], data_format='NCHW'): argscope([Conv2D, MaxPooling, GlobalAvgPooling, BatchNorm], data_format='NCHW'):
convmaps = (LinearWrap(image) convmaps = (LinearWrap(image)
.Conv2D('conv0', 64, 7, stride=2, nl=BNReLU) .Conv2D('conv0', 64, 7, stride=2, nl=BNReLU)
......
...@@ -35,7 +35,7 @@ def DepthConv(x, out_channel, kernel_shape, padding='SAME', stride=1, ...@@ -35,7 +35,7 @@ def DepthConv(x, out_channel, kernel_shape, padding='SAME', stride=1,
channel_mult = out_channel // in_channel channel_mult = out_channel // in_channel
if W_init is None: if W_init is None:
W_init = tf.contrib.layers.variance_scaling_initializer() W_init = tf.variance_scaling_initializer(2.0)
kernel_shape = [kernel_shape, kernel_shape] kernel_shape = [kernel_shape, kernel_shape]
filter_shape = kernel_shape + [in_channel, channel_mult] filter_shape = kernel_shape + [in_channel, channel_mult]
......
...@@ -27,7 +27,7 @@ def Conv2D(x, out_channel, kernel_shape, ...@@ -27,7 +27,7 @@ def Conv2D(x, out_channel, kernel_shape,
stride: (h, w) tuple or a int. stride: (h, w) tuple or a int.
padding (str): 'valid' or 'same'. Case insensitive. padding (str): 'valid' or 'same'. Case insensitive.
split (int): Split channels as used in Alexnet. Defaults to 1 (no split). split (int): Split channels as used in Alexnet. Defaults to 1 (no split).
W_init: initializer for W. Defaults to `variance_scaling_initializer`. W_init: initializer for W. Defaults to `variance_scaling_initializer(2.0)`, i.e. kaiming-normal.
b_init: initializer for b. Defaults to zero. b_init: initializer for b. Defaults to zero.
nl: a nonlinearity function. nl: a nonlinearity function.
use_bias (bool): whether to use bias. use_bias (bool): whether to use bias.
...@@ -53,7 +53,7 @@ def Conv2D(x, out_channel, kernel_shape, ...@@ -53,7 +53,7 @@ def Conv2D(x, out_channel, kernel_shape,
stride = shape4d(stride, data_format=data_format) stride = shape4d(stride, data_format=data_format)
if W_init is None: if W_init is None:
W_init = tf.contrib.layers.variance_scaling_initializer() W_init = tf.variance_scaling_initializer(scale=2.0)
if b_init is None: if b_init is None:
b_init = tf.constant_initializer() b_init = tf.constant_initializer()
...@@ -94,7 +94,7 @@ def Deconv2D(x, out_channel, kernel_shape, ...@@ -94,7 +94,7 @@ def Deconv2D(x, out_channel, kernel_shape,
kernel_shape: (h, w) tuple or a int. kernel_shape: (h, w) tuple or a int.
stride: (h, w) tuple or a int. stride: (h, w) tuple or a int.
padding (str): 'valid' or 'same'. Case insensitive. padding (str): 'valid' or 'same'. Case insensitive.
W_init: initializer for W. Defaults to `variance_scaling_initializer`. W_init: initializer for W. Defaults to `tf.variance_scaling_initializer(2.0)`, i.e. kaiming-normal.
b_init: initializer for b. Defaults to zero. b_init: initializer for b. Defaults to zero.
nl: a nonlinearity function. nl: a nonlinearity function.
use_bias (bool): whether to use bias. use_bias (bool): whether to use bias.
...@@ -115,7 +115,7 @@ def Deconv2D(x, out_channel, kernel_shape, ...@@ -115,7 +115,7 @@ def Deconv2D(x, out_channel, kernel_shape,
assert isinstance(out_channel, int), out_channel assert isinstance(out_channel, int), out_channel
if W_init is None: if W_init is None:
W_init = tf.contrib.layers.xavier_initializer_conv2d() W_init = tf.variance_scaling_initializer(scale=2.0)
if b_init is None: if b_init is None:
b_init = tf.constant_initializer() b_init = tf.constant_initializer()
......
...@@ -22,7 +22,7 @@ def FullyConnected(x, out_dim, ...@@ -22,7 +22,7 @@ def FullyConnected(x, out_dim,
Args: Args:
x (tf.Tensor): a tensor to be flattened except for the first dimension. x (tf.Tensor): a tensor to be flattened except for the first dimension.
out_dim (int): output dimension out_dim (int): output dimension
W_init: initializer for W. Defaults to `variance_scaling_initializer`. W_init: initializer for W. Defaults to `variance_scaling_initializer(2.0)`, i.e. kaiming-normal.
b_init: initializer for b. Defaults to zero. b_init: initializer for b. Defaults to zero.
nl: a nonlinearity function nl: a nonlinearity function
use_bias (bool): whether to use bias. use_bias (bool): whether to use bias.
...@@ -38,7 +38,7 @@ def FullyConnected(x, out_dim, ...@@ -38,7 +38,7 @@ def FullyConnected(x, out_dim,
x = symbf.batch_flatten(x) x = symbf.batch_flatten(x)
if W_init is None: if W_init is None:
W_init = tf.contrib.layers.variance_scaling_initializer() W_init = tf.variance_scaling_initializer(2.0)
if b_init is None: if b_init is None:
b_init = tf.constant_initializer() b_init = tf.constant_initializer()
......
...@@ -211,4 +211,6 @@ def is_training_name(name): ...@@ -211,4 +211,6 @@ def is_training_name(name):
return True return True
if name.startswith('EMA/'): # all the moving average summaries if name.startswith('EMA/'): # all the moving average summaries
return True return True
if name.startswith('AccumGrad') or name.endswith('/AccumGrad'):
return True
return False return False
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment