Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
d5d7270a
Commit
d5d7270a
authored
May 24, 2016
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
some further simplification
parent
2d96baca
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
17 additions
and
24 deletions
+17
-24
examples/cifar-convnet.py
examples/cifar-convnet.py
+6
-9
tensorpack/dataflow/dataset/cifar.py
tensorpack/dataflow/dataset/cifar.py
+11
-14
tensorpack/train/trainer.py
tensorpack/train/trainer.py
+0
-1
No files found.
examples/cifar-convnet.py
100644 → 100755
View file @
d5d7270a
#!/usr/bin/env python
#!/usr/bin/env python
# -*- coding: UTF-8 -*-
# -*- coding: UTF-8 -*-
# File: cifar
10
-convnet.py
# File: cifar-convnet.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
import
numpy
import
numpy
import
tensorflow
as
tf
import
tensorflow
as
tf
...
@@ -19,7 +19,8 @@ from tensorpack.dataflow import *
...
@@ -19,7 +19,8 @@ from tensorpack.dataflow import *
"""
"""
A small convnet model for cifar 10 or cifar100 dataset.
A small convnet model for cifar 10 or cifar100 dataset.
90
%
validation accuracy after 40k step.
For Cifar10: 90
%
validation accuracy after 40k step.
"""
"""
class
Model
(
ModelDesc
):
class
Model
(
ModelDesc
):
...
@@ -141,7 +142,8 @@ if __name__ == '__main__':
...
@@ -141,7 +142,8 @@ if __name__ == '__main__':
parser
=
argparse
.
ArgumentParser
()
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
'--gpu'
,
help
=
'comma separated list of GPU(s) to use.'
)
# nargs='*' in multi mode
parser
.
add_argument
(
'--gpu'
,
help
=
'comma separated list of GPU(s) to use.'
)
# nargs='*' in multi mode
parser
.
add_argument
(
'--load'
,
help
=
'load model'
)
parser
.
add_argument
(
'--load'
,
help
=
'load model'
)
parser
.
add_argument
(
'--classnum'
,
help
=
'specify cifar10 or cifar100, input 10 for cifar10 or 100 for cifar100'
)
parser
.
add_argument
(
'--classnum'
,
help
=
'10 for cifar10 or 100 for cifar100'
,
type
=
int
,
default
=
10
)
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
basename
=
os
.
path
.
basename
(
__file__
)
basename
=
os
.
path
.
basename
(
__file__
)
...
@@ -153,13 +155,8 @@ if __name__ == '__main__':
...
@@ -153,13 +155,8 @@ if __name__ == '__main__':
else
:
else
:
os
.
environ
[
'CUDA_VISIBLE_DEVICES'
]
=
'0'
os
.
environ
[
'CUDA_VISIBLE_DEVICES'
]
=
'0'
if
args
.
classnum
:
cifar_classnum
=
int
(
args
.
classnum
)
else
:
cifar_classnum
=
10
with
tf
.
Graph
()
.
as_default
():
with
tf
.
Graph
()
.
as_default
():
config
=
get_config
(
cifar_
classnum
)
config
=
get_config
(
args
.
classnum
)
if
args
.
load
:
if
args
.
load
:
config
.
session_init
=
SaverRestore
(
args
.
load
)
config
.
session_init
=
SaverRestore
(
args
.
load
)
if
args
.
gpu
:
if
args
.
gpu
:
...
...
tensorpack/dataflow/dataset/cifar.py
View file @
d5d7270a
#!/usr/bin/env python2
#!/usr/bin/env python2
# -*- coding: UTF-8 -*-
# -*- coding: UTF-8 -*-
# File: cifar
10
.py
# File: cifar.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
# Yukun Chen <cykustc@gmail.com>
import
os
,
sys
import
os
,
sys
import
pickle
import
pickle
import
numpy
as
np
import
numpy
as
np
...
@@ -19,7 +21,7 @@ __all__ = ['Cifar10', 'Cifar100']
...
@@ -19,7 +21,7 @@ __all__ = ['Cifar10', 'Cifar100']
DATA_URL_CIFAR_10
=
'http://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz'
DATA_URL_CIFAR_10
=
'http://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz'
DATA_URL_CIFAR_100
=
'http
s
://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz'
DATA_URL_CIFAR_100
=
'http://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz'
def
maybe_download_and_extract
(
dest_directory
,
cifar_classnum
):
def
maybe_download_and_extract
(
dest_directory
,
cifar_classnum
):
"""Download and extract the tarball from Alex's website.
"""Download and extract the tarball from Alex's website.
...
@@ -52,7 +54,7 @@ def read_cifar(filenames, cifar_classnum):
...
@@ -52,7 +54,7 @@ def read_cifar(filenames, cifar_classnum):
data
=
dic
[
b
'data'
]
data
=
dic
[
b
'data'
]
if
cifar_classnum
==
10
:
if
cifar_classnum
==
10
:
label
=
dic
[
b
'labels'
]
label
=
dic
[
b
'labels'
]
IMG_NUM
=
10000
IMG_NUM
=
10000
# cifar10 data are split into blocks of 10000
elif
cifar_classnum
==
100
:
elif
cifar_classnum
==
100
:
label
=
dic
[
b
'fine_labels'
]
label
=
dic
[
b
'fine_labels'
]
IMG_NUM
=
50000
if
'train'
in
fname
else
10000
IMG_NUM
=
50000
if
'train'
in
fname
else
10000
...
@@ -71,10 +73,8 @@ def get_filenames(dir, cifar_classnum):
...
@@ -71,10 +73,8 @@ def get_filenames(dir, cifar_classnum):
filenames
.
append
(
os
.
path
.
join
(
filenames
.
append
(
os
.
path
.
join
(
dir
,
'cifar-10-batches-py'
,
'test_batch'
))
dir
,
'cifar-10-batches-py'
,
'test_batch'
))
elif
cifar_classnum
==
100
:
elif
cifar_classnum
==
100
:
filenames
=
[
os
.
path
.
join
(
filenames
=
[
os
.
path
.
join
(
dir
,
'cifar-100-python'
,
'train'
),
dir
,
'cifar-100-python'
,
'train'
)]
os
.
path
.
join
(
dir
,
'cifar-100-python'
,
'test'
)]
filenames
.
append
(
os
.
path
.
join
(
dir
,
'cifar-100-python'
,
'test'
))
return
filenames
return
filenames
class
CifarBase
(
DataFlow
):
class
CifarBase
(
DataFlow
):
...
@@ -92,15 +92,12 @@ class CifarBase(DataFlow):
...
@@ -92,15 +92,12 @@ class CifarBase(DataFlow):
assert
cifar_classnum
==
10
or
cifar_classnum
==
100
assert
cifar_classnum
==
10
or
cifar_classnum
==
100
self
.
cifar_classnum
=
cifar_classnum
self
.
cifar_classnum
=
cifar_classnum
if
dir
is
None
:
if
dir
is
None
:
dir
=
os
.
path
.
join
(
os
.
path
.
dirname
(
__file__
),
'cifar-10-batches-py'
dir
=
os
.
path
.
join
(
os
.
path
.
dirname
(
__file__
),
if
cifar_classnum
==
10
else
'cifar100_data'
)
'cifar{}_data'
.
format
(
cifar_classnum
)
)
maybe_download_and_extract
(
dir
,
self
.
cifar_classnum
)
maybe_download_and_extract
(
dir
,
self
.
cifar_classnum
)
if
self
.
cifar_classnum
==
10
:
fnames
=
get_filenames
(
dir
,
cifar_classnum
)
fnames
=
get_filenames
(
dir
,
10
)
else
:
fnames
=
get_filenames
(
dir
,
100
)
if
train_or_test
==
'train'
:
if
train_or_test
==
'train'
:
self
.
fs
=
fnames
[:
5
]
if
cifar_classnum
==
10
else
fnames
[:
1
]
self
.
fs
=
fnames
[:
-
1
]
else
:
else
:
self
.
fs
=
[
fnames
[
-
1
]]
self
.
fs
=
[
fnames
[
-
1
]]
for
f
in
self
.
fs
:
for
f
in
self
.
fs
:
...
...
tensorpack/train/trainer.py
View file @
d5d7270a
...
@@ -125,7 +125,6 @@ class QueueInputTrainer(Trainer):
...
@@ -125,7 +125,6 @@ class QueueInputTrainer(Trainer):
def
_get_model_inputs
(
self
):
def
_get_model_inputs
(
self
):
""" Dequeue a datapoint from input_queue and return"""
""" Dequeue a datapoint from input_queue and return"""
ret
=
self
.
input_queue
.
dequeue
(
name
=
'input_deque'
)
ret
=
self
.
input_queue
.
dequeue
(
name
=
'input_deque'
)
print
ret
if
isinstance
(
ret
,
tf
.
Tensor
):
# only one input
if
isinstance
(
ret
,
tf
.
Tensor
):
# only one input
ret
=
[
ret
]
ret
=
[
ret
]
assert
len
(
ret
)
==
len
(
self
.
input_vars
)
assert
len
(
ret
)
==
len
(
self
.
input_vars
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment