Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
cf1ca7ae
Commit
cf1ca7ae
authored
Dec 13, 2016
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
add timit
parent
6beb258b
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
298 additions
and
3 deletions
+298
-3
examples/DoReFa-Net/alexnet-dorefa.py
examples/DoReFa-Net/alexnet-dorefa.py
+3
-3
examples/TIMIT/create-lmdb.py
examples/TIMIT/create-lmdb.py
+121
-0
examples/TIMIT/timitdata.py
examples/TIMIT/timitdata.py
+54
-0
examples/TIMIT/train-timit.py
examples/TIMIT/train-timit.py
+120
-0
No files found.
examples/DoReFa-Net/alexnet-dorefa.py
View file @
cf1ca7ae
...
@@ -22,7 +22,7 @@ DoReFa-Net: Training Low Bitwidth Convolutional Neural Networks with Low Bitwidt
...
@@ -22,7 +22,7 @@ DoReFa-Net: Training Low Bitwidth Convolutional Neural Networks with Low Bitwidt
http://arxiv.org/abs/1606.06160
http://arxiv.org/abs/1606.06160
The original experiements are performed on a proprietary framework.
The original experiements are performed on a proprietary framework.
This is our attempt to reproduce it on tensorpack
/tensorf
low.
This is our attempt to reproduce it on tensorpack
& TensorF
low.
Accuracy:
Accuracy:
Trained with 4 GPUs and (W,A,G)=(1,2,6), it can reach top-1 single-crop validation error of 51
%
,
Trained with 4 GPUs and (W,A,G)=(1,2,6), it can reach top-1 single-crop validation error of 51
%
,
...
@@ -40,8 +40,8 @@ Accuracy:
...
@@ -40,8 +40,8 @@ Accuracy:
Speed:
Speed:
About 2.8 iteration/s on 1 TitanX. (Each epoch is set to 10000 iterations)
About 2.8 iteration/s on 1 TitanX. (Each epoch is set to 10000 iterations)
To Train:
To Train
, for example
:
./alexnet-dorefa.py --dorefa 1,2,6 --data PATH --gpu 0,1
,2,3
./alexnet-dorefa.py --dorefa 1,2,6 --data PATH --gpu 0,1
PATH should look like:
PATH should look like:
PATH/
PATH/
...
...
examples/TIMIT/create-lmdb.py
0 → 100755
View file @
cf1ca7ae
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# File: create-lmdb.py
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
import
sys
,
os
import
scipy.io.wavfile
as
wavfile
import
string
import
numpy
as
np
import
argparse
from
tensorpack
import
*
from
tensorpack.utils.argtools
import
memoized
from
tensorpack.utils.stats
import
OnlineMoments
import
bob.ap
CHARSET
=
set
(
string
.
ascii_lowercase
+
' '
)
PHONEME_LIST
=
"aa,ae,ah,ao,aw,ax,ax-h,axr,ay,b,bcl,ch,d,dcl,dh,dx,eh,el,em,en,eng,epi,er,ey,f,g,gcl,h#,hh,hv,ih,ix,iy,jh,k,kcl,l,m,n,ng,nx,ow,oy,p,pau,pcl,q,r,s,sh,t,tcl,th,uh,uw,ux,v,w,y,z,zh"
.
split
(
','
)
PHONEME_DIC
=
{
v
:
k
for
k
,
v
in
enumerate
(
PHONEME_LIST
)}
WORD_DIC
=
{
v
:
k
for
k
,
v
in
enumerate
(
string
.
ascii_lowercase
+
' '
)}
def
read_timit_txt
(
f
):
f
=
open
(
f
)
line
=
f
.
readlines
()[
0
]
.
strip
()
.
split
(
' '
)
line
=
line
[
2
:]
line
=
' '
.
join
(
line
)
line
=
line
.
replace
(
'.'
,
''
)
.
lower
()
line
=
filter
(
lambda
c
:
c
in
CHARSET
,
line
)
f
.
close
()
for
c
in
line
:
ret
.
append
(
WORD_DIC
[
c
])
return
np
.
asarray
(
ret
)
def
read_timit_phoneme
(
f
):
f
=
open
(
f
)
pho
=
[]
for
line
in
f
:
line
=
line
.
strip
()
.
split
(
' '
)[
-
1
]
pho
.
append
(
PHONEME_DIC
[
line
])
f
.
close
()
return
np
.
asarray
(
pho
)
@
memoized
def
get_bob_extractor
(
fs
,
win_length_ms
=
10
,
win_shift_ms
=
5
,
n_filters
=
55
,
n_ceps
=
15
,
f_min
=
0.
,
f_max
=
6000
,
delta_win
=
2
,
pre_emphasis_coef
=
0.95
,
dct_norm
=
True
,
mel_scale
=
True
):
ret
=
bob
.
ap
.
Ceps
(
fs
,
win_length_ms
,
win_shift_ms
,
n_filters
,
n_ceps
,
f_min
,
f_max
,
delta_win
,
pre_emphasis_coef
,
mel_scale
,
dct_norm
)
return
ret
def
diff_feature
(
feat
,
nd
=
1
):
diff
=
feat
[
1
:]
-
feat
[:
-
1
]
feat
=
feat
[
1
:]
if
nd
==
1
:
return
np
.
concatenate
((
feat
,
diff
),
axis
=
1
)
elif
nd
==
2
:
d2
=
diff
[
1
:]
-
diff
[:
-
1
]
return
np
.
concatenate
((
feat
[
1
:],
diff
[
1
:],
d2
),
axis
=
1
)
def
get_feature
(
f
):
fs
,
signal
=
wavfile
.
read
(
f
)
signal
=
signal
.
astype
(
'float64'
)
feat
=
get_bob_extractor
(
fs
,
n_filters
=
26
,
n_ceps
=
13
)(
signal
)
feat
=
diff_feature
(
feat
,
nd
=
2
)
return
feat
class
RawTIMIT
(
DataFlow
):
def
__init__
(
self
,
dirname
,
label
=
'phoneme'
):
self
.
dirname
=
dirname
assert
os
.
path
.
isdir
(
dirname
),
dirname
self
.
filelists
=
[
k
for
k
in
fs
.
recursive_walk
(
self
.
dirname
)
if
k
.
endswith
(
'.wav'
)]
logger
.
info
(
"Found {} wav files ..."
.
format
(
len
(
self
.
filelists
)))
assert
label
in
[
'phoneme'
,
'letter'
],
label
self
.
label
=
label
def
size
(
self
):
return
len
(
self
.
filelists
)
def
get_data
(
self
):
for
f
in
self
.
filelists
:
feat
=
get_feature
(
f
)
if
self
.
label
==
'phoneme'
:
label
=
read_timit_phoneme
(
f
[:
-
4
]
+
'.PHN'
)
elif
self
.
label
==
'letter'
:
label
=
read_timit_txt
(
f
[:
-
4
]
+
'.TXT'
)
yield
[
feat
,
label
]
def
compute_mean_std
(
db
,
fname
):
ds
=
LMDBDataPoint
(
db
,
shuffle
=
False
)
o
=
OnlineMoments
()
with
get_tqdm
(
total
=
ds
.
size
())
as
bar
:
for
dp
in
ds
.
get_data
():
feat
=
dp
[
0
]
#len x dim
for
f
in
feat
:
o
.
feed
(
f
)
bar
.
update
()
logger
.
info
(
"Writing to {} ..."
.
format
(
fname
))
with
open
(
fname
,
'wb'
)
as
f
:
f
.
write
(
serialize
.
dumps
([
o
.
mean
,
o
.
std
]))
if
__name__
==
'__main__'
:
parser
=
argparse
.
ArgumentParser
()
subparsers
=
parser
.
add_subparsers
(
title
=
'command'
,
dest
=
'command'
)
parser_db
=
subparsers
.
add_parser
(
'build'
,
help
=
'build a LMDB database'
)
parser_db
.
add_argument
(
'--dataset'
,
help
=
'path to TIMIT TRAIN or TEST directory'
,
required
=
True
)
parser_db
.
add_argument
(
'--db'
,
help
=
'output lmdb file'
,
required
=
True
)
parser_stat
=
subparsers
.
add_parser
(
'stat'
,
help
=
'compute statistics (mean/std) of dataset'
)
parser_stat
.
add_argument
(
'--db'
,
help
=
'input lmdb file'
,
required
=
True
)
parser_stat
.
add_argument
(
'-o'
,
'--output'
,
help
=
'output statistics file'
,
default
=
'stats.data'
)
args
=
parser
.
parse_args
()
if
args
.
command
==
'build'
:
ds
=
RawTIMIT
(
args
.
dataset
)
dftools
.
dump_dataflow_to_lmdb
(
ds
,
args
.
db
)
elif
args
.
command
==
'stat'
:
compute_mean_std
(
args
.
db
,
args
.
output
)
examples/TIMIT/timitdata.py
0 → 100644
View file @
cf1ca7ae
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# File: timitdata.py
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
from
tensorpack
import
ProxyDataFlow
import
numpy
as
np
from
six.moves
import
range
__all__
=
[
'TIMITBatch'
]
def
batch_feature
(
feats
):
maxlen
=
max
([
k
.
shape
[
0
]
for
k
in
feats
])
bsize
=
len
(
feats
)
ret
=
np
.
zeros
((
bsize
,
maxlen
,
feats
[
0
]
.
shape
[
1
]))
for
idx
,
feat
in
enumerate
(
feats
):
ret
[
idx
,:
feat
.
shape
[
0
],:]
=
feat
return
ret
def
sparse_label
(
labels
):
maxlen
=
max
([
k
.
shape
[
0
]
for
k
in
labels
])
shape
=
[
len
(
labels
),
maxlen
]
# bxt
indices
=
[]
values
=
[]
for
bid
,
lab
in
enumerate
(
labels
):
for
tid
,
c
in
enumerate
(
lab
):
indices
.
append
([
bid
,
tid
])
values
.
append
(
c
)
indices
=
np
.
asarray
(
indices
)
values
=
np
.
asarray
(
values
)
return
(
indices
,
values
,
shape
)
class
TIMITBatch
(
ProxyDataFlow
):
def
__init__
(
self
,
ds
,
batch
):
self
.
batch
=
batch
self
.
ds
=
ds
def
size
(
self
):
return
self
.
ds
.
size
()
//
self
.
batch
def
get_data
(
self
):
itr
=
self
.
ds
.
get_data
()
for
_
in
range
(
self
.
size
()):
feats
=
[]
labs
=
[]
for
b
in
range
(
self
.
batch
):
feat
,
lab
=
next
(
itr
)
feats
.
append
(
feat
)
labs
.
append
(
lab
)
batchfeat
=
batch_feature
(
feats
)
batchlab
=
sparse_label
(
labs
)
seqlen
=
np
.
asarray
([
k
.
shape
[
0
]
for
k
in
feats
])
yield
[
batchfeat
,
batchlab
[
0
],
batchlab
[
1
],
batchlab
[
2
],
seqlen
]
examples/TIMIT/train-timit.py
0 → 100755
View file @
cf1ca7ae
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# File: train-timit.py
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
import
tensorflow
as
tf
import
numpy
as
np
import
os
,
sys
import
argparse
from
collections
import
Counter
import
operator
import
six
from
six.moves
import
map
,
range
from
tensorpack
import
*
from
tensorpack.tfutils.gradproc
import
*
from
tensorpack.utils.globvars
import
globalns
as
param
import
tensorpack.tfutils.symbolic_functions
as
symbf
from
timitdata
import
TIMITBatch
BATCH
=
64
NLAYER
=
2
HIDDEN
=
128
NR_CLASS
=
61
+
1
FEATUREDIM
=
39
class
Model
(
ModelDesc
):
def
_get_input_vars
(
self
):
return
[
InputVar
(
tf
.
float32
,
[
None
,
None
,
FEATUREDIM
],
'feat'
),
# bxmaxseqx39
InputVar
(
tf
.
int64
,
None
,
'labelidx'
),
#label is b x maxlen, sparse
InputVar
(
tf
.
int32
,
None
,
'labelvalue'
),
InputVar
(
tf
.
int64
,
None
,
'labelshape'
),
InputVar
(
tf
.
int32
,
[
None
],
'seqlen'
),
# b
]
def
_build_graph
(
self
,
input_vars
):
feat
,
labelidx
,
labelvalue
,
labelshape
,
seqlen
=
input_vars
label
=
tf
.
SparseTensor
(
labelidx
,
labelvalue
,
labelshape
)
cell
=
tf
.
nn
.
rnn_cell
.
BasicLSTMCell
(
num_units
=
HIDDEN
)
cell
=
tf
.
nn
.
rnn_cell
.
MultiRNNCell
([
cell
]
*
NLAYER
)
initial
=
cell
.
zero_state
(
tf
.
shape
(
feat
)[
0
],
tf
.
float32
)
outputs
,
last_state
=
tf
.
nn
.
dynamic_rnn
(
cell
,
feat
,
seqlen
,
initial
,
dtype
=
tf
.
float32
,
scope
=
'rnn'
)
# o: b x t x HIDDEN
output
=
tf
.
reshape
(
outputs
,
[
-
1
,
HIDDEN
])
# (Bxt) x rnnsize
logits
=
FullyConnected
(
'fc'
,
output
,
NR_CLASS
,
nl
=
tf
.
identity
,
W_init
=
tf
.
truncated_normal_initializer
(
stddev
=
0.01
))
logits
=
tf
.
reshape
(
logits
,
(
BATCH
,
-
1
,
NR_CLASS
))
loss
=
tf
.
nn
.
ctc_loss
(
logits
,
label
,
seqlen
,
time_major
=
False
)
self
.
cost
=
tf
.
reduce_mean
(
loss
,
name
=
'cost'
)
logits
=
tf
.
transpose
(
logits
,
[
1
,
0
,
2
])
predictions
=
tf
.
to_int32
(
tf
.
nn
.
ctc_beam_search_decoder
(
logits
,
seqlen
)[
0
][
0
])
err
=
tf
.
edit_distance
(
predictions
,
label
,
normalize
=
True
)
err
.
set_shape
([
None
])
err
=
tf
.
reduce_mean
(
err
,
name
=
'error'
)
summary
.
add_moving_summary
(
err
)
def
get_gradient_processor
(
self
):
return
[
GlobalNormClip
(
5
),
SummaryGradient
()
]
def
get_data
(
path
,
isTrain
,
stat_file
):
ds
=
LMDBDataPoint
(
path
,
shuffle
=
isTrain
)
mean
,
std
=
serialize
.
loads
(
open
(
stat_file
)
.
read
())
ds
=
MapDataComponent
(
ds
,
lambda
x
:
(
x
-
mean
)
/
std
)
ds
=
TIMITBatch
(
ds
,
BATCH
)
if
isTrain
:
ds
=
PrefetchDataZMQ
(
ds
,
1
)
return
ds
def
get_config
(
ds_train
,
ds_test
):
step_per_epoch
=
ds_train
.
size
()
lr
=
symbolic_functions
.
get_scalar_var
(
'learning_rate'
,
5e-3
,
summary
=
True
)
return
TrainConfig
(
dataset
=
ds_train
,
optimizer
=
tf
.
train
.
AdamOptimizer
(
lr
,
epsilon
=
1e-3
),
callbacks
=
Callbacks
([
StatPrinter
(),
ModelSaver
(),
StatMonitorParamSetter
(
'learning_rate'
,
'error'
,
lambda
x
:
x
*
0.2
,
0
,
5
),
HumanHyperParamSetter
(
'learning_rate'
),
PeriodicCallback
(
InferenceRunner
(
ds_test
,
[
ScalarStats
(
'error'
)]),
2
),
]),
model
=
Model
(),
step_per_epoch
=
step_per_epoch
,
max_epoch
=
500
,
)
if
__name__
==
'__main__'
:
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
'--gpu'
,
help
=
'comma separated list of GPU(s) to use.'
)
parser
.
add_argument
(
'--load'
,
help
=
'load model'
)
parser
.
add_argument
(
'--train'
,
help
=
'path to training lmdb'
,
required
=
True
)
parser
.
add_argument
(
'--test'
,
help
=
'path to testing lmdb'
,
required
=
True
)
parser
.
add_argument
(
'--stat'
,
help
=
'path to the mean/std statistics file'
,
default
=
'stats.data'
)
args
=
parser
.
parse_args
()
if
args
.
gpu
:
os
.
environ
[
'CUDA_VISIBLE_DEVICES'
]
=
args
.
gpu
logger
.
auto_set_dir
()
ds_train
=
get_data
(
args
.
train
,
True
,
args
.
stat
)
ds_test
=
get_data
(
args
.
test
,
False
,
args
.
stat
)
config
=
get_config
(
ds_train
,
ds_test
)
if
args
.
load
:
config
.
session_init
=
SaverRestore
(
args
.
load
)
QueueInputTrainer
(
config
)
.
train
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment