Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
60e52b94
Commit
60e52b94
authored
Mar 22, 2016
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
download / ilsvrcmeta
parent
d979ab78
Changes
8
Show whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
72 additions
and
21 deletions
+72
-21
examples/load_alexnet.py
examples/load_alexnet.py
+5
-1
examples/load_vgg16.py
examples/load_vgg16.py
+5
-1
tensorpack/dataflow/dataset/.gitignore
tensorpack/dataflow/dataset/.gitignore
+1
-0
tensorpack/dataflow/dataset/cifar10.py
tensorpack/dataflow/dataset/cifar10.py
+4
-12
tensorpack/dataflow/dataset/ilsvrc.py
tensorpack/dataflow/dataset/ilsvrc.py
+34
-0
tensorpack/dataflow/dataset/mnist.py
tensorpack/dataflow/dataset/mnist.py
+2
-5
tensorpack/dataflow/dataset/svhn.py
tensorpack/dataflow/dataset/svhn.py
+1
-1
tensorpack/utils/fs.py
tensorpack/utils/fs.py
+20
-1
No files found.
examples/load_alexnet.py
View file @
60e52b94
...
...
@@ -18,6 +18,7 @@ from tensorpack.tfutils.symbolic_functions import *
from
tensorpack.tfutils.summary
import
*
from
tensorpack.callbacks
import
*
from
tensorpack.dataflow
import
*
from
tensorpack.dataflow.dataset
import
ILSVRCMeta
BATCH_SIZE
=
10
MIN_AFTER_DEQUEUE
=
500
...
...
@@ -132,12 +133,15 @@ def run_test(path, input):
im
=
cv2
.
cvtColor
(
im
,
cv2
.
COLOR_BGR2RGB
)
im
=
cv2
.
resize
(
im
,
(
227
,
227
))
im
=
np
.
reshape
(
im
,
(
1
,
227
,
227
,
3
))
.
astype
(
'float32'
)
im
=
im
-
110
outputs
=
predict_func
([
im
])[
0
]
prob
=
outputs
[
0
]
print
prob
.
shape
ret
=
prob
.
argsort
()[
-
10
:][::
-
1
]
print
ret
meta
=
ILSVRCMeta
()
.
get_synset_words_1000
()
print
[
meta
[
k
]
for
k
in
ret
]
if
__name__
==
'__main__'
:
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
'--gpu'
,
help
=
'comma separated list of GPU(s) to use.'
)
# nargs='*' in multi mode
...
...
examples/load_vgg16.py
View file @
60e52b94
...
...
@@ -18,6 +18,7 @@ from tensorpack.tfutils.symbolic_functions import *
from
tensorpack.tfutils.summary
import
*
from
tensorpack.callbacks
import
*
from
tensorpack.dataflow
import
*
from
tensorpack.dataflow.dataset
import
ILSVRCMeta
class
Model
(
ModelDesc
):
def
_get_input_vars
(
self
):
...
...
@@ -104,12 +105,15 @@ def run_test(path, input):
im
=
cv2
.
cvtColor
(
im
,
cv2
.
COLOR_BGR2RGB
)
im
=
cv2
.
resize
(
im
,
(
224
,
224
))
im
=
np
.
reshape
(
im
,
(
1
,
224
,
224
,
3
))
.
astype
(
'float32'
)
im
=
im
-
110
outputs
=
predict_func
([
im
])[
0
]
prob
=
outputs
[
0
]
print
prob
.
shape
ret
=
prob
.
argsort
()[
-
10
:][::
-
1
]
print
ret
meta
=
ILSVRCMeta
()
.
get_synset_words_1000
()
print
[
meta
[
k
]
for
k
in
ret
]
if
__name__
==
'__main__'
:
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
'--gpu'
,
default
=
'0'
,
...
...
tensorpack/dataflow/dataset/.gitignore
View file @
60e52b94
mnist_data
cifar10_data
svhn_data
ilsvrc_metadata
tensorpack/dataflow/dataset/cifar10.py
View file @
60e52b94
...
...
@@ -13,6 +13,7 @@ import tarfile
import
logging
from
...utils
import
logger
,
get_rng
from
...utils.fs
import
download
from
..base
import
DataFlow
__all__
=
[
'Cifar10'
]
...
...
@@ -23,22 +24,13 @@ DATA_URL = 'http://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz'
def
maybe_download_and_extract
(
dest_directory
):
"""Download and extract the tarball from Alex's website.
copied from tensorflow example """
if
not
os
.
path
.
exists
(
dest_directory
):
os
.
makedirs
(
dest_directory
)
filename
=
DATA_URL
.
split
(
'/'
)[
-
1
]
filepath
=
os
.
path
.
join
(
dest_directory
,
filename
)
if
os
.
path
.
isdir
(
os
.
path
.
join
(
dest_directory
,
'cifar-10-batches-py'
)):
logger
.
info
(
"Found cifar10 data in {}."
.
format
(
dest_directory
))
return
else
:
def
_progress
(
count
,
block_size
,
total_size
):
sys
.
stdout
.
write
(
'
\r
>> Downloading
%
s
%.1
f
%%
'
%
(
filepath
,
float
(
count
*
block_size
)
/
float
(
total_size
)
*
100.0
))
sys
.
stdout
.
flush
()
filepath
,
_
=
urllib
.
request
.
urlretrieve
(
DATA_URL
,
filepath
,
reporthook
=
_progress
)
print
()
statinfo
=
os
.
stat
(
filepath
)
print
(
'Succesfully downloaded'
,
filename
,
statinfo
.
st_size
,
'bytes.'
)
download
(
URL
,
dest_directory
)
filename
=
DATA_URL
.
split
(
'/'
)[
-
1
]
filepath
=
os
.
path
.
join
(
dest_directory
,
filename
)
tarfile
.
open
(
filepath
,
'r:gz'
)
.
extractall
(
dest_directory
)
def
read_cifar10
(
filenames
):
...
...
tensorpack/dataflow/dataset/ilsvrc.py
0 → 100644
View file @
60e52b94
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# File: ilsvrc.py
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
import
os
import
tarfile
from
...utils.fs
import
mkdir_p
,
download
__all__
=
[
'ILSVRCMeta'
]
CAFFE_URL
=
"http://dl.caffe.berkeleyvision.org/caffe_ilsvrc12.tar.gz"
class
ILSVRCMeta
(
object
):
def
__init__
(
self
,
dir
=
None
):
if
dir
is
None
:
dir
=
os
.
path
.
join
(
os
.
path
.
dirname
(
__file__
),
'ilsvrc_metadata'
)
self
.
dir
=
dir
mkdir_p
(
self
.
dir
)
def
get_synset_words_1000
(
self
):
fname
=
os
.
path
.
join
(
self
.
dir
,
'synset_words.txt'
)
if
not
os
.
path
.
isfile
(
fname
):
self
.
download_caffe_meta
()
assert
os
.
path
.
isfile
(
fname
)
lines
=
[
x
.
strip
()
for
x
in
open
(
fname
)
.
readlines
()]
return
dict
(
enumerate
(
lines
))
def
download_caffe_meta
(
self
):
fpath
=
download
(
CAFFE_URL
,
self
.
dir
)
tarfile
.
open
(
fpath
,
'r:gz'
)
.
extractall
(
self
.
dir
)
if
__name__
==
'__main__'
:
meta
=
ILSVRCMeta
()
print
meta
.
get_synset_words_1000
()
tensorpack/dataflow/dataset/mnist.py
View file @
60e52b94
...
...
@@ -10,6 +10,7 @@ import numpy
from
six.moves
import
urllib
,
range
from
...utils
import
logger
from
...utils.fs
import
download
from
..base
import
DataFlow
__all__
=
[
'Mnist'
]
...
...
@@ -20,14 +21,10 @@ SOURCE_URL = 'http://yann.lecun.com/exdb/mnist/'
def
maybe_download
(
filename
,
work_directory
):
"""Download the data from Yann's website, unless it's already here."""
if
not
os
.
path
.
exists
(
work_directory
):
os
.
mkdir
(
work_directory
)
filepath
=
os
.
path
.
join
(
work_directory
,
filename
)
if
not
os
.
path
.
exists
(
filepath
):
logger
.
info
(
"Downloading mnist data to {}..."
.
format
(
filepath
))
filepath
,
_
=
urllib
.
request
.
urlretrieve
(
SOURCE_URL
+
filename
,
filepath
)
statinfo
=
os
.
stat
(
filepath
)
logger
.
info
(
'Successfully downloaded to '
+
filename
)
download
(
SOURCE_URL
+
filename
,
work_directory
)
return
filepath
def
_read32
(
bytestream
):
...
...
tensorpack/dataflow/dataset/svhn.py
View file @
60e52b94
...
...
@@ -31,7 +31,7 @@ class SVHNDigit(DataFlow):
assert
name
in
[
'train'
,
'test'
,
'extra'
],
name
filename
=
os
.
path
.
join
(
data_dir
,
name
+
'_32x32.mat'
)
assert
os
.
path
.
isfile
(
filename
),
\
"File {} not found!
D
ownload it from
\
"File {} not found!
Please d
ownload it from
\
http://ufldl.stanford.edu/housenumbers/"
.
format
(
filename
)
logger
.
info
(
"Loading {} ..."
.
format
(
filename
))
data
=
scipy
.
io
.
loadmat
(
filename
)
...
...
tensorpack/utils/fs.py
View file @
60e52b94
...
...
@@ -3,7 +3,8 @@
# File: fs.py
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
import
os
import
os
,
sys
from
six.moves
import
urllib
def
mkdir_p
(
dirname
):
assert
dirname
is
not
None
...
...
@@ -15,3 +16,21 @@ def mkdir_p(dirname):
if
e
.
errno
!=
17
:
raise
e
def
download
(
url
,
dir
):
mkdir_p
(
dir
)
fname
=
url
.
split
(
'/'
)[
-
1
]
fpath
=
os
.
path
.
join
(
dir
,
fname
)
def
_progress
(
count
,
block_size
,
total_size
):
sys
.
stdout
.
write
(
'
\r
>> Downloading
%
s
%.1
f
%%
'
%
(
fname
,
float
(
count
*
block_size
)
/
float
(
total_size
)
*
100.0
))
sys
.
stdout
.
flush
()
fpath
,
_
=
urllib
.
request
.
urlretrieve
(
url
,
fpath
,
reporthook
=
_progress
)
statinfo
=
os
.
stat
(
fpath
)
sys
.
stdout
.
write
(
'
\n
'
)
print
(
'Succesfully downloaded '
+
fname
+
" "
+
str
(
statinfo
.
st_size
)
+
' bytes.'
)
return
fpath
if
__name__
==
'__main__'
:
download
(
'http://dl.caffe.berkeleyvision.org/caffe_ilsvrc12.tar.gz'
,
'.'
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment