Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
5102a8f3
Commit
5102a8f3
authored
Dec 27, 2015
by
ppwwyyxx
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
type check, dataflow base
parent
63e3a42f
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
126 additions
and
99 deletions
+126
-99
dataflow/base.py
dataflow/base.py
+23
-0
dataflow/batch.py
dataflow/batch.py
+2
-1
dataflow/dataset/mnist.py
dataflow/dataset/mnist.py
+85
-89
train.py
train.py
+4
-1
utils/callback.py
utils/callback.py
+3
-0
utils/logger.py
utils/logger.py
+9
-8
No files found.
dataflow/base.py
0 → 100644
View file @
5102a8f3
#!/usr/bin/env python2
# -*- coding: UTF-8 -*-
# File: base.py
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
from
abc
import
abstractmethod
__all__
=
[
'DataFlow'
]
class
DataFlow
(
object
):
@
abstractmethod
def
get_data
(
self
):
"""
A generator to generate data as tuple.
"""
@
abstractmethod
def
size
(
self
):
"""
Size of this data flow.
"""
dataflow/batch.py
View file @
5102a8f3
...
...
@@ -4,10 +4,11 @@
# Author: Yuxin Wu <ppwwyyxx@gmail.com>
import
numpy
as
np
from
.base
import
DataFlow
__all__
=
[
'BatchData'
]
class
BatchData
(
object
):
class
BatchData
(
DataFlow
):
def
__init__
(
self
,
ds
,
batch_size
,
remainder
=
False
):
"""
Args:
...
...
dataflow/dataset/mnist.py
View file @
5102a8f3
...
...
@@ -8,9 +8,10 @@ import gzip
import
numpy
from
six.moves
import
urllib
from
utils
import
logger
from
..base
import
DataFlow
__all__
=
[
'Mnist'
]
SOURCE_URL
=
'http://yann.lecun.com/exdb/mnist/'
...
...
@@ -28,97 +29,71 @@ def maybe_download(filename, work_directory):
return
filepath
def
_read32
(
bytestream
):
dt
=
numpy
.
dtype
(
numpy
.
uint32
)
.
newbyteorder
(
'>'
)
return
numpy
.
frombuffer
(
bytestream
.
read
(
4
),
dtype
=
dt
)[
0
]
dt
=
numpy
.
dtype
(
numpy
.
uint32
)
.
newbyteorder
(
'>'
)
return
numpy
.
frombuffer
(
bytestream
.
read
(
4
),
dtype
=
dt
)[
0
]
def
extract_images
(
filename
):
"""Extract the images into a 4D uint8 numpy array [index, y, x, depth]."""
with
gzip
.
open
(
filename
)
as
bytestream
:
magic
=
_read32
(
bytestream
)
if
magic
!=
2051
:
raise
ValueError
(
'Invalid magic number
%
d in MNIST image file:
%
s'
%
(
magic
,
filename
))
num_images
=
_read32
(
bytestream
)
rows
=
_read32
(
bytestream
)
cols
=
_read32
(
bytestream
)
buf
=
bytestream
.
read
(
rows
*
cols
*
num_images
)
data
=
numpy
.
frombuffer
(
buf
,
dtype
=
numpy
.
uint8
)
data
=
data
.
reshape
(
num_images
,
rows
,
cols
,
1
)
return
data
"""Extract the images into a 4D uint8 numpy array [index, y, x, depth]."""
with
gzip
.
open
(
filename
)
as
bytestream
:
magic
=
_read32
(
bytestream
)
if
magic
!=
2051
:
raise
ValueError
(
'Invalid magic number
%
d in MNIST image file:
%
s'
%
(
magic
,
filename
))
num_images
=
_read32
(
bytestream
)
rows
=
_read32
(
bytestream
)
cols
=
_read32
(
bytestream
)
buf
=
bytestream
.
read
(
rows
*
cols
*
num_images
)
data
=
numpy
.
frombuffer
(
buf
,
dtype
=
numpy
.
uint8
)
data
=
data
.
reshape
(
num_images
,
rows
,
cols
,
1
)
return
data
def
extract_labels
(
filename
):
"""Extract the labels into a 1D uint8 numpy array [index]."""
with
gzip
.
open
(
filename
)
as
bytestream
:
magic
=
_read32
(
bytestream
)
if
magic
!=
2049
:
raise
ValueError
(
'Invalid magic number
%
d in MNIST label file:
%
s'
%
(
magic
,
filename
))
num_items
=
_read32
(
bytestream
)
buf
=
bytestream
.
read
(
num_items
)
labels
=
numpy
.
frombuffer
(
buf
,
dtype
=
numpy
.
uint8
)
return
labels
"""Extract the labels into a 1D uint8 numpy array [index]."""
with
gzip
.
open
(
filename
)
as
bytestream
:
magic
=
_read32
(
bytestream
)
if
magic
!=
2049
:
raise
ValueError
(
'Invalid magic number
%
d in MNIST label file:
%
s'
%
(
magic
,
filename
))
num_items
=
_read32
(
bytestream
)
buf
=
bytestream
.
read
(
num_items
)
labels
=
numpy
.
frombuffer
(
buf
,
dtype
=
numpy
.
uint8
)
return
labels
class
DataSet
(
object
):
def
__init__
(
self
,
images
,
labels
,
fake_data
=
False
):
"""Construct a DataSet. """
assert
images
.
shape
[
0
]
==
labels
.
shape
[
0
],
(
'images.shape:
%
s labels.shape:
%
s'
%
(
images
.
shape
,
labels
.
shape
))
self
.
_num_examples
=
images
.
shape
[
0
]
# Convert shape from [num examples, rows, columns, depth]
# to [num examples, rows*columns] (assuming depth == 1)
assert
images
.
shape
[
3
]
==
1
images
=
images
.
reshape
(
images
.
shape
[
0
],
images
.
shape
[
1
]
*
images
.
shape
[
2
])
# Convert from [0, 255] -> [0.0, 1.0].
images
=
images
.
astype
(
numpy
.
float32
)
images
=
numpy
.
multiply
(
images
,
1.0
/
255.0
)
self
.
_images
=
images
self
.
_labels
=
labels
@
property
def
images
(
self
):
return
self
.
_images
@
property
def
labels
(
self
):
return
self
.
_labels
@
property
def
num_examples
(
self
):
return
self
.
_num_examples
def
read_data_sets
(
train_dir
):
class
DataSets
(
object
):
pass
data_sets
=
DataSets
()
TRAIN_IMAGES
=
'train-images-idx3-ubyte.gz'
TRAIN_LABELS
=
'train-labels-idx1-ubyte.gz'
TEST_IMAGES
=
't10k-images-idx3-ubyte.gz'
TEST_LABELS
=
't10k-labels-idx1-ubyte.gz'
local_file
=
maybe_download
(
TRAIN_IMAGES
,
train_dir
)
train_images
=
extract_images
(
local_file
)
local_file
=
maybe_download
(
TRAIN_LABELS
,
train_dir
)
train_labels
=
extract_labels
(
local_file
)
local_file
=
maybe_download
(
TEST_IMAGES
,
train_dir
)
test_images
=
extract_images
(
local_file
)
local_file
=
maybe_download
(
TEST_LABELS
,
train_dir
)
test_labels
=
extract_labels
(
local_file
)
data_sets
.
train
=
DataSet
(
train_images
,
train_labels
)
data_sets
.
test
=
DataSet
(
test_images
,
test_labels
)
return
data_sets
class
Mnist
(
object
):
def
__init__
(
self
,
images
,
labels
,
fake_data
=
False
):
"""Construct a DataSet. """
assert
images
.
shape
[
0
]
==
labels
.
shape
[
0
],
(
'images.shape:
%
s labels.shape:
%
s'
%
(
images
.
shape
,
labels
.
shape
))
self
.
_num_examples
=
images
.
shape
[
0
]
# Convert shape from [num examples, rows, columns, depth]
# to [num examples, rows*columns] (assuming depth == 1)
assert
images
.
shape
[
3
]
==
1
images
=
images
.
reshape
(
images
.
shape
[
0
],
images
.
shape
[
1
]
*
images
.
shape
[
2
])
# Convert from [0, 255] -> [0.0, 1.0].
images
=
images
.
astype
(
numpy
.
float32
)
images
=
numpy
.
multiply
(
images
,
1.0
/
255.0
)
self
.
_images
=
images
self
.
_labels
=
labels
@
property
def
images
(
self
):
return
self
.
_images
@
property
def
labels
(
self
):
return
self
.
_labels
@
property
def
num_examples
(
self
):
return
self
.
_num_examples
class
Mnist
(
DataFlow
):
def
__init__
(
self
,
train_or_test
,
dir
=
None
):
"""
Args:
...
...
@@ -126,15 +101,35 @@ class Mnist(object):
"""
if
dir
is
None
:
dir
=
os
.
path
.
join
(
os
.
path
.
dirname
(
__file__
),
'mnist_data'
)
self
.
dataset
=
read_data_sets
(
dir
)
assert
train_or_test
in
[
'train'
,
'test'
]
self
.
train_or_test
=
train_or_test
TRAIN_IMAGES
=
'train-images-idx3-ubyte.gz'
TRAIN_LABELS
=
'train-labels-idx1-ubyte.gz'
TEST_IMAGES
=
't10k-images-idx3-ubyte.gz'
TEST_LABELS
=
't10k-labels-idx1-ubyte.gz'
local_file
=
maybe_download
(
TRAIN_IMAGES
,
dir
)
train_images
=
extract_images
(
local_file
)
local_file
=
maybe_download
(
TRAIN_LABELS
,
dir
)
train_labels
=
extract_labels
(
local_file
)
local_file
=
maybe_download
(
TEST_IMAGES
,
dir
)
test_images
=
extract_images
(
local_file
)
local_file
=
maybe_download
(
TEST_LABELS
,
dir
)
test_labels
=
extract_labels
(
local_file
)
self
.
train
=
DataSet
(
train_images
,
train_labels
)
self
.
test
=
DataSet
(
test_images
,
test_labels
)
def
size
(
self
):
ds
=
self
.
dataset
.
train
if
self
.
train_or_test
==
'train'
else
self
.
dataset
.
test
ds
=
self
.
train
if
self
.
train_or_test
==
'train'
else
self
.
test
return
ds
.
num_examples
def
get_data
(
self
):
ds
=
self
.
dataset
.
train
if
self
.
train_or_test
==
'train'
else
self
.
dataset
.
test
ds
=
self
.
train
if
self
.
train_or_test
==
'train'
else
self
.
test
for
k
in
xrange
(
ds
.
num_examples
):
img
=
ds
.
images
[
k
]
.
reshape
((
28
,
28
))
label
=
ds
.
labels
[
k
]
...
...
@@ -144,4 +139,5 @@ if __name__ == '__main__':
ds
=
Mnist
(
'train'
)
for
(
img
,
label
)
in
ds
.
get_data
():
from
IPython
import
embed
;
embed
()
break
train.py
View file @
5102a8f3
...
...
@@ -5,6 +5,7 @@
import
tensorflow
as
tf
from
utils
import
*
from
dataflow
import
DataFlow
from
itertools
import
count
def
prepare
():
...
...
@@ -20,17 +21,19 @@ def start_train(config):
Args:
config: a tensorpack config dictionary
"""
# a Dataflow instance
dataset_train
=
config
[
'dataset_train'
]
assert
isinstance
(
dataset_train
,
DataFlow
),
dataset_train
.
__class__
# a tf.train.Optimizer instance
optimizer
=
config
[
'optimizer'
]
assert
isinstance
(
optimizer
,
tf
.
train
.
Optimizer
),
optimizer
.
__class__
# a list of Callback instance
callbacks
=
Callbacks
(
config
.
get
(
'callbacks'
,
[]))
# a tf.ConfigProto instance
sess_config
=
config
.
get
(
'session_config'
,
None
)
assert
isinstance
(
sess_config
,
tf
.
ConfigProto
),
sess_config
.
__class__
# a list of input/output variables
input_vars
=
config
[
'inputs'
]
...
...
utils/callback.py
View file @
5102a8f3
...
...
@@ -89,6 +89,9 @@ class SummaryWriter(Callback):
class
Callbacks
(
Callback
):
def
__init__
(
self
,
callbacks
):
for
cb
in
callbacks
:
assert
isinstance
(
cb
,
Callback
),
cb
.
__class__
# put SummaryWriter to the first
for
idx
,
cb
in
enumerate
(
callbacks
):
if
type
(
cb
)
==
SummaryWriter
:
...
...
utils/logger.py
View file @
5102a8f3
...
...
@@ -5,6 +5,7 @@
import
logging
import
os
import
os.path
from
termcolor
import
colored
__all__
=
[]
...
...
@@ -38,14 +39,14 @@ for func in ['info', 'warning', 'error', 'critical', 'warn']:
def
set_file
(
path
):
if
os
.
path
.
isfile
(
path
):
warn
(
"File
\"
{}
\"
exists! backup? (y/n)"
.
format
(
path
))
resp
=
raw_input
(
)
i
f
resp
in
[
'y'
,
'Y'
]:
from
datetime
import
datetime
backup_name
=
path
+
datetime
.
now
()
.
strftime
(
'.
%
d-
%
H
%
M
%
S'
)
import
shutil
shutil
.
move
(
path
,
backup_name
)
info
(
"Log '{}' moved to '{}'"
.
format
(
path
,
backup_name
)
)
from
datetime
import
datetime
backup_name
=
path
+
datetime
.
now
()
.
strftime
(
'.
%
d-
%
H
%
M
%
S'
)
i
mport
shutil
shutil
.
move
(
path
,
backup_name
)
info
(
"Log file '{}' backuped to '{}'"
.
format
(
path
,
backup_name
)
)
dirname
=
os
.
path
.
dirname
(
path
)
if
not
os
.
path
.
isdir
(
dirname
):
os
.
makedirs
(
dirname
)
hdl
=
logging
.
FileHandler
(
filename
=
path
,
encoding
=
'utf-8'
,
mode
=
'w'
)
logger
.
addHandler
(
hdl
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment