Commit 3270acb8 authored by Yuxin Wu's avatar Yuxin Wu

simplify code and add bit==32

parent 28e8f8db
......@@ -30,27 +30,28 @@ BITG = 4
GRAD_DEFINED = False
def get_dorefa(bitW, bitA, bitG):
""" return the three quantization functions fw, fa, fg, for weights,
activations and gradients respectively"""
G = tf.get_default_graph()
global GRAD_DEFINED
if not GRAD_DEFINED:
@tf.RegisterGradient("IdentityGrad")
def ident_grad(op, grad):
return [grad] * len(op.inputs)
def quantize(x, k):
n = float(2**k-1)
with G.gradient_override_map({"Floor": "IdentityGrad"}):
with G.gradient_override_map({"Floor": "Identity"}):
return tf.round(x * n) / n
def fw(x):
if bitW == 32:
return x
x = tf.tanh(x)
x = x / tf.reduce_max(tf.abs(x)) * 0.5 + 0.5
return 2 * quantize(x, bitW) - 1
def fa(x):
if bitA == 32:
return x
return quantize(x, bitA)
global GRAD_DEFINED
if not GRAD_DEFINED:
@tf.RegisterGradient("FGGrad")
def grad_fg(op, x):
......@@ -66,6 +67,8 @@ def get_dorefa(bitW, bitA, bitG):
return x * maxx * 2
def fg(x):
if bitG == 32:
return x
with G.gradient_override_map({"Identity": "FGGrad"}):
return tf.identity(x)
GRAD_DEFINED = True
......@@ -84,7 +87,8 @@ class Model(ModelDesc):
old_get_variable = tf.get_variable
def new_get_variable(name, shape=None, **kwargs):
v = old_get_variable(name, shape, **kwargs)
if name != 'W' or 'conv0' in v.op.name or 'fc'in v.op.name:
# don't binarize first and last layer
if name != 'W' or 'conv0' in v.op.name or 'fc' in v.op.name:
return v
else:
logger.info("Binarizing weight {}".format(v.op.name))
......@@ -137,12 +141,12 @@ class Model(ModelDesc):
cost = tf.reduce_mean(cost, name='cross_entropy_loss')
tf.add_to_collection(MOVING_SUMMARY_VARS_KEY, cost)
# compute the number of failed samples, for ClassificationError to use at test time
# compute the number of failed samples
wrong = prediction_incorrect(logits, label)
nr_wrong = tf.reduce_sum(wrong, name='wrong')
# monitor training error
tf.add_to_collection(
MOVING_SUMMARY_VARS_KEY, tf.reduce_mean(wrong, name='train_error'))
tf.add_to_collection(MOVING_SUMMARY_VARS_KEY,
tf.reduce_mean(wrong, name='train_error'))
# weight decay on all W of fc layers
wd_cost = regularize_cost('fc.*/W', l2_regularizer(1e-7))
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment