Commit 3d2c2f6e authored by Yuxin Wu's avatar Yuxin Wu

[DoReFa] Use `tf.custom_gradient` for dorefa.

parent f7ab74a3
...@@ -82,7 +82,7 @@ class Model(ModelDesc): ...@@ -82,7 +82,7 @@ class Model(ModelDesc):
def optimizer(self): def optimizer(self):
lr = tf.get_variable('learning_rate', initializer=self.learning_rate, trainable=False) lr = tf.get_variable('learning_rate', initializer=self.learning_rate, trainable=False)
opt = tf.train.AdamOptimizer(lr, epsilon=1e-3) opt = tf.train.RMSPropOptimizer(lr, epsilon=1e-5)
return optimizer.apply_grad_processors(opt, [gradproc.SummaryGradient()]) return optimizer.apply_grad_processors(opt, [gradproc.SummaryGradient()])
@staticmethod @staticmethod
......
...@@ -37,6 +37,8 @@ Alternative link to this page: [http://dorefa.net](http://dorefa.net) ...@@ -37,6 +37,8 @@ Alternative link to this page: [http://dorefa.net](http://dorefa.net)
## Use ## Use
+ Install TensorFlow>=1.7. For TensorFlow<1.7, you can use an earlier implementation of `dorefa.py` at [here](https://github.com/tensorpack/tensorpack/blob/58529de18e9bdad1bab31aed9c397a8f340e7f94/examples/DoReFa-Net/dorefa.py)
+ Install tensorpack and scipy. + Install tensorpack and scipy.
+ Look at the docstring in `*-dorefa.py` to see detailed usage and performance. + Look at the docstring in `*-dorefa.py` to see detailed usage and performance.
......
...@@ -12,20 +12,28 @@ def get_dorefa(bitW, bitA, bitG): ...@@ -12,20 +12,28 @@ def get_dorefa(bitW, bitA, bitG):
return the three quantization functions fw, fa, fg, for weights, activations and gradients respectively return the three quantization functions fw, fa, fg, for weights, activations and gradients respectively
It's unsafe to call this function multiple times with different parameters It's unsafe to call this function multiple times with different parameters
""" """
G = tf.get_default_graph()
def quantize(x, k): def quantize(x, k):
n = float(2**k - 1) n = float(2 ** k - 1)
with G.gradient_override_map({"Round": "Identity"}):
return tf.round(x * n) / n @tf.custom_gradient
def _quantize(x):
return tf.round(x * n) / n, lambda dy: dy
return _quantize(x)
def fw(x): def fw(x):
if bitW == 32: if bitW == 32:
return x return x
if bitW == 1: # BWN if bitW == 1: # BWN
with G.gradient_override_map({"Sign": "Identity"}):
E = tf.stop_gradient(tf.reduce_mean(tf.abs(x))) E = tf.stop_gradient(tf.reduce_mean(tf.abs(x)))
return tf.sign(x / E) * E
@tf.custom_gradient
def _sign(x):
return tf.sign(x / E) * E, lambda dy: dy
return _sign(x)
x = tf.tanh(x) x = tf.tanh(x)
x = x / tf.reduce_max(tf.abs(x)) * 0.5 + 0.5 x = x / tf.reduce_max(tf.abs(x)) * 0.5 + 0.5
return 2 * quantize(x, bitW) - 1 return 2 * quantize(x, bitW) - 1
...@@ -35,8 +43,13 @@ def get_dorefa(bitW, bitA, bitG): ...@@ -35,8 +43,13 @@ def get_dorefa(bitW, bitA, bitG):
return x return x
return quantize(x, bitA) return quantize(x, bitA)
@tf.RegisterGradient("FGGrad") def fg(x):
def grad_fg(op, x): if bitG == 32:
return x
@tf.custom_gradient
def _identity(input):
def grad_fg(x):
rank = x.get_shape().ndims rank = x.get_shape().ndims
assert rank is not None assert rank is not None
maxx = tf.reduce_max(tf.abs(x), list(range(1, rank)), keep_dims=True) maxx = tf.reduce_max(tf.abs(x), list(range(1, rank)), keep_dims=True)
...@@ -48,11 +61,9 @@ def get_dorefa(bitW, bitA, bitG): ...@@ -48,11 +61,9 @@ def get_dorefa(bitW, bitA, bitG):
x = quantize(x, bitG) - 0.5 x = quantize(x, bitG) - 0.5
return x * maxx * 2 return x * maxx * 2
def fg(x): return input, grad_fg
if bitG == 32:
return x return _identity(x)
with G.gradient_override_map({"Identity": "FGGrad"}):
return tf.identity(x)
return fw, fa, fg return fw, fa, fg
...@@ -64,7 +75,6 @@ def ternarize(x, thresh=0.05): ...@@ -64,7 +75,6 @@ def ternarize(x, thresh=0.05):
Code modified from the authors' at: Code modified from the authors' at:
https://github.com/czhu95/ternarynet/blob/master/examples/Ternary-Net/ternary.py https://github.com/czhu95/ternarynet/blob/master/examples/Ternary-Net/ternary.py
""" """
G = tf.get_default_graph()
shape = x.get_shape() shape = x.get_shape()
thre_x = tf.stop_gradient(tf.reduce_max(tf.abs(x)) * thresh) thre_x = tf.stop_gradient(tf.reduce_max(tf.abs(x)) * thresh)
...@@ -80,8 +90,11 @@ def ternarize(x, thresh=0.05): ...@@ -80,8 +90,11 @@ def ternarize(x, thresh=0.05):
mask_np = tf.where(x < -thre_x, tf.ones(shape) * w_n, mask_p) mask_np = tf.where(x < -thre_x, tf.ones(shape) * w_n, mask_p)
mask_z = tf.where((x < thre_x) & (x > - thre_x), tf.zeros(shape), mask) mask_z = tf.where((x < thre_x) & (x > - thre_x), tf.zeros(shape), mask)
with G.gradient_override_map({"Sign": "Identity", "Mul": "Add"}): @tf.custom_gradient
w = tf.sign(x) * tf.stop_gradient(mask_z) def _sign_mask(x):
return tf.sign(x) * mask_z, lambda dy: dy
w = _sign_mask(x)
w = w * mask_np w = w * mask_np
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment