Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
fbc13fb4
Commit
fbc13fb4
authored
Dec 10, 2016
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
notes, logs, online moments
parent
99a8ee54
Changes
7
Show whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
45 additions
and
34 deletions
+45
-34
examples/DoReFa-Net/alexnet-dorefa.py
examples/DoReFa-Net/alexnet-dorefa.py
+3
-1
tensorpack/dataflow/common.py
tensorpack/dataflow/common.py
+4
-5
tensorpack/dataflow/format.py
tensorpack/dataflow/format.py
+10
-5
tensorpack/dataflow/raw.py
tensorpack/dataflow/raw.py
+2
-1
tensorpack/train/input_data.py
tensorpack/train/input_data.py
+1
-1
tensorpack/utils/argtools.py
tensorpack/utils/argtools.py
+14
-9
tensorpack/utils/stats.py
tensorpack/utils/stats.py
+11
-12
No files found.
examples/DoReFa-Net/alexnet-dorefa.py
View file @
fbc13fb4
...
...
@@ -54,9 +54,10 @@ To Train:
ILSVRC2012_val_00000001.JPEG
...
And
better to have:
And
you'll need the following to be able to fetch data efficiently
Fast disk random access (Not necessarily SSD. I used a RAID of HDD, but not sure if plain HDD is enough)
More than 12 CPU cores (for data processing)
More than 10G of free memory
To Run Pretrained Model:
./alexnet-dorefa.py --load alexnet-126.npy --run a.jpg --dorefa 1,2,6
...
...
@@ -303,6 +304,7 @@ if __name__ == '__main__':
assert
args
.
gpu
is
not
None
,
"Need to specify a list of gpu for training!"
NR_GPU
=
len
(
args
.
gpu
.
split
(
','
))
BATCH_SIZE
=
TOTAL_BATCH_SIZE
//
NR_GPU
logger
.
info
(
"Batch per tower: {}"
.
format
(
BATCH_SIZE
))
config
=
get_config
()
if
args
.
load
:
...
...
tensorpack/dataflow/common.py
View file @
fbc13fb4
...
...
@@ -21,16 +21,15 @@ class TestDataSpeed(ProxyDataFlow):
self
.
test_size
=
size
def
get_data
(
self
):
with
get_tqdm
(
total
=
self
.
test_size
)
as
pbar
:
for
dp
in
self
.
ds
.
get_data
():
pbar
.
update
()
self
.
start_test
()
for
dp
in
self
.
ds
.
get_data
():
yield
dp
def
start_test
(
self
):
self
.
ds
.
reset_state
()
for
k
in
self
.
get_data
():
pass
with
get_tqdm
(
total
=
self
.
test_size
)
as
pbar
:
for
dp
in
self
.
ds
.
get_data
():
pbar
.
update
()
class
BatchData
(
ProxyDataFlow
):
def
__init__
(
self
,
ds
,
batch_size
,
remainder
=
False
):
...
...
tensorpack/dataflow/format.py
View file @
fbc13fb4
...
...
@@ -72,13 +72,18 @@ class HDF5Data(RNGDataFlow):
class
LMDBData
(
RNGDataFlow
):
""" Read a lmdb and produce k,v pair """
def
__init__
(
self
,
lmdb_path
,
shuffle
=
True
):
self
.
_lmdb
=
lmdb
.
open
(
lmdb_path
,
subdir
=
os
.
path
.
isdir
(
lmdb_path
),
readonly
=
True
,
lock
=
False
,
self
.
_lmdb_path
=
lmdb_path
self
.
_shuffle
=
shuffle
self
.
open_lmdb
()
def
open_lmdb
(
self
):
self
.
_lmdb
=
lmdb
.
open
(
self
.
_lmdb_path
,
subdir
=
os
.
path
.
isdir
(
self
.
_lmdb_path
),
readonly
=
True
,
lock
=
False
,
readahead
=
False
,
map_size
=
1099511627776
*
2
,
max_readers
=
100
)
self
.
_txn
=
self
.
_lmdb
.
begin
()
self
.
_shuffle
=
shuffle
self
.
_size
=
self
.
_txn
.
stat
()[
'entries'
]
if
shuffle
:
if
s
elf
.
_s
huffle
:
# get the list of keys either from __keys__ or by iterating
self
.
keys
=
loads
(
self
.
_txn
.
get
(
'__keys__'
))
if
not
self
.
keys
:
...
...
@@ -92,7 +97,7 @@ class LMDBData(RNGDataFlow):
def
reset_state
(
self
):
super
(
LMDBData
,
self
)
.
reset_state
()
self
.
_txn
=
self
.
_lmdb
.
begin
()
self
.
open_lmdb
()
def
size
(
self
):
return
self
.
_size
...
...
tensorpack/dataflow/raw.py
View file @
fbc13fb4
...
...
@@ -4,6 +4,7 @@
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
import
numpy
as
np
import
copy
from
six.moves
import
range
from
.base
import
DataFlow
,
RNGDataFlow
from
..utils.serialize
import
loads
...
...
@@ -41,7 +42,7 @@ class FakeData(RNGDataFlow):
else
:
v
=
[
self
.
rng
.
rand
(
*
k
)
.
astype
(
self
.
dtype
)
for
k
in
self
.
shapes
]
for
_
in
range
(
self
.
_size
):
yield
v
yield
copy
.
deepcopy
(
v
)
class
DataFromQueue
(
DataFlow
):
""" Produce data from a queue """
...
...
tensorpack/train/input_data.py
View file @
fbc13fb4
...
...
@@ -74,7 +74,7 @@ class EnqueueThread(threading.Thread):
if
self
.
coord
.
should_stop
():
return
feed
=
dict
(
zip
(
self
.
placehdrs
,
dp
))
#print '
TFQ
:', self.sess.run([self.op, self.size_op], feed_dict=feed)[1]
#print '
qsize
:', self.sess.run([self.op, self.size_op], feed_dict=feed)[1]
self
.
op
.
run
(
feed_dict
=
feed
)
except
tf
.
errors
.
CancelledError
as
e
:
pass
...
...
tensorpack/utils/argtools.py
View file @
fbc13fb4
...
...
@@ -4,6 +4,7 @@
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
import
operator
import
inspect
,
six
,
functools
import
collections
...
...
@@ -34,16 +35,20 @@ class memoized(object):
self
.
func
=
func
self
.
cache
=
{}
def
__call__
(
self
,
*
args
):
if
not
isinstance
(
args
,
collections
.
Hashable
):
def
__call__
(
self
,
*
args
,
**
kwargs
):
kwlist
=
tuple
(
sorted
(
list
(
kwargs
),
key
=
operator
.
itemgetter
(
0
)))
if
not
isinstance
(
args
,
collections
.
Hashable
)
or
\
not
isinstance
(
kwlist
,
collections
.
Hashable
):
logger
.
warn
(
"Arguments to memoized call is unhashable!"
)
# uncacheable. a list, for instance.
# better to not cache than blow up.
return
self
.
func
(
*
args
)
if
args
in
self
.
cache
:
return
self
.
cache
[
args
]
return
self
.
func
(
*
args
,
**
kwargs
)
key
=
(
args
,
kwlist
)
if
key
in
self
.
cache
:
return
self
.
cache
[
key
]
else
:
value
=
self
.
func
(
*
args
)
self
.
cache
[
args
]
=
value
value
=
self
.
func
(
*
args
,
**
kwargs
)
self
.
cache
[
key
]
=
value
return
value
def
__repr__
(
self
):
...
...
@@ -57,9 +62,9 @@ class memoized(object):
_MEMOIZED_NOARGS
=
{}
def
memoized_ignoreargs
(
func
):
h
=
hash
(
func
)
# make sure it is hashable. is it necessary?
def
wrapper
(
*
args
):
def
wrapper
(
*
args
,
**
kwargs
):
if
func
not
in
_MEMOIZED_NOARGS
:
res
=
func
(
*
args
)
res
=
func
(
*
args
,
**
kwargs
)
_MEMOIZED_NOARGS
[
func
]
=
res
return
res
return
_MEMOIZED_NOARGS
[
func
]
...
...
tensorpack/utils/stats.py
View file @
fbc13fb4
...
...
@@ -119,21 +119,20 @@ class BinaryStatistics(object):
return
1
-
self
.
recall
class
OnlineMoments
(
object
):
"""Compute 1st and 2nd moments online
See algorithm at: https://www.wikiwand.com/en/Algorithms_for_calculating_variance#/Online_algorithm
"""
def
__init__
(
self
):
self
.
_mean
=
None
self
.
_
var
=
None
self
.
_mean
=
0
self
.
_
M2
=
0
self
.
_n
=
0
def
feed
(
self
,
x
):
self
.
_n
+=
1
if
self
.
_mean
is
None
:
self
.
_mean
=
x
self
.
_var
=
0
else
:
diff
=
(
x
-
self
.
_mean
)
ninv
=
1.0
/
self
.
_n
self
.
_mean
+=
diff
*
ninv
self
.
_var
=
(
self
.
_n
-
2.0
)
/
(
self
.
_n
-
1
)
*
self
.
_var
+
diff
*
diff
*
ninv
delta
=
x
-
self
.
_mean
self
.
_mean
+=
delta
*
(
1.0
/
self
.
_n
)
delta2
=
x
-
self
.
_mean
self
.
_M2
+=
delta
*
delta2
@
property
def
mean
(
self
):
...
...
@@ -141,8 +140,8 @@ class OnlineMoments(object):
@
property
def
variance
(
self
):
return
self
.
_
var
return
self
.
_
M2
/
(
self
.
_n
-
1
)
@
property
def
std
(
self
):
return
np
.
sqrt
(
self
.
_var
)
return
np
.
sqrt
(
self
.
variance
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment