Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
bbaf8d12
Commit
bbaf8d12
authored
Feb 11, 2017
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
use Variable.load to avoid assign ops
parent
d97c0081
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
17 additions
and
24 deletions
+17
-24
docs/tutorial/efficient-dataflow.md
docs/tutorial/efficient-dataflow.md
+6
-5
examples/GAN/DCGAN-CelebA.py
examples/GAN/DCGAN-CelebA.py
+2
-2
examples/ResNet/imagenet-resnet.py
examples/ResNet/imagenet-resnet.py
+0
-1
tensorpack/callbacks/param.py
tensorpack/callbacks/param.py
+1
-5
tensorpack/dataflow/dataset/ilsvrc.py
tensorpack/dataflow/dataset/ilsvrc.py
+2
-1
tensorpack/tfutils/varmanip.py
tensorpack/tfutils/varmanip.py
+6
-10
No files found.
docs/tutorial/efficient-dataflow.md
View file @
bbaf8d12
# Efficient DataFlow
# Efficient DataFlow
This tutorial gives an overview of how to build an efficient DataFlow, using ImageNet
This tutorial gives an overview of how to build an efficient DataFlow, using ImageNet
dataset as an example.
dataset as an example.
Our goal in the end is to have
Our goal in the end is to have
a generator which yields ImageNet datapoints (after proper preprocessing) as fast as possible.
a generator which yields ImageNet datapoints (after proper preprocessing) as fast as possible.
Since it is simply a generator interface, you can use the DataFlow in other frameworks (e.g. Keras)
or your own code as well.
We use ILSVRC12 training set, which contains 1.28 million images.
We use ILSVRC12 training set, which contains 1.28 million images.
Following the
[
ResNet example
](
../examples/ResNet
)
, our pre-processing need images in their original resolution,
Following the
[
ResNet example
](
../examples/ResNet
)
, our pre-processing need images in their original resolution,
...
@@ -120,7 +121,7 @@ It will generate a database file of 140G. We build a DataFlow to read the LMDB f
...
@@ -120,7 +121,7 @@ It will generate a database file of 140G. We build a DataFlow to read the LMDB f
```
```
from tensorpack import *
from tensorpack import *
ds = LMDBData('/path/to/ILSVRC-train.lmdb', shuffle=False)
ds = LMDBData('/path/to/ILSVRC-train.lmdb', shuffle=False)
ds = BatchData(ds, 256,
allow
_list=True)
ds = BatchData(ds, 256,
use
_list=True)
TestDataSpeed(ds).start_test()
TestDataSpeed(ds).start_test()
```
```
Depending on whether the OS has cached the file for you (and how large the RAM is), the above script
Depending on whether the OS has cached the file for you (and how large the RAM is), the above script
...
@@ -134,7 +135,7 @@ As a reference, on Samsung SSD 850, the uncached speed is about 16it/s.
...
@@ -134,7 +135,7 @@ As a reference, on Samsung SSD 850, the uncached speed is about 16it/s.
ds = LMDBData('/path/to/ILSVRC-train.lmdb', shuffle=False)
ds = LMDBData('/path/to/ILSVRC-train.lmdb', shuffle=False)
ds = LocallyShuffleData(ds, 50000)
ds = LocallyShuffleData(ds, 50000)
ds = BatchData(ds, 256,
allow
_list=True)
ds = BatchData(ds, 256,
use
_list=True)
```
```
Instead of shuffling all the training data in every epoch (which would require random read),
Instead of shuffling all the training data in every epoch (which would require random read),
the added line above maintains a buffer of datapoints and shuffle them once a while.
the added line above maintains a buffer of datapoints and shuffle them once a while.
...
@@ -153,7 +154,7 @@ Then we add necessary transformations:
...
@@ -153,7 +154,7 @@ Then we add necessary transformations:
ds = AugmentImageComponent(ds, lots_of_augmentors)
ds = AugmentImageComponent(ds, lots_of_augmentors)
ds = BatchData(ds, 256)
ds = BatchData(ds, 256)
```
```
1.
`LMDBData`
deserialize
d
the datapoints (from string to [jpeg_string, label])
1.
`LMDBData`
deserialize the datapoints (from string to [jpeg_string, label])
2.
Use opencv to decode the first component into ndarray
2.
Use opencv to decode the first component into ndarray
3.
Apply augmentations to the ndarray
3.
Apply augmentations to the ndarray
...
@@ -172,7 +173,7 @@ Both imdecode and the augmentors can be quite slow. We can parallelize them like
...
@@ -172,7 +173,7 @@ Both imdecode and the augmentors can be quite slow. We can parallelize them like
ds = BatchData(ds, 256)
ds = BatchData(ds, 256)
```
```
Since we are reading the database sequentially, hav
e
multiple identical instances of the
Since we are reading the database sequentially, hav
ing
multiple identical instances of the
underlying DataFlow will result in biased data distribution. Therefore we use
`PrefetchData`
to
underlying DataFlow will result in biased data distribution. Therefore we use
`PrefetchData`
to
launch the underlying DataFlow in one independent process, and only parallelize the transformations.
launch the underlying DataFlow in one independent process, and only parallelize the transformations.
(
`PrefetchDataZMQ`
is faster but not fork-safe, so the first prefetch has to be
`PrefetchData`
. This is [issue#138])
(
`PrefetchDataZMQ`
is faster but not fork-safe, so the first prefetch has to be
`PrefetchData`
. This is [issue#138])
...
...
examples/GAN/DCGAN-CelebA.py
View file @
bbaf8d12
...
@@ -20,9 +20,9 @@ from GAN import GANTrainer, RandomZData, GANModelDesc
...
@@ -20,9 +20,9 @@ from GAN import GANTrainer, RandomZData, GANModelDesc
"""
"""
DCGAN on CelebA dataset.
DCGAN on CelebA dataset.
The original code (dcgan.torch) uses kernel_shape=4, but I found the difference not significant.
1. Download the 'aligned&cropped' version of CelebA dataset.
1. Download the 'aligned&cropped' version of CelebA dataset
(or just use any directory of jpg files).
2. Start training:
2. Start training:
./DCGAN-CelebA.py --data /path/to/image_align_celeba/
./DCGAN-CelebA.py --data /path/to/image_align_celeba/
3. Visualize samples of a trained model:
3. Visualize samples of a trained model:
...
...
examples/ResNet/imagenet-resnet.py
100644 → 100755
View file @
bbaf8d12
...
@@ -132,7 +132,6 @@ def get_data(train_or_test):
...
@@ -132,7 +132,6 @@ def get_data(train_or_test):
crop 8
%
~100
%
of the original image
crop 8
%
~100
%
of the original image
See `Going Deeper with Convolutions` by Google.
See `Going Deeper with Convolutions` by Google.
"""
"""
def
_augment
(
self
,
img
,
_
):
def
_augment
(
self
,
img
,
_
):
h
,
w
=
img
.
shape
[:
2
]
h
,
w
=
img
.
shape
[:
2
]
area
=
h
*
w
area
=
h
*
w
...
...
tensorpack/callbacks/param.py
View file @
bbaf8d12
...
@@ -74,13 +74,9 @@ class GraphVarParam(HyperParam):
...
@@ -74,13 +74,9 @@ class GraphVarParam(HyperParam):
else
:
else
:
raise
ValueError
(
"{} is not a VARIABLE in the graph!"
.
format
(
self
.
var_name
))
raise
ValueError
(
"{} is not a VARIABLE in the graph!"
.
format
(
self
.
var_name
))
self
.
val_holder
=
tf
.
placeholder
(
tf
.
float32
,
shape
=
self
.
shape
,
name
=
self
.
_readable_name
+
'_feed'
)
self
.
assign_op
=
self
.
var
.
assign
(
self
.
val_holder
)
def
set_value
(
self
,
v
):
def
set_value
(
self
,
v
):
""" Assign the variable a new value. """
""" Assign the variable a new value. """
self
.
assign_op
.
eval
(
feed_dict
=
{
self
.
val_holder
:
v
}
)
self
.
var
.
load
(
v
)
def
get_value
(
self
):
def
get_value
(
self
):
""" Evaluate the variable. """
""" Evaluate the variable. """
...
...
tensorpack/dataflow/dataset/ilsvrc.py
View file @
bbaf8d12
...
@@ -146,7 +146,8 @@ class ILSVRC12(RNGDataFlow):
...
@@ -146,7 +146,8 @@ class ILSVRC12(RNGDataFlow):
mkdir train && tar xvf ILSVRC12_img_train.tar -C train && cd train
mkdir train && tar xvf ILSVRC12_img_train.tar -C train && cd train
find -type f -name '*.tar' | parallel -P 10 'echo {} && mkdir -p {/.} && tar xf {} -C {/.}'
find -type f -name '*.tar' | parallel -P 10 'echo {} && mkdir -p {/.} && tar xf {} -C {/.}'
"""
"""
assert
name
in
[
'train'
,
'test'
,
'val'
]
assert
name
in
[
'train'
,
'test'
,
'val'
],
name
assert
os
.
path
.
isdir
(
dir
),
dir
self
.
full_dir
=
os
.
path
.
join
(
dir
,
name
)
self
.
full_dir
=
os
.
path
.
join
(
dir
,
name
)
self
.
name
=
name
self
.
name
=
name
assert
os
.
path
.
isdir
(
self
.
full_dir
),
self
.
full_dir
assert
os
.
path
.
isdir
(
self
.
full_dir
),
self
.
full_dir
...
...
tensorpack/tfutils/varmanip.py
View file @
bbaf8d12
...
@@ -54,14 +54,10 @@ class SessionUpdate(object):
...
@@ -54,14 +54,10 @@ class SessionUpdate(object):
vars_to_update: a collection of variables to update
vars_to_update: a collection of variables to update
"""
"""
self
.
sess
=
sess
self
.
sess
=
sess
self
.
assign_ops
=
defaultdict
(
list
)
self
.
name_map
=
defaultdict
(
list
)
for
v
in
vars_to_update
:
for
v
in
vars_to_update
:
# p = tf.placeholder(v.dtype, shape=v.get_shape())
savename
=
get_savename_from_varname
(
v
.
name
)
with
tf
.
device
(
'/cpu:0'
):
self
.
name_map
[
savename
]
.
append
(
v
)
p
=
tf
.
placeholder
(
v
.
dtype
)
savename
=
get_savename_from_varname
(
v
.
name
)
# multiple vars might share one savename
self
.
assign_ops
[
savename
]
.
append
((
p
,
v
,
v
.
assign
(
p
)))
def
update
(
self
,
prms
):
def
update
(
self
,
prms
):
"""
"""
...
@@ -70,8 +66,8 @@ class SessionUpdate(object):
...
@@ -70,8 +66,8 @@ class SessionUpdate(object):
Any name in prms must be in the graph and in vars_to_update.
Any name in prms must be in the graph and in vars_to_update.
"""
"""
for
name
,
value
in
six
.
iteritems
(
prms
):
for
name
,
value
in
six
.
iteritems
(
prms
):
assert
name
in
self
.
assign_ops
assert
name
in
self
.
name_map
for
p
,
v
,
op
in
self
.
assign_ops
[
name
]:
for
v
in
self
.
name_map
[
name
]:
varshape
=
tuple
(
v
.
get_shape
()
.
as_list
())
varshape
=
tuple
(
v
.
get_shape
()
.
as_list
())
if
varshape
!=
value
.
shape
:
if
varshape
!=
value
.
shape
:
# TODO only allow reshape when shape different by empty axis
# TODO only allow reshape when shape different by empty axis
...
@@ -79,7 +75,7 @@ class SessionUpdate(object):
...
@@ -79,7 +75,7 @@ class SessionUpdate(object):
"{}: {}!={}"
.
format
(
name
,
varshape
,
value
.
shape
)
"{}: {}!={}"
.
format
(
name
,
varshape
,
value
.
shape
)
logger
.
warn
(
"Param {} is reshaped during assigning"
.
format
(
name
))
logger
.
warn
(
"Param {} is reshaped during assigning"
.
format
(
name
))
value
=
value
.
reshape
(
varshape
)
value
=
value
.
reshape
(
varshape
)
self
.
sess
.
run
(
op
,
feed_dict
=
{
p
:
value
}
)
v
.
load
(
value
,
session
=
self
.
sess
)
def
dump_session_params
(
path
):
def
dump_session_params
(
path
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment