Commit 3d2c2f6e authored by Yuxin Wu's avatar Yuxin Wu

[DoReFa] Use `tf.custom_gradient` for dorefa.

parent f7ab74a3
......@@ -82,7 +82,7 @@ class Model(ModelDesc):
def optimizer(self):
lr = tf.get_variable('learning_rate', initializer=self.learning_rate, trainable=False)
opt = tf.train.AdamOptimizer(lr, epsilon=1e-3)
opt = tf.train.RMSPropOptimizer(lr, epsilon=1e-5)
return optimizer.apply_grad_processors(opt, [gradproc.SummaryGradient()])
@staticmethod
......
......@@ -37,6 +37,8 @@ Alternative link to this page: [http://dorefa.net](http://dorefa.net)
## Use
+ Install TensorFlow>=1.7. For TensorFlow<1.7, you can use an earlier implementation of `dorefa.py` at [here](https://github.com/tensorpack/tensorpack/blob/58529de18e9bdad1bab31aed9c397a8f340e7f94/examples/DoReFa-Net/dorefa.py)
+ Install tensorpack and scipy.
+ Look at the docstring in `*-dorefa.py` to see detailed usage and performance.
......
......@@ -12,20 +12,28 @@ def get_dorefa(bitW, bitA, bitG):
return the three quantization functions fw, fa, fg, for weights, activations and gradients respectively
It's unsafe to call this function multiple times with different parameters
"""
G = tf.get_default_graph()
def quantize(x, k):
n = float(2**k - 1)
with G.gradient_override_map({"Round": "Identity"}):
return tf.round(x * n) / n
n = float(2 ** k - 1)
@tf.custom_gradient
def _quantize(x):
return tf.round(x * n) / n, lambda dy: dy
return _quantize(x)
def fw(x):
if bitW == 32:
return x
if bitW == 1: # BWN
with G.gradient_override_map({"Sign": "Identity"}):
E = tf.stop_gradient(tf.reduce_mean(tf.abs(x)))
return tf.sign(x / E) * E
@tf.custom_gradient
def _sign(x):
return tf.sign(x / E) * E, lambda dy: dy
return _sign(x)
x = tf.tanh(x)
x = x / tf.reduce_max(tf.abs(x)) * 0.5 + 0.5
return 2 * quantize(x, bitW) - 1
......@@ -35,8 +43,13 @@ def get_dorefa(bitW, bitA, bitG):
return x
return quantize(x, bitA)
@tf.RegisterGradient("FGGrad")
def grad_fg(op, x):
def fg(x):
if bitG == 32:
return x
@tf.custom_gradient
def _identity(input):
def grad_fg(x):
rank = x.get_shape().ndims
assert rank is not None
maxx = tf.reduce_max(tf.abs(x), list(range(1, rank)), keep_dims=True)
......@@ -48,11 +61,9 @@ def get_dorefa(bitW, bitA, bitG):
x = quantize(x, bitG) - 0.5
return x * maxx * 2
def fg(x):
if bitG == 32:
return x
with G.gradient_override_map({"Identity": "FGGrad"}):
return tf.identity(x)
return input, grad_fg
return _identity(x)
return fw, fa, fg
......@@ -64,7 +75,6 @@ def ternarize(x, thresh=0.05):
Code modified from the authors' at:
https://github.com/czhu95/ternarynet/blob/master/examples/Ternary-Net/ternary.py
"""
G = tf.get_default_graph()
shape = x.get_shape()
thre_x = tf.stop_gradient(tf.reduce_max(tf.abs(x)) * thresh)
......@@ -80,8 +90,11 @@ def ternarize(x, thresh=0.05):
mask_np = tf.where(x < -thre_x, tf.ones(shape) * w_n, mask_p)
mask_z = tf.where((x < thre_x) & (x > - thre_x), tf.zeros(shape), mask)
with G.gradient_override_map({"Sign": "Identity", "Mul": "Add"}):
w = tf.sign(x) * tf.stop_gradient(mask_z)
@tf.custom_gradient
def _sign_mask(x):
return tf.sign(x) * mask_z, lambda dy: dy
w = _sign_mask(x)
w = w * mask_np
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment