Commit 3270acb8 authored by Yuxin Wu's avatar Yuxin Wu

simplify code and add bit==32

parent 28e8f8db
...@@ -30,27 +30,28 @@ BITG = 4 ...@@ -30,27 +30,28 @@ BITG = 4
GRAD_DEFINED = False GRAD_DEFINED = False
def get_dorefa(bitW, bitA, bitG): def get_dorefa(bitW, bitA, bitG):
""" return the three quantization functions fw, fa, fg, for weights,
activations and gradients respectively"""
G = tf.get_default_graph() G = tf.get_default_graph()
global GRAD_DEFINED
if not GRAD_DEFINED:
@tf.RegisterGradient("IdentityGrad")
def ident_grad(op, grad):
return [grad] * len(op.inputs)
def quantize(x, k): def quantize(x, k):
n = float(2**k-1) n = float(2**k-1)
with G.gradient_override_map({"Floor": "IdentityGrad"}): with G.gradient_override_map({"Floor": "Identity"}):
return tf.round(x * n) / n return tf.round(x * n) / n
def fw(x): def fw(x):
if bitW == 32:
return x
x = tf.tanh(x) x = tf.tanh(x)
x = x / tf.reduce_max(tf.abs(x)) * 0.5 + 0.5 x = x / tf.reduce_max(tf.abs(x)) * 0.5 + 0.5
return 2 * quantize(x, bitW) - 1 return 2 * quantize(x, bitW) - 1
def fa(x): def fa(x):
if bitA == 32:
return x
return quantize(x, bitA) return quantize(x, bitA)
global GRAD_DEFINED
if not GRAD_DEFINED: if not GRAD_DEFINED:
@tf.RegisterGradient("FGGrad") @tf.RegisterGradient("FGGrad")
def grad_fg(op, x): def grad_fg(op, x):
...@@ -66,6 +67,8 @@ def get_dorefa(bitW, bitA, bitG): ...@@ -66,6 +67,8 @@ def get_dorefa(bitW, bitA, bitG):
return x * maxx * 2 return x * maxx * 2
def fg(x): def fg(x):
if bitG == 32:
return x
with G.gradient_override_map({"Identity": "FGGrad"}): with G.gradient_override_map({"Identity": "FGGrad"}):
return tf.identity(x) return tf.identity(x)
GRAD_DEFINED = True GRAD_DEFINED = True
...@@ -84,7 +87,8 @@ class Model(ModelDesc): ...@@ -84,7 +87,8 @@ class Model(ModelDesc):
old_get_variable = tf.get_variable old_get_variable = tf.get_variable
def new_get_variable(name, shape=None, **kwargs): def new_get_variable(name, shape=None, **kwargs):
v = old_get_variable(name, shape, **kwargs) v = old_get_variable(name, shape, **kwargs)
if name != 'W' or 'conv0' in v.op.name or 'fc'in v.op.name: # don't binarize first and last layer
if name != 'W' or 'conv0' in v.op.name or 'fc' in v.op.name:
return v return v
else: else:
logger.info("Binarizing weight {}".format(v.op.name)) logger.info("Binarizing weight {}".format(v.op.name))
...@@ -137,12 +141,12 @@ class Model(ModelDesc): ...@@ -137,12 +141,12 @@ class Model(ModelDesc):
cost = tf.reduce_mean(cost, name='cross_entropy_loss') cost = tf.reduce_mean(cost, name='cross_entropy_loss')
tf.add_to_collection(MOVING_SUMMARY_VARS_KEY, cost) tf.add_to_collection(MOVING_SUMMARY_VARS_KEY, cost)
# compute the number of failed samples, for ClassificationError to use at test time # compute the number of failed samples
wrong = prediction_incorrect(logits, label) wrong = prediction_incorrect(logits, label)
nr_wrong = tf.reduce_sum(wrong, name='wrong') nr_wrong = tf.reduce_sum(wrong, name='wrong')
# monitor training error # monitor training error
tf.add_to_collection( tf.add_to_collection(MOVING_SUMMARY_VARS_KEY,
MOVING_SUMMARY_VARS_KEY, tf.reduce_mean(wrong, name='train_error')) tf.reduce_mean(wrong, name='train_error'))
# weight decay on all W of fc layers # weight decay on all W of fc layers
wd_cost = regularize_cost('fc.*/W', l2_regularizer(1e-7)) wd_cost = regularize_cost('fc.*/W', l2_regularizer(1e-7))
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment