Commit 03f18976 authored by Yuxin Wu's avatar Yuxin Wu

Add TTQ inside DoReFa

parent 5cdf1d33
...@@ -11,7 +11,9 @@ Pretrained model for (1,4,32)-ResNet18 and (1,2,6)-AlexNet are available at ...@@ -11,7 +11,9 @@ Pretrained model for (1,4,32)-ResNet18 and (1,2,6)-AlexNet are available at
They're provided in the format of numpy dictionary, so it should be very easy to port into other applications. They're provided in the format of numpy dictionary, so it should be very easy to port into other applications.
The __binary-weight 4-bit-activation ResNet-18__ model has 59.2% top-1 validation accuracy. The __binary-weight 4-bit-activation ResNet-18__ model has 59.2% top-1 validation accuracy.
Note that when (W,A,G) is set to (1,32,32), this code is also an implementation of [Binary Weight Network](https://arxiv.org/abs/1511.00363). Note that when (W,A,G) is set to (1,32,32), this code implements [Binary Weight Network](https://arxiv.org/abs/1511.00363).
When (W,A,G) is set to (t,32,32), this code implements
[Trained Ternary Quantization](https://arxiv.org/abs/1612.01064).
But with (W,A,G) set to (1,1,32), it is not equivalent to [XNOR-Net](https://arxiv.org/abs/1603.05279), although it won't be hard to implement it. But with (W,A,G) set to (1,1,32), it is not equivalent to [XNOR-Net](https://arxiv.org/abs/1603.05279), although it won't be hard to implement it.
Alternative link to this page: [http://dorefa.net](http://dorefa.net) Alternative link to this page: [http://dorefa.net](http://dorefa.net)
......
...@@ -19,7 +19,7 @@ from tensorpack.dataflow import dataset ...@@ -19,7 +19,7 @@ from tensorpack.dataflow import dataset
from tensorpack.utils.gpu import get_nr_gpu from tensorpack.utils.gpu import get_nr_gpu
from imagenet_utils import get_imagenet_dataflow, fbresnet_augmentor from imagenet_utils import get_imagenet_dataflow, fbresnet_augmentor
from dorefa import get_dorefa from dorefa import get_dorefa, ternarize
""" """
This is a tensorpack script for the ImageNet results in paper: This is a tensorpack script for the ImageNet results in paper:
...@@ -35,7 +35,9 @@ Accuracy: ...@@ -35,7 +35,9 @@ Accuracy:
due to more sophisticated augmentations. due to more sophisticated augmentations.
With (W,A,G)=(32,32,32) -- full precision baseline, 41.4% error. With (W,A,G)=(32,32,32) -- full precision baseline, 41.4% error.
With (W,A,G)=(t,32,32) -- TTQ, 42.3% error
With (W,A,G)=(1,32,32) -- BWN, 44.3% error With (W,A,G)=(1,32,32) -- BWN, 44.3% error
With (W,A,G)=(1,1,32) -- BNN, 53.4% error
With (W,A,G)=(1,2,6), 47.6% error With (W,A,G)=(1,2,6), 47.6% error
With (W,A,G)=(1,2,4), 58.4% error With (W,A,G)=(1,2,4), 58.4% error
...@@ -84,7 +86,11 @@ class Model(ModelDesc): ...@@ -84,7 +86,11 @@ class Model(ModelDesc):
def build_graph(self, image, label): def build_graph(self, image, label):
image = image / 255.0 image = image / 255.0
fw, fa, fg = get_dorefa(BITW, BITA, BITG) if BITW == 't':
fw, fa, fg = get_dorefa(32, 32, 32)
fw = ternarize
else:
fw, fa, fg = get_dorefa(BITW, BITA, BITG)
# monkey-patch tf.get_variable to apply fw # monkey-patch tf.get_variable to apply fw
def new_get_variable(v): def new_get_variable(v):
...@@ -93,7 +99,7 @@ class Model(ModelDesc): ...@@ -93,7 +99,7 @@ class Model(ModelDesc):
if not name.endswith('W') or 'conv0' in name or 'fct' in name: if not name.endswith('W') or 'conv0' in name or 'fct' in name:
return v return v
else: else:
logger.info("Binarizing weight {}".format(v.op.name)) logger.info("Quantizing weight {}".format(v.op.name))
return fw(v) return fw(v)
def nonlin(x): def nonlin(x):
...@@ -175,7 +181,6 @@ def get_data(dataset_name): ...@@ -175,7 +181,6 @@ def get_data(dataset_name):
def get_config(): def get_config():
logger.auto_set_dir()
data_train = get_data('train') data_train = get_data('train')
data_test = get_data('val') data_test = get_data('val')
...@@ -242,12 +247,17 @@ if __name__ == '__main__': ...@@ -242,12 +247,17 @@ if __name__ == '__main__':
parser.add_argument('--gpu', help='the physical ids of GPUs to use') parser.add_argument('--gpu', help='the physical ids of GPUs to use')
parser.add_argument('--load', help='load a checkpoint, or a npz (given as the pretrained model)') parser.add_argument('--load', help='load a checkpoint, or a npz (given as the pretrained model)')
parser.add_argument('--data', help='ILSVRC dataset dir') parser.add_argument('--data', help='ILSVRC dataset dir')
parser.add_argument('--dorefa', parser.add_argument('--dorefa', required=True,
help='number of bits for W,A,G, separated by comma', required=True) help='number of bits for W,A,G, separated by comma. W="t" means TTQ')
parser.add_argument('--run', help='run on a list of images with the pretrained model', nargs='*') parser.add_argument('--run', help='run on a list of images with the pretrained model', nargs='*')
args = parser.parse_args() args = parser.parse_args()
BITW, BITA, BITG = map(int, args.dorefa.split(',')) dorefa = args.dorefa.split(',')
if dorefa[0] == 't':
assert dorefa[1] == '32' and dorefa[2] == '32'
BITW, BITA, BITG = 't', 32, 32
else:
BITW, BITA, BITG = map(int, dorefa)
if args.gpu: if args.gpu:
os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
...@@ -259,6 +269,8 @@ if __name__ == '__main__': ...@@ -259,6 +269,8 @@ if __name__ == '__main__':
nr_tower = max(get_nr_gpu(), 1) nr_tower = max(get_nr_gpu(), 1)
BATCH_SIZE = TOTAL_BATCH_SIZE // nr_tower BATCH_SIZE = TOTAL_BATCH_SIZE // nr_tower
logger.set_logger_dir(os.path.join(
'train_log', 'alexnet-dorefa-{}'.format(args.dorefa)))
logger.info("Batch per tower: {}".format(BATCH_SIZE)) logger.info("Batch per tower: {}".format(BATCH_SIZE))
config = get_config() config = get_config()
......
...@@ -55,3 +55,36 @@ def get_dorefa(bitW, bitA, bitG): ...@@ -55,3 +55,36 @@ def get_dorefa(bitW, bitA, bitG):
with G.gradient_override_map({"Identity": "FGGrad"}): with G.gradient_override_map({"Identity": "FGGrad"}):
return tf.identity(x) return tf.identity(x)
return fw, fa, fg return fw, fa, fg
def ternarize(x, thresh=0.05):
"""
Implemented Trained Ternary Quantization:
https://arxiv.org/abs/1612.01064
Code modified from the authors' at:
https://github.com/czhu95/ternarynet/blob/master/examples/Ternary-Net/ternary.py
"""
G = tf.get_default_graph()
shape = x.get_shape()
thre_x = tf.stop_gradient(tf.reduce_max(tf.abs(x)) * thresh)
w_p = tf.get_variable('Wp', initializer=1.0, dtype=tf.float32)
w_n = tf.get_variable('Wn', initializer=1.0, dtype=tf.float32)
tf.summary.scalar(w_p.op.name + '-summary', w_p)
tf.summary.scalar(w_n.op.name + '-summary', w_n)
mask = tf.ones(shape)
mask_p = tf.where(x > thre_x, tf.ones(shape) * w_p, mask)
mask_np = tf.where(x < -thre_x, tf.ones(shape) * w_n, mask_p)
mask_z = tf.where((x < thre_x) & (x > - thre_x), tf.zeros(shape), mask)
with G.gradient_override_map({"Sign": "Identity", "Mul": "Add"}):
w = tf.sign(x) * tf.stop_gradient(mask_z)
w = w * mask_np
tf.summary.histogram(w.name, w)
return w
...@@ -51,14 +51,15 @@ Evaluate the performance of a model and save to json. ...@@ -51,14 +51,15 @@ Evaluate the performance of a model and save to json.
These models are trained with different configurations on trainval35k and evaluated on minival using mAP@IoU=0.50:0.95. These models are trained with different configurations on trainval35k and evaluated on minival using mAP@IoU=0.50:0.95.
MaskRCNN results contain both bbox and segm mAP. MaskRCNN results contain both bbox and segm mAP.
|Backbone|`FASTRCNN_BATCH`|resolution |schedule|mAP (bbox/segm)|Time | |Backbone|`FASTRCNN_BATCH`|resolution |schedule|mAP (bbox/segm)|Time |
| - | - | - | - | - | - | | - | - | - | - | - | - |
|R-50 |64 |(600, 1024)|280k |33.1 |18h on 8 V100s| |R-50 |64 |(600, 1024)|280k |33.1 |18h on 8 V100s|
|R-50 |512 |(800, 1333)|280k |35.6 |55h on 8 P100s| |R-50 |512 |(800, 1333)|280k |35.6 |55h on 8 P100s|
|R-50 |512 |(800, 1333)|360k |36.6 |49h on 8 V100s| |R-50 |512 |(800, 1333)|360k |36.6 |49h on 8 V100s|
|R-50 |256 |(800, 1333)|280k |36.8/32.1 |39h on 8 P100s| |R-50 |256 |(800, 1333)|280k |36.8/32.1 |39h on 8 P100s|
|R-50 |512 |(800, 1333)|360k |37.8/33.2 |51h on 8 V100s| |R-50 |512 |(800, 1333)|360k |37.8/33.2 |51h on 8 V100s|
|R-101 |512 |(800, 1333)|280k |40.1/34.4 |70h on 8 P100s| |R-101 |512 |(800, 1333)|280k |40.1/34.4 |70h on 8 P100s|
|R-101 |512 |(800, 1333)|360k |40.8/35.1 |63h on 8 V100s|
The two 360k models have identical configurations with The two 360k models have identical configurations with
`R50-C4-2x` configuration in `R50-C4-2x` configuration in
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment