Commit 03f18976 authored by Yuxin Wu's avatar Yuxin Wu

Add TTQ inside DoReFa

parent 5cdf1d33
......@@ -11,7 +11,9 @@ Pretrained model for (1,4,32)-ResNet18 and (1,2,6)-AlexNet are available at
They're provided in the format of numpy dictionary, so it should be very easy to port into other applications.
The __binary-weight 4-bit-activation ResNet-18__ model has 59.2% top-1 validation accuracy.
Note that when (W,A,G) is set to (1,32,32), this code is also an implementation of [Binary Weight Network](https://arxiv.org/abs/1511.00363).
Note that when (W,A,G) is set to (1,32,32), this code implements [Binary Weight Network](https://arxiv.org/abs/1511.00363).
When (W,A,G) is set to (t,32,32), this code implements
[Trained Ternary Quantization](https://arxiv.org/abs/1612.01064).
But with (W,A,G) set to (1,1,32), it is not equivalent to [XNOR-Net](https://arxiv.org/abs/1603.05279), although it won't be hard to implement it.
Alternative link to this page: [http://dorefa.net](http://dorefa.net)
......
......@@ -19,7 +19,7 @@ from tensorpack.dataflow import dataset
from tensorpack.utils.gpu import get_nr_gpu
from imagenet_utils import get_imagenet_dataflow, fbresnet_augmentor
from dorefa import get_dorefa
from dorefa import get_dorefa, ternarize
"""
This is a tensorpack script for the ImageNet results in paper:
......@@ -35,7 +35,9 @@ Accuracy:
due to more sophisticated augmentations.
With (W,A,G)=(32,32,32) -- full precision baseline, 41.4% error.
With (W,A,G)=(t,32,32) -- TTQ, 42.3% error
With (W,A,G)=(1,32,32) -- BWN, 44.3% error
With (W,A,G)=(1,1,32) -- BNN, 53.4% error
With (W,A,G)=(1,2,6), 47.6% error
With (W,A,G)=(1,2,4), 58.4% error
......@@ -84,7 +86,11 @@ class Model(ModelDesc):
def build_graph(self, image, label):
image = image / 255.0
fw, fa, fg = get_dorefa(BITW, BITA, BITG)
if BITW == 't':
fw, fa, fg = get_dorefa(32, 32, 32)
fw = ternarize
else:
fw, fa, fg = get_dorefa(BITW, BITA, BITG)
# monkey-patch tf.get_variable to apply fw
def new_get_variable(v):
......@@ -93,7 +99,7 @@ class Model(ModelDesc):
if not name.endswith('W') or 'conv0' in name or 'fct' in name:
return v
else:
logger.info("Binarizing weight {}".format(v.op.name))
logger.info("Quantizing weight {}".format(v.op.name))
return fw(v)
def nonlin(x):
......@@ -175,7 +181,6 @@ def get_data(dataset_name):
def get_config():
logger.auto_set_dir()
data_train = get_data('train')
data_test = get_data('val')
......@@ -242,12 +247,17 @@ if __name__ == '__main__':
parser.add_argument('--gpu', help='the physical ids of GPUs to use')
parser.add_argument('--load', help='load a checkpoint, or a npz (given as the pretrained model)')
parser.add_argument('--data', help='ILSVRC dataset dir')
parser.add_argument('--dorefa',
help='number of bits for W,A,G, separated by comma', required=True)
parser.add_argument('--dorefa', required=True,
help='number of bits for W,A,G, separated by comma. W="t" means TTQ')
parser.add_argument('--run', help='run on a list of images with the pretrained model', nargs='*')
args = parser.parse_args()
BITW, BITA, BITG = map(int, args.dorefa.split(','))
dorefa = args.dorefa.split(',')
if dorefa[0] == 't':
assert dorefa[1] == '32' and dorefa[2] == '32'
BITW, BITA, BITG = 't', 32, 32
else:
BITW, BITA, BITG = map(int, dorefa)
if args.gpu:
os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
......@@ -259,6 +269,8 @@ if __name__ == '__main__':
nr_tower = max(get_nr_gpu(), 1)
BATCH_SIZE = TOTAL_BATCH_SIZE // nr_tower
logger.set_logger_dir(os.path.join(
'train_log', 'alexnet-dorefa-{}'.format(args.dorefa)))
logger.info("Batch per tower: {}".format(BATCH_SIZE))
config = get_config()
......
......@@ -55,3 +55,36 @@ def get_dorefa(bitW, bitA, bitG):
with G.gradient_override_map({"Identity": "FGGrad"}):
return tf.identity(x)
return fw, fa, fg
def ternarize(x, thresh=0.05):
"""
Implemented Trained Ternary Quantization:
https://arxiv.org/abs/1612.01064
Code modified from the authors' at:
https://github.com/czhu95/ternarynet/blob/master/examples/Ternary-Net/ternary.py
"""
G = tf.get_default_graph()
shape = x.get_shape()
thre_x = tf.stop_gradient(tf.reduce_max(tf.abs(x)) * thresh)
w_p = tf.get_variable('Wp', initializer=1.0, dtype=tf.float32)
w_n = tf.get_variable('Wn', initializer=1.0, dtype=tf.float32)
tf.summary.scalar(w_p.op.name + '-summary', w_p)
tf.summary.scalar(w_n.op.name + '-summary', w_n)
mask = tf.ones(shape)
mask_p = tf.where(x > thre_x, tf.ones(shape) * w_p, mask)
mask_np = tf.where(x < -thre_x, tf.ones(shape) * w_n, mask_p)
mask_z = tf.where((x < thre_x) & (x > - thre_x), tf.zeros(shape), mask)
with G.gradient_override_map({"Sign": "Identity", "Mul": "Add"}):
w = tf.sign(x) * tf.stop_gradient(mask_z)
w = w * mask_np
tf.summary.histogram(w.name, w)
return w
......@@ -51,14 +51,15 @@ Evaluate the performance of a model and save to json.
These models are trained with different configurations on trainval35k and evaluated on minival using mAP@IoU=0.50:0.95.
MaskRCNN results contain both bbox and segm mAP.
|Backbone|`FASTRCNN_BATCH`|resolution |schedule|mAP (bbox/segm)|Time |
| - | - | - | - | - | - |
|Backbone|`FASTRCNN_BATCH`|resolution |schedule|mAP (bbox/segm)|Time |
| - | - | - | - | - | - |
|R-50 |64 |(600, 1024)|280k |33.1 |18h on 8 V100s|
|R-50 |512 |(800, 1333)|280k |35.6 |55h on 8 P100s|
|R-50 |512 |(800, 1333)|360k |36.6 |49h on 8 V100s|
|R-50 |256 |(800, 1333)|280k |36.8/32.1 |39h on 8 P100s|
|R-50 |512 |(800, 1333)|360k |37.8/33.2 |51h on 8 V100s|
|R-101 |512 |(800, 1333)|280k |40.1/34.4 |70h on 8 P100s|
|R-101 |512 |(800, 1333)|360k |40.8/35.1 |63h on 8 V100s|
The two 360k models have identical configurations with
`R50-C4-2x` configuration in
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment