Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
3da0c3ec
Commit
3da0c3ec
authored
Jun 23, 2016
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
svhn-dorefa
parent
e56dfb5f
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
234 additions
and
7 deletions
+234
-7
examples/DoReFa-Net/README.md
examples/DoReFa-Net/README.md
+9
-4
examples/DoReFa-Net/svhn-digit-dorefa.py
examples/DoReFa-Net/svhn-digit-dorefa.py
+222
-0
tensorpack/dataflow/common.py
tensorpack/dataflow/common.py
+3
-3
No files found.
examples/DoReFa-Net/README.md
View file @
3da0c3ec
This is the official script to
load and
run pretrained model for the paper:
This is the official script to
train, or
run pretrained model for the paper:
[
DoReFa-Net: Training Low Bitwidth Convolutional Neural Networks with Low Bitwidth Gradients
](
http://arxiv.org/abs/1606.06160
)
, by Zhou et al.
[
DoReFa-Net: Training Low Bitwidth Convolutional Neural Networks with Low Bitwidth Gradients
](
http://arxiv.org/abs/1606.06160
)
, by Zhou et al.
T
he provided model is an AlexNet with 1 bit weights, 2 bit activations, trained with 4 bit gradients
.
T
raining code for SVHN is available
.
T
raining code available soon
.
T
he provided pretrained model is an AlexNet with 1 bit weights, 2 bit activations, trained with 4 bit gradients
.
## Preparation:
## Preparation:
...
@@ -22,7 +22,12 @@ pip install --user -r tensorpack/requirements.txt
...
@@ -22,7 +22,12 @@ pip install --user -r tensorpack/requirements.txt
export PYTHONPATH=$PYTHONPATH:`readlink -f tensorpack`
export PYTHONPATH=$PYTHONPATH:`readlink -f tensorpack`
```
```
+
Download the model at
[
google drive
](
https://drive.google.com/open?id=0B308TeQzmFDLa0xOeVQwcXg1ZjQ
)
+
To perform training, you'll also need
[
pyzmq
](
https://github.com/zeromq/pyzmq
)
:
```
pip install --user pyzmq
```
+
Pretrained model is hosted at
[
google drive
](
https://drive.google.com/open?id=0B308TeQzmFDLa0xOeVQwcXg1ZjQ
)
## Load and run the model
## Load and run the model
We published the model in two file formats:
We published the model in two file formats:
...
...
examples/DoReFa-Net/svhn-digit-dorefa.py
0 → 100755
View file @
3da0c3ec
#!/usr/bin/env python
# -*- coding: UTF-8 -*-
# File: svhn-digit-dorefa.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
import
tensorflow
as
tf
import
argparse
import
numpy
as
np
import
os
from
tensorpack
import
*
from
tensorpack.tfutils.symbolic_functions
import
*
from
tensorpack.tfutils.summary
import
*
"""
Code for the paper:
DoReFa-Net: Training Low Bitwidth Convolutional Neural Networks with Low Bitwidth Gradients
http://arxiv.org/abs/1606.06160
The original experiements are performed on a proprietary framework.
This is our attempt to reproduce it on tensorpack.
This config, with (W,A,G)=(1,1,4), can reach 3.1~3.2
%
error after 150 epochs.
With the GaussianDeform augmentor, it will reach 2.8~2.9
%
.
"""
BITW
=
1
BITA
=
2
BITG
=
4
GRAD_DEFINED
=
False
def
get_dorefa
(
bitW
,
bitA
,
bitG
):
G
=
tf
.
get_default_graph
()
global
GRAD_DEFINED
if
not
GRAD_DEFINED
:
@
tf
.
RegisterGradient
(
"IdentityGrad"
)
def
ident_grad
(
op
,
grad
):
return
[
grad
]
*
len
(
op
.
inputs
)
def
quantize
(
x
,
k
):
n
=
float
(
2
**
k
-
1
)
with
G
.
gradient_override_map
({
"Floor"
:
"IdentityGrad"
}):
return
tf
.
round
(
x
*
n
)
/
n
def
fw
(
x
):
x
=
tf
.
tanh
(
x
)
x
=
x
/
tf
.
reduce_max
(
tf
.
abs
(
x
))
*
0.5
+
0.5
return
2
*
quantize
(
x
,
bitW
)
-
1
def
fa
(
x
):
return
quantize
(
x
,
bitA
)
if
not
GRAD_DEFINED
:
@
tf
.
RegisterGradient
(
"FGGrad"
)
def
grad_fg
(
op
,
x
):
rank
=
x
.
get_shape
()
.
ndims
assert
rank
is
not
None
maxx
=
tf
.
reduce_max
(
tf
.
abs
(
x
),
list
(
range
(
1
,
rank
)),
keep_dims
=
True
)
x
=
x
/
maxx
n
=
float
(
2
**
bitG
-
1
)
x
=
x
*
0.5
+
0.5
+
tf
.
random_uniform
(
tf
.
shape
(
x
),
minval
=-
0.5
/
n
,
maxval
=
0.5
/
n
)
x
=
tf
.
clip_by_value
(
x
,
0.0
,
1.0
)
x
=
quantize
(
x
,
bitG
)
-
0.5
return
x
*
maxx
*
2
def
fg
(
x
):
with
G
.
gradient_override_map
({
"Identity"
:
"FGGrad"
}):
return
tf
.
identity
(
x
)
GRAD_DEFINED
=
True
return
fw
,
fa
,
fg
class
Model
(
ModelDesc
):
def
_get_input_vars
(
self
):
return
[
InputVar
(
tf
.
float32
,
[
None
,
40
,
40
,
3
],
'input'
),
InputVar
(
tf
.
int32
,
[
None
],
'label'
)
]
def
_build_graph
(
self
,
input_vars
,
is_training
):
image
,
label
=
input_vars
fw
,
fa
,
fg
=
get_dorefa
(
BITW
,
BITA
,
BITG
)
# monkey-patch tf.get_variable to apply fw
old_get_variable
=
tf
.
get_variable
def
new_get_variable
(
name
,
shape
=
None
,
**
kwargs
):
v
=
old_get_variable
(
name
,
shape
,
**
kwargs
)
if
name
!=
'W'
or
'conv0'
in
v
.
op
.
name
or
'fc'
in
v
.
op
.
name
:
return
v
else
:
logger
.
info
(
"Binarizing weight {}"
.
format
(
v
.
op
.
name
))
return
fw
(
v
)
tf
.
get_variable
=
new_get_variable
def
cabs
(
x
):
return
tf
.
minimum
(
1.0
,
tf
.
abs
(
x
),
name
=
'cabs'
)
def
activate
(
x
):
return
fa
(
cabs
(
x
))
l
=
image
/
256.0
with
argscope
(
BatchNorm
,
decay
=
0.9
,
epsilon
=
1e-4
,
use_local_stat
=
is_training
),
\
argscope
(
Conv2D
,
use_bias
=
False
,
nl
=
tf
.
identity
):
l
=
Conv2D
(
'conv0'
,
l
,
48
,
5
,
padding
=
'VALID'
,
use_bias
=
True
)
l
=
MaxPooling
(
'pool0'
,
l
,
2
,
padding
=
'SAME'
)
l
=
activate
(
l
)
# 18
l
=
Conv2D
(
'conv1'
,
l
,
64
,
3
,
padding
=
'SAME'
)
l
=
activate
(
BatchNorm
(
'bn1'
,
fg
(
l
)))
l
=
Conv2D
(
'conv2'
,
l
,
64
,
3
,
padding
=
'SAME'
)
l
=
BatchNorm
(
'bn2'
,
fg
(
l
))
l
=
MaxPooling
(
'pool1'
,
l
,
2
,
padding
=
'SAME'
)
l
=
activate
(
l
)
# 9
l
=
Conv2D
(
'conv3'
,
l
,
128
,
3
,
padding
=
'VALID'
)
l
=
activate
(
BatchNorm
(
'bn3'
,
fg
(
l
)))
# 7
l
=
Conv2D
(
'conv4'
,
l
,
128
,
3
,
padding
=
'SAME'
)
l
=
activate
(
BatchNorm
(
'bn4'
,
fg
(
l
)))
l
=
Conv2D
(
'conv5'
,
l
,
128
,
3
,
padding
=
'VALID'
)
l
=
activate
(
BatchNorm
(
'bn5'
,
fg
(
l
)))
# 5
l
=
tf
.
nn
.
dropout
(
l
,
0.5
if
is_training
else
1.0
)
l
=
Conv2D
(
'conv6'
,
l
,
512
,
5
,
padding
=
'VALID'
)
l
=
BatchNorm
(
'bn6'
,
fg
(
l
))
l
=
cabs
(
l
)
logits
=
FullyConnected
(
'fc1'
,
l
,
10
,
nl
=
tf
.
identity
)
prob
=
tf
.
nn
.
softmax
(
logits
,
name
=
'output'
)
cost
=
tf
.
nn
.
sparse_softmax_cross_entropy_with_logits
(
logits
,
label
)
cost
=
tf
.
reduce_mean
(
cost
,
name
=
'cross_entropy_loss'
)
tf
.
add_to_collection
(
MOVING_SUMMARY_VARS_KEY
,
cost
)
# compute the number of failed samples, for ClassificationError to use at test time
wrong
=
prediction_incorrect
(
logits
,
label
)
nr_wrong
=
tf
.
reduce_sum
(
wrong
,
name
=
'wrong'
)
# monitor training error
tf
.
add_to_collection
(
MOVING_SUMMARY_VARS_KEY
,
tf
.
reduce_mean
(
wrong
,
name
=
'train_error'
))
# weight decay on all W of fc layers
wd_cost
=
regularize_cost
(
'fc.*/W'
,
l2_regularizer
(
1e-7
))
tf
.
add_to_collection
(
MOVING_SUMMARY_VARS_KEY
,
wd_cost
)
add_param_summary
([(
'.*/W'
,
[
'histogram'
,
'rms'
])])
self
.
cost
=
tf
.
add_n
([
cost
,
wd_cost
],
name
=
'cost'
)
def
get_config
():
basename
=
os
.
path
.
basename
(
__file__
)
logger
.
set_logger_dir
(
os
.
path
.
join
(
'train_log'
,
basename
[:
basename
.
rfind
(
'.'
)]))
# prepare dataset
d1
=
dataset
.
SVHNDigit
(
'train'
)
d2
=
dataset
.
SVHNDigit
(
'extra'
)
data_train
=
RandomMixData
([
d1
,
d2
])
data_test
=
dataset
.
SVHNDigit
(
'test'
)
augmentors
=
[
imgaug
.
Resize
((
40
,
40
)),
imgaug
.
Brightness
(
30
),
imgaug
.
Contrast
((
0.5
,
1.5
)),
#imgaug.GaussianDeform( # this is slow but helpful. only use it when you have lots of cpus
#[(0.2, 0.2), (0.2, 0.8), (0.8,0.8), (0.8,0.2)],
#(40,40), 0.2, 3),
]
data_train
=
AugmentImageComponent
(
data_train
,
augmentors
)
data_train
=
BatchData
(
data_train
,
128
)
data_train
=
PrefetchDataZMQ
(
data_train
,
5
)
step_per_epoch
=
data_train
.
size
()
augmentors
=
[
imgaug
.
Resize
((
40
,
40
))
]
data_test
=
AugmentImageComponent
(
data_test
,
augmentors
)
data_test
=
BatchData
(
data_test
,
128
,
remainder
=
True
)
lr
=
tf
.
train
.
exponential_decay
(
learning_rate
=
1e-3
,
global_step
=
get_global_step_var
(),
decay_steps
=
data_train
.
size
()
*
100
,
decay_rate
=
0.5
,
staircase
=
True
,
name
=
'learning_rate'
)
tf
.
scalar_summary
(
'learning_rate'
,
lr
)
return
TrainConfig
(
dataset
=
data_train
,
optimizer
=
tf
.
train
.
AdamOptimizer
(
lr
,
epsilon
=
1e-5
),
callbacks
=
Callbacks
([
StatPrinter
(),
ModelSaver
(),
InferenceRunner
(
data_test
,
[
ScalarStats
(
'cost'
),
ClassificationError
()])
]),
model
=
Model
(),
step_per_epoch
=
step_per_epoch
,
max_epoch
=
200
,
)
if
__name__
==
'__main__'
:
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
'--gpu'
,
help
=
'the GPU to use'
)
# nargs='*' in multi mode
parser
.
add_argument
(
'--load'
,
help
=
'load a checkpoint'
)
parser
.
add_argument
(
'--dorefa'
,
help
=
'number of bits for W,A,G, separated by comma. Defaults to
\'
1,2,4
\'
'
,
default
=
'1,2,4'
)
args
=
parser
.
parse_args
()
BITW
,
BITA
,
BITG
=
map
(
int
,
args
.
dorefa
.
split
(
','
))
if
args
.
gpu
:
os
.
environ
[
'CUDA_VISIBLE_DEVICES'
]
=
args
.
gpu
config
=
get_config
()
if
args
.
load
:
config
.
session_init
=
SaverRestore
(
args
.
load
)
if
args
.
gpu
:
config
.
nr_tower
=
len
(
args
.
gpu
.
split
(
','
))
QueueInputTrainer
(
config
)
.
train
()
tensorpack/dataflow/common.py
View file @
3da0c3ec
...
@@ -5,7 +5,7 @@
...
@@ -5,7 +5,7 @@
from
__future__
import
division
from
__future__
import
division
import
copy
import
copy
import
numpy
as
np
import
numpy
as
np
from
six.moves
import
range
from
six.moves
import
range
,
map
from
.base
import
DataFlow
,
ProxyDataFlow
from
.base
import
DataFlow
,
ProxyDataFlow
from
..utils
import
*
from
..utils
import
*
...
@@ -251,8 +251,8 @@ class RandomMixData(DataFlow):
...
@@ -251,8 +251,8 @@ class RandomMixData(DataFlow):
sums
=
np
.
cumsum
(
self
.
sizes
)
sums
=
np
.
cumsum
(
self
.
sizes
)
idxs
=
np
.
arange
(
self
.
size
())
idxs
=
np
.
arange
(
self
.
size
())
self
.
rng
.
shuffle
(
idxs
)
self
.
rng
.
shuffle
(
idxs
)
idxs
=
np
.
array
(
map
(
idxs
=
np
.
array
(
list
(
map
(
lambda
x
:
np
.
searchsorted
(
sums
,
x
,
'right'
),
idxs
))
lambda
x
:
np
.
searchsorted
(
sums
,
x
,
'right'
),
idxs
))
)
itrs
=
[
k
.
get_data
()
for
k
in
self
.
df_lists
]
itrs
=
[
k
.
get_data
()
for
k
in
self
.
df_lists
]
assert
idxs
.
max
()
==
len
(
itrs
)
-
1
,
"{}!={}"
.
format
(
idxs
.
max
(),
len
(
itrs
)
-
1
)
assert
idxs
.
max
()
==
len
(
itrs
)
-
1
,
"{}!={}"
.
format
(
idxs
.
max
(),
len
(
itrs
)
-
1
)
for
k
in
idxs
:
for
k
in
idxs
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment