Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
d8935ef3
Commit
d8935ef3
authored
Jan 05, 2017
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
update sphinx doc for dataflow/
parent
edecca96
Changes
27
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
27 changed files
with
583 additions
and
373 deletions
+583
-373
docs/conf.py
docs/conf.py
+5
-6
tensorpack/dataflow/base.py
tensorpack/dataflow/base.py
+18
-13
tensorpack/dataflow/common.py
tensorpack/dataflow/common.py
+140
-91
tensorpack/dataflow/dataset/bsds500.py
tensorpack/dataflow/dataset/bsds500.py
+8
-8
tensorpack/dataflow/dataset/cifar.py
tensorpack/dataflow/dataset/cifar.py
+11
-12
tensorpack/dataflow/dataset/ilsvrc.py
tensorpack/dataflow/dataset/ilsvrc.py
+35
-31
tensorpack/dataflow/dataset/mnist.py
tensorpack/dataflow/dataset/mnist.py
+4
-3
tensorpack/dataflow/dataset/svhn.py
tensorpack/dataflow/dataset/svhn.py
+6
-4
tensorpack/dataflow/dataset/visualqa.py
tensorpack/dataflow/dataset/visualqa.py
+2
-4
tensorpack/dataflow/dftools.py
tensorpack/dataflow/dftools.py
+30
-19
tensorpack/dataflow/format.py
tensorpack/dataflow/format.py
+45
-15
tensorpack/dataflow/image.py
tensorpack/dataflow/image.py
+20
-15
tensorpack/dataflow/imgaug/base.py
tensorpack/dataflow/imgaug/base.py
+10
-4
tensorpack/dataflow/imgaug/crop.py
tensorpack/dataflow/imgaug/crop.py
+27
-17
tensorpack/dataflow/imgaug/deform.py
tensorpack/dataflow/imgaug/deform.py
+13
-10
tensorpack/dataflow/imgaug/geometry.py
tensorpack/dataflow/imgaug/geometry.py
+16
-5
tensorpack/dataflow/imgaug/imgproc.py
tensorpack/dataflow/imgaug/imgproc.py
+43
-20
tensorpack/dataflow/imgaug/meta.py
tensorpack/dataflow/imgaug/meta.py
+20
-7
tensorpack/dataflow/imgaug/noise.py
tensorpack/dataflow/imgaug/noise.py
+17
-4
tensorpack/dataflow/imgaug/noname.py
tensorpack/dataflow/imgaug/noname.py
+24
-16
tensorpack/dataflow/imgaug/paste.py
tensorpack/dataflow/imgaug/paste.py
+10
-6
tensorpack/dataflow/prefetch.py
tensorpack/dataflow/prefetch.py
+31
-17
tensorpack/dataflow/raw.py
tensorpack/dataflow/raw.py
+17
-33
tensorpack/dataflow/remote.py
tensorpack/dataflow/remote.py
+25
-9
tensorpack/dataflow/tf_func.py
tensorpack/dataflow/tf_func.py
+3
-1
tensorpack/models/model_desc.py
tensorpack/models/model_desc.py
+1
-1
tensorpack/models/regularize.py
tensorpack/models/regularize.py
+2
-2
No files found.
docs/conf.py
View file @
d8935ef3
...
@@ -69,9 +69,11 @@ extensions = [
...
@@ -69,9 +69,11 @@ extensions = [
#'sphinx.ext.coverage',
#'sphinx.ext.coverage',
#'sphinx.ext.mathjax',
#'sphinx.ext.mathjax',
'sphinx.ext.mathbase'
,
'sphinx.ext.mathbase'
,
'sphinx.ext.intersphinx'
,
'sphinx.ext.viewcode'
,
'sphinx.ext.viewcode'
,
]
]
napoleon_google_docstring
=
True
napoleon_google_docstring
=
True
napoleon_include_init_with_doc
=
True
napoleon_numpy_docstring
=
False
napoleon_numpy_docstring
=
False
napoleon_use_rtype
=
False
napoleon_use_rtype
=
False
...
@@ -332,11 +334,9 @@ texinfo_documents = [
...
@@ -332,11 +334,9 @@ texinfo_documents = [
# If true, do not generate a @detailmenu in the "Top" node's menu.
# If true, do not generate a @detailmenu in the "Top" node's menu.
#texinfo_no_detailmenu = False
#texinfo_no_detailmenu = False
def
skip
(
app
,
what
,
name
,
obj
,
skip
,
options
):
# keep __init__
intersphinx_timeout
=
0.1
if
name
==
"__init__"
:
intersphinx_mapping
=
{
'python'
:
(
'https://docs.python.org/3.4'
,
None
)}
return
False
return
skip
def
process_signature
(
app
,
what
,
name
,
obj
,
options
,
signature
,
def
process_signature
(
app
,
what
,
name
,
obj
,
options
,
signature
,
return_annotation
):
return_annotation
):
...
@@ -350,7 +350,6 @@ def process_signature(app, what, name, obj, options, signature,
...
@@ -350,7 +350,6 @@ def process_signature(app, what, name, obj, options, signature,
def
setup
(
app
):
def
setup
(
app
):
from
recommonmark.transform
import
AutoStructify
from
recommonmark.transform
import
AutoStructify
app
.
connect
(
'autodoc-process-signature'
,
process_signature
)
app
.
connect
(
'autodoc-process-signature'
,
process_signature
)
app
.
connect
(
"autodoc-skip-member"
,
skip
)
app
.
add_config_value
(
app
.
add_config_value
(
'recommonmark_config'
,
'recommonmark_config'
,
{
'url_resolver'
:
lambda
url
:
\
{
'url_resolver'
:
lambda
url
:
\
...
...
tensorpack/dataflow/base.py
View file @
d8935ef3
...
@@ -15,37 +15,41 @@ __all__ = ['DataFlow', 'ProxyDataFlow', 'RNGDataFlow']
...
@@ -15,37 +15,41 @@ __all__ = ['DataFlow', 'ProxyDataFlow', 'RNGDataFlow']
class
DataFlow
(
object
):
class
DataFlow
(
object
):
""" Base class for all DataFlow """
""" Base class for all DataFlow """
class
Infinity
:
pass
@
abstractmethod
@
abstractmethod
def
get_data
(
self
):
def
get_data
(
self
):
"""
"""
A generator to generate data as a list.
The method to generate datapoints.
Datapoint should be a mutable list.
Each component should be assumed immutable.
Yields:
list: The datapoint, i.e. list of components.
"""
"""
def
size
(
self
):
def
size
(
self
):
"""
"""
Size of this data flow.
Returns:
int: size of this data flow.
Raises:
:class:`NotImplementedError` if this DataFlow doesn't have a size.
"""
"""
raise
NotImplementedError
()
raise
NotImplementedError
()
def
reset_state
(
self
):
def
reset_state
(
self
):
"""
"""
Reset state of the dataflow. Will always be called before consuming data points.
Reset state of the dataflow. It has to be called before producing datapoints.
for example, RNG **HAS** to be reset here if used in the DataFlow.
Otherwise it may not work well with prefetching, because different
For example, RNG **has to** be reset if used in the DataFlow,
otherwise it won't work well with prefetching, because different
processes will have the same RNG state.
processes will have the same RNG state.
"""
"""
pass
pass
class
RNGDataFlow
(
DataFlow
):
class
RNGDataFlow
(
DataFlow
):
""" A
dataflow with rng
"""
""" A
DataFlow with RNG
"""
def
reset_state
(
self
):
def
reset_state
(
self
):
""" Reset the RNG """
self
.
rng
=
get_rng
(
self
)
self
.
rng
=
get_rng
(
self
)
...
@@ -54,13 +58,14 @@ class ProxyDataFlow(DataFlow):
...
@@ -54,13 +58,14 @@ class ProxyDataFlow(DataFlow):
def
__init__
(
self
,
ds
):
def
__init__
(
self
,
ds
):
"""
"""
:param ds: a :mod:`DataFlow` instance to proxy
Args:
ds (DataFlow): DataFlow to proxy.
"""
"""
self
.
ds
=
ds
self
.
ds
=
ds
def
reset_state
(
self
):
def
reset_state
(
self
):
"""
"""
Will reset state of the proxied DataFlow
Reset state of the proxied DataFlow.
"""
"""
self
.
ds
.
reset_state
()
self
.
ds
.
reset_state
()
...
...
tensorpack/dataflow/common.py
View file @
d8935ef3
This diff is collapsed.
Click to expand it.
tensorpack/dataflow/dataset/bsds500.py
View file @
d8935ef3
...
@@ -25,20 +25,20 @@ IMG_W, IMG_H = 481, 321
...
@@ -25,20 +25,20 @@ IMG_W, IMG_H = 481, 321
class
BSDS500
(
RNGDataFlow
):
class
BSDS500
(
RNGDataFlow
):
"""
"""
`Berkeley Segmentation Data Set and Benchmarks 500
`Berkeley Segmentation Data Set and Benchmarks 500
dataset
<http://www.eecs.berkeley.edu/Research/Projects/CS/vision/grouping/resources.html#bsds500>`_.
<http://www.eecs.berkeley.edu/Research/Projects/CS/vision/grouping/resources.html#bsds500>`_.
Produce (image, label) pair, where image has shape (321, 481, 3) and
Produce ``(image, label)`` pair, where ``image`` has shape (321, 481, 3(BGR)) and
ranges in [0,255]. Label is binary and has shape (321, 481).
ranges in [0,255].
Those pixels annotated as boundaries by <=2 annotators are set to 0.
``Label`` is a floating point image of shape (321, 481) in range [0, 1].
This is used in `Holistically-Nested Edge Detection
The value of each pixel is ``number of times it is annotated as edge / total number of annotators for this image``.
<http://arxiv.org/abs/1504.06375>`_.
"""
"""
def
__init__
(
self
,
name
,
data_dir
=
None
,
shuffle
=
True
):
def
__init__
(
self
,
name
,
data_dir
=
None
,
shuffle
=
True
):
"""
"""
:param name: 'train', 'test', 'val'
Args:
:param data_dir: a directory containing the original 'BSR' directory.
name (str): 'train', 'test', 'val'
data_dir (str): a directory containing the original 'BSR' directory.
"""
"""
# check and download data
# check and download data
if
data_dir
is
None
:
if
data_dir
is
None
:
...
...
tensorpack/dataflow/dataset/cifar.py
View file @
d8935ef3
...
@@ -80,17 +80,7 @@ def get_filenames(dir, cifar_classnum):
...
@@ -80,17 +80,7 @@ def get_filenames(dir, cifar_classnum):
class
CifarBase
(
RNGDataFlow
):
class
CifarBase
(
RNGDataFlow
):
"""
Return [image, label],
image is 32x32x3 in the range [0,255]
"""
def
__init__
(
self
,
train_or_test
,
shuffle
=
True
,
dir
=
None
,
cifar_classnum
=
10
):
def
__init__
(
self
,
train_or_test
,
shuffle
=
True
,
dir
=
None
,
cifar_classnum
=
10
):
"""
Args:
train_or_test: string either 'train' or 'test'
shuffle: default to True
"""
assert
train_or_test
in
[
'train'
,
'test'
]
assert
train_or_test
in
[
'train'
,
'test'
]
assert
cifar_classnum
==
10
or
cifar_classnum
==
100
assert
cifar_classnum
==
10
or
cifar_classnum
==
100
self
.
cifar_classnum
=
cifar_classnum
self
.
cifar_classnum
=
cifar_classnum
...
@@ -139,13 +129,22 @@ class CifarBase(RNGDataFlow):
...
@@ -139,13 +129,22 @@ class CifarBase(RNGDataFlow):
class
Cifar10
(
CifarBase
):
class
Cifar10
(
CifarBase
):
"""
Produces [image, label] in Cifar10 dataset,
image is 32x32x3 in the range [0,255].
label is an int.
"""
def
__init__
(
self
,
train_or_test
,
shuffle
=
True
,
dir
=
None
):
def
__init__
(
self
,
train_or_test
,
shuffle
=
True
,
dir
=
None
):
"""
Args:
train_or_test (str): either 'train' or 'test'.
shuffle (bool): shuffle the dataset.
"""
super
(
Cifar10
,
self
)
.
__init__
(
train_or_test
,
shuffle
,
dir
,
10
)
super
(
Cifar10
,
self
)
.
__init__
(
train_or_test
,
shuffle
,
dir
,
10
)
class
Cifar100
(
CifarBase
):
class
Cifar100
(
CifarBase
):
""" Similar to Cifar10"""
def
__init__
(
self
,
train_or_test
,
shuffle
=
True
,
dir
=
None
):
def
__init__
(
self
,
train_or_test
,
shuffle
=
True
,
dir
=
None
):
super
(
Cifar100
,
self
)
.
__init__
(
train_or_test
,
shuffle
,
dir
,
100
)
super
(
Cifar100
,
self
)
.
__init__
(
train_or_test
,
shuffle
,
dir
,
100
)
...
...
tensorpack/dataflow/dataset/ilsvrc.py
View file @
d8935ef3
...
@@ -22,7 +22,7 @@ CAFFE_ILSVRC12_URL = "http://dl.caffe.berkeleyvision.org/caffe_ilsvrc12.tar.gz"
...
@@ -22,7 +22,7 @@ CAFFE_ILSVRC12_URL = "http://dl.caffe.berkeleyvision.org/caffe_ilsvrc12.tar.gz"
class
ILSVRCMeta
(
object
):
class
ILSVRCMeta
(
object
):
"""
"""
Some
metadata for ILSVRC dataset.
Provide methods to access
metadata for ILSVRC dataset.
"""
"""
def
__init__
(
self
,
dir
=
None
):
def
__init__
(
self
,
dir
=
None
):
...
@@ -37,7 +37,8 @@ class ILSVRCMeta(object):
...
@@ -37,7 +37,8 @@ class ILSVRCMeta(object):
def
get_synset_words_1000
(
self
):
def
get_synset_words_1000
(
self
):
"""
"""
:returns a dict of {cls_number: cls_name}
Returns:
dict: {cls_number: cls_name}
"""
"""
fname
=
os
.
path
.
join
(
self
.
dir
,
'synset_words.txt'
)
fname
=
os
.
path
.
join
(
self
.
dir
,
'synset_words.txt'
)
assert
os
.
path
.
isfile
(
fname
)
assert
os
.
path
.
isfile
(
fname
)
...
@@ -46,7 +47,8 @@ class ILSVRCMeta(object):
...
@@ -46,7 +47,8 @@ class ILSVRCMeta(object):
def
get_synset_1000
(
self
):
def
get_synset_1000
(
self
):
"""
"""
:returns a dict of {cls_number: synset_id}
Returns:
dict: {cls_number: synset_id}
"""
"""
fname
=
os
.
path
.
join
(
self
.
dir
,
'synsets.txt'
)
fname
=
os
.
path
.
join
(
self
.
dir
,
'synsets.txt'
)
assert
os
.
path
.
isfile
(
fname
)
assert
os
.
path
.
isfile
(
fname
)
...
@@ -59,8 +61,10 @@ class ILSVRCMeta(object):
...
@@ -59,8 +61,10 @@ class ILSVRCMeta(object):
def
get_image_list
(
self
,
name
):
def
get_image_list
(
self
,
name
):
"""
"""
:param name: 'train' or 'val' or 'test'
Args:
:returns: list of (image filename, cls)
name (str): 'train' or 'val' or 'test'
Returns:
list: list of (image filename, label)
"""
"""
assert
name
in
[
'train'
,
'val'
,
'test'
]
assert
name
in
[
'train'
,
'val'
,
'test'
]
fname
=
os
.
path
.
join
(
self
.
dir
,
name
+
'.txt'
)
fname
=
os
.
path
.
join
(
self
.
dir
,
name
+
'.txt'
)
...
@@ -75,8 +79,10 @@ class ILSVRCMeta(object):
...
@@ -75,8 +79,10 @@ class ILSVRCMeta(object):
def
get_per_pixel_mean
(
self
,
size
=
None
):
def
get_per_pixel_mean
(
self
,
size
=
None
):
"""
"""
:param size: return image size in [h, w]. default to (256, 256)
Args:
:returns: per-pixel mean as an array of shape (h, w, 3) in range [0, 255]
size (tuple): image size in (h, w). Defaults to (256, 256).
Returns:
np.ndarray: per-pixel mean of shape (h, w, 3 (BGR)) in range [0, 255].
"""
"""
obj
=
self
.
caffepb
.
BlobProto
()
obj
=
self
.
caffepb
.
BlobProto
()
...
@@ -91,18 +97,26 @@ class ILSVRCMeta(object):
...
@@ -91,18 +97,26 @@ class ILSVRCMeta(object):
class
ILSVRC12
(
RNGDataFlow
):
class
ILSVRC12
(
RNGDataFlow
):
"""
Produces ILSVRC12 images of shape [h, w, 3(BGR)], and a label between [0, 999],
and optionally a bounding box of [xmin, ymin, xmax, ymax].
"""
def
__init__
(
self
,
dir
,
name
,
meta_dir
=
None
,
shuffle
=
True
,
def
__init__
(
self
,
dir
,
name
,
meta_dir
=
None
,
shuffle
=
True
,
dir_structure
=
'original'
,
include_bb
=
False
):
dir_structure
=
'original'
,
include_bb
=
False
):
"""
"""
:param dir: A directory containing a subdir named `name`, where the
Args:
original ILSVRC12_`name`.tar gets decompressed.
dir (str): A directory containing a subdir named ``name``, where the
:param name: 'train' or 'val' or 'test'
original ``ILSVRC12_img_{name}.tar`` gets decompressed.
:param dir_structure: The dir structure of 'val' and 'test'.
name (str): 'train' or 'val' or 'test'.
If is 'original' then keep the original decompressed directory with list
shuffle (bool): shuffle the dataset.
of image files (as below). If set to 'train', use the the same
dir_structure (str): The dir structure of 'val' and 'test' directory.
directory structure as 'train/', with class name as subdirectories.
If is 'original', it expects the original decompressed
:param include_bb: Include the bounding box. Maybe useful in training.
directory, which only has list of image files (as below).
If set to 'train', it expects the same two-level
directory structure simlar to 'train/'.
include_bb (bool): Include the bounding box. Maybe useful in training.
Examples:
When `dir_structure=='original'`, `dir` should have the following structure:
When `dir_structure=='original'`, `dir` should have the following structure:
...
@@ -120,22 +134,16 @@ class ILSVRC12(RNGDataFlow):
...
@@ -120,22 +134,16 @@ class ILSVRC12(RNGDataFlow):
test/
test/
ILSVRC2012_test_00000001.JPEG
ILSVRC2012_test_00000001.JPEG
...
...
bbox/
n02134418/
n02134418_198.xml
...
...
After decompress ILSVRC12_img_train
.tar, you can use the following
With ILSVRC12_img_*
.tar, you can use the following
command to build the above structure
for `train/`
:
command to build the above structure:
.. code-block:: none
.. code-block:: none
tar xvf ILSVRC12_img_train.tar -C train && cd train
mkdir val && tar xvf ILSVRC12_img_val.tar -C val
mkdir test && tar xvf ILSVRC12_img_test.tar -C test
mkdir train && tar xvf ILSVRC12_img_train.tar -C train && cd train
find -type f -name '*.tar' | parallel -P 10 'echo {} && mkdir -p {/.} && tar xf {} -C {/.}'
find -type f -name '*.tar' | parallel -P 10 'echo {} && mkdir -p {/.} && tar xf {} -C {/.}'
Or:
for i in *.tar; do dir=${i
%
.tar}; echo $dir; mkdir -p $dir; tar xf $i -C $dir; done
"""
"""
assert
name
in
[
'train'
,
'test'
,
'val'
]
assert
name
in
[
'train'
,
'test'
,
'val'
]
self
.
full_dir
=
os
.
path
.
join
(
dir
,
name
)
self
.
full_dir
=
os
.
path
.
join
(
dir
,
name
)
...
@@ -158,10 +166,6 @@ class ILSVRC12(RNGDataFlow):
...
@@ -158,10 +166,6 @@ class ILSVRC12(RNGDataFlow):
return
len
(
self
.
imglist
)
return
len
(
self
.
imglist
)
def
get_data
(
self
):
def
get_data
(
self
):
"""
Produce original images of shape [h, w, 3(BGR)], and label,
and optionally a bbox of [xmin, ymin, xmax, ymax]
"""
idxs
=
np
.
arange
(
len
(
self
.
imglist
))
idxs
=
np
.
arange
(
len
(
self
.
imglist
))
add_label_to_fname
=
(
self
.
name
!=
'train'
and
self
.
dir_structure
!=
'original'
)
add_label_to_fname
=
(
self
.
name
!=
'train'
and
self
.
dir_structure
!=
'original'
)
if
self
.
shuffle
:
if
self
.
shuffle
:
...
...
tensorpack/dataflow/dataset/mnist.py
View file @
d8935ef3
...
@@ -65,14 +65,15 @@ def extract_labels(filename):
...
@@ -65,14 +65,15 @@ def extract_labels(filename):
class
Mnist
(
RNGDataFlow
):
class
Mnist
(
RNGDataFlow
):
"""
"""
Return [image, label]
,
Produces [image, label] in MNIST dataset
,
image is 28x28 in the range [0,1]
image is 28x28 in the range [0,1], label is an int.
"""
"""
def
__init__
(
self
,
train_or_test
,
shuffle
=
True
,
dir
=
None
):
def
__init__
(
self
,
train_or_test
,
shuffle
=
True
,
dir
=
None
):
"""
"""
Args:
Args:
train_or_test: string either 'train' or 'test'
train_or_test (str): either 'train' or 'test'
shuffle (bool): shuffle the dataset
"""
"""
if
dir
is
None
:
if
dir
is
None
:
dir
=
get_dataset_path
(
'mnist_data'
)
dir
=
get_dataset_path
(
'mnist_data'
)
...
...
tensorpack/dataflow/dataset/svhn.py
View file @
d8935ef3
...
@@ -21,15 +21,17 @@ SVHN_URL = "http://ufldl.stanford.edu/housenumbers/"
...
@@ -21,15 +21,17 @@ SVHN_URL = "http://ufldl.stanford.edu/housenumbers/"
class
SVHNDigit
(
RNGDataFlow
):
class
SVHNDigit
(
RNGDataFlow
):
"""
"""
SVHN
Cropped Digit Dataset.
`SVHN <http://ufldl.stanford.edu/housenumbers/>`_
Cropped Digit Dataset.
return img of 32x32x3
, label of 0-9
Produces [img, label], img of 32x32x3 in range [0,255]
, label of 0-9
"""
"""
_Cache
=
{}
_Cache
=
{}
def
__init__
(
self
,
name
,
data_dir
=
None
,
shuffle
=
True
):
def
__init__
(
self
,
name
,
data_dir
=
None
,
shuffle
=
True
):
"""
"""
:param name: 'train', 'test', or 'extra'
Args:
:param data_dir: a directory containing the original {train,test,extra}_32x32.mat
name (str): 'train', 'test', or 'extra'.
data_dir (str): a directory containing the original {train,test,extra}_32x32.mat.
shuffle (bool): shuffle the dataset.
"""
"""
self
.
shuffle
=
shuffle
self
.
shuffle
=
shuffle
...
...
tensorpack/dataflow/dataset/visualqa.py
View file @
d8935ef3
...
@@ -18,13 +18,11 @@ def read_json(fname):
...
@@ -18,13 +18,11 @@ def read_json(fname):
f
.
close
()
f
.
close
()
return
ret
return
ret
# TODO shuffle
class
VisualQA
(
DataFlow
):
class
VisualQA
(
DataFlow
):
"""
"""
Visual QA dataset. See http://visualqa.org/
`Visual QA <http://visualqa.org/>`_ dataset.
Simply read
q/a json file and produce q/a pairs in their original format.
It simply reads
q/a json file and produce q/a pairs in their original format.
"""
"""
def
__init__
(
self
,
question_file
,
annotation_file
):
def
__init__
(
self
,
question_file
,
annotation_file
):
...
...
tensorpack/dataflow/dftools.py
View file @
d8935ef3
...
@@ -23,17 +23,17 @@ except ImportError:
...
@@ -23,17 +23,17 @@ except ImportError:
else
:
else
:
__all__
.
extend
([
'dump_dataflow_to_lmdb'
])
__all__
.
extend
([
'dump_dataflow_to_lmdb'
])
# TODO pass a name_func to write label as filename?
def
dump_dataset_images
(
ds
,
dirname
,
max_count
=
None
,
index
=
0
):
def
dump_dataset_images
(
ds
,
dirname
,
max_count
=
None
,
index
=
0
):
""" Dump images from a
`DataFlow`
to a directory.
""" Dump images from a
DataFlow
to a directory.
:param ds: a `DataFlow` instance.
Args:
:param dirname: name of the directory.
ds (DataFlow): the DataFlow to dump.
:param max_count: max number of images to dump
dirname (str): name of the directory.
:param index: the index of the image component in a data point.
max_count (int): limit max number of images to dump. Defaults to unlimited.
index (int): the index of the image component in the data point.
"""
"""
# TODO pass a name_func to write label as filename?
mkdir_p
(
dirname
)
mkdir_p
(
dirname
)
if
max_count
is
None
:
if
max_count
is
None
:
max_count
=
sys
.
maxint
max_count
=
sys
.
maxint
...
@@ -48,9 +48,15 @@ def dump_dataset_images(ds, dirname, max_count=None, index=0):
...
@@ -48,9 +48,15 @@ def dump_dataset_images(ds, dirname, max_count=None, index=0):
def
dump_dataflow_to_lmdb
(
ds
,
lmdb_path
):
def
dump_dataflow_to_lmdb
(
ds
,
lmdb_path
):
""" Dump a `Dataflow` ds to a lmdb database, where the key is the index
"""
and the data is the serialized datapoint.
Dump a Dataflow to a lmdb database, where the keys are indices and values
The output database can be read directly by `LMDBDataPoint`
are serialized datapoints.
The output database can be read directly by
:class:`tensorpack.dataflow.LMDBDataPoint`.
Args:
ds (DataFlow): the DataFlow to dump.
lmdb_path (str): output path. Either a directory or a mdb file.
"""
"""
assert
isinstance
(
ds
,
DataFlow
),
type
(
ds
)
assert
isinstance
(
ds
,
DataFlow
),
type
(
ds
)
isdir
=
os
.
path
.
isdir
(
lmdb_path
)
isdir
=
os
.
path
.
isdir
(
lmdb_path
)
...
@@ -80,15 +86,20 @@ def dump_dataflow_to_lmdb(ds, lmdb_path):
...
@@ -80,15 +86,20 @@ def dump_dataflow_to_lmdb(ds, lmdb_path):
def
dataflow_to_process_queue
(
ds
,
size
,
nr_consumer
):
def
dataflow_to_process_queue
(
ds
,
size
,
nr_consumer
):
"""
"""
Convert a `DataFlow` to a multiprocessing.Queue.
Convert a DataFlow to a :class:`multiprocessing.Queue`.
The dataflow will only be reset in the spawned process.
The DataFlow will only be reset in the spawned process.
:param ds: a `DataFlow`
Args:
:param size: size of the queue
ds (DataFlow): the DataFlow to dump.
:param nr_consumer: number of consumer of the queue.
size (int): size of the queue
will add this many of `DIE` sentinel to the end of the queue.
nr_consumer (int): number of consumer of the queue.
:returns: (queue, process). The process will take data from `ds` to fill
The producer will add this many of ``DIE`` sentinel to the end of the queue.
the queue once you start it. Each element is (task_id, dp).
Returns:
tuple(queue, process):
The process will take data from ``ds`` and fill
the queue, once you start it. Each element in the queue is (idx,
dp). idx can be the ``DIE`` sentinel when ``ds`` is exhausted.
"""
"""
q
=
mp
.
Queue
(
size
)
q
=
mp
.
Queue
(
size
)
...
...
tensorpack/dataflow/format.py
View file @
d8935ef3
...
@@ -26,7 +26,7 @@ try:
...
@@ -26,7 +26,7 @@ try:
except
ImportError
:
except
ImportError
:
logger
.
warn_dependency
(
"LMDBData"
,
'lmdb'
)
logger
.
warn_dependency
(
"LMDBData"
,
'lmdb'
)
else
:
else
:
__all__
.
extend
([
'LMDBData'
,
'
CaffeLMDB'
,
'LMDBDataDecoder'
,
'LMDBDataPoint
'
])
__all__
.
extend
([
'LMDBData'
,
'
LMDBDataDecoder'
,
'LMDBDataPoint'
,
'CaffeLMDB
'
])
try
:
try
:
import
sklearn.datasets
import
sklearn.datasets
...
@@ -40,19 +40,23 @@ else:
...
@@ -40,19 +40,23 @@ else:
Adapters for different data format.
Adapters for different data format.
"""
"""
# TODO lazy load
class
HDF5Data
(
RNGDataFlow
):
class
HDF5Data
(
RNGDataFlow
):
"""
"""
Zip data from different paths in an HDF5 file. Will load all data into memory.
Zip data from different paths in an HDF5 file.
Warning:
The current implementation will load all data into memory.
"""
"""
# TODO lazy load
def
__init__
(
self
,
filename
,
data_paths
,
shuffle
=
True
):
def
__init__
(
self
,
filename
,
data_paths
,
shuffle
=
True
):
"""
"""
:param filename: h5 data file.
Args:
:param data_paths: list of h5 paths to zipped. For example ['images', 'labels']
filename (str): h5 data file.
:param shuffle: shuffle the order of all data.
data_paths (list): list of h5 paths to zipped.
For example `['images', 'labels']`.
shuffle (bool): shuffle all data.
"""
"""
self
.
f
=
h5py
.
File
(
filename
,
'r'
)
self
.
f
=
h5py
.
File
(
filename
,
'r'
)
logger
.
info
(
"Loading {} to memory..."
.
format
(
filename
))
logger
.
info
(
"Loading {} to memory..."
.
format
(
filename
))
...
@@ -74,9 +78,13 @@ class HDF5Data(RNGDataFlow):
...
@@ -74,9 +78,13 @@ class HDF5Data(RNGDataFlow):
class
LMDBData
(
RNGDataFlow
):
class
LMDBData
(
RNGDataFlow
):
""" Read a lmdb and produce k,v pair """
""" Read a LMDB database and produce (k,v) pairs """
def
__init__
(
self
,
lmdb_path
,
shuffle
=
True
):
def
__init__
(
self
,
lmdb_path
,
shuffle
=
True
):
"""
Args:
lmdb_path (str): a directory or a file.
shuffle (bool): shuffle the keys or not.
"""
self
.
_lmdb_path
=
lmdb_path
self
.
_lmdb_path
=
lmdb_path
self
.
_shuffle
=
shuffle
self
.
_shuffle
=
shuffle
self
.
open_lmdb
()
self
.
open_lmdb
()
...
@@ -122,11 +130,14 @@ class LMDBData(RNGDataFlow):
...
@@ -122,11 +130,14 @@ class LMDBData(RNGDataFlow):
class
LMDBDataDecoder
(
LMDBData
):
class
LMDBDataDecoder
(
LMDBData
):
""" Read a LMDB database and produce a decoded output."""
def
__init__
(
self
,
lmdb_path
,
decoder
,
shuffle
=
True
):
def
__init__
(
self
,
lmdb_path
,
decoder
,
shuffle
=
True
):
"""
"""
:param decoder: a function taking k, v and return a data point,
Args:
or return None to skip
lmdb_path (str): a directory or a file.
decoder (k,v -> dp | None): a function taking k, v and returning a datapoint,
or return None to discard.
shuffle (bool): shuffle the keys or not.
"""
"""
super
(
LMDBDataDecoder
,
self
)
.
__init__
(
lmdb_path
,
shuffle
)
super
(
LMDBDataDecoder
,
self
)
.
__init__
(
lmdb_path
,
shuffle
)
self
.
decoder
=
decoder
self
.
decoder
=
decoder
...
@@ -139,17 +150,31 @@ class LMDBDataDecoder(LMDBData):
...
@@ -139,17 +150,31 @@ class LMDBDataDecoder(LMDBData):
class
LMDBDataPoint
(
LMDBDataDecoder
):
class
LMDBDataPoint
(
LMDBDataDecoder
):
""" Read a LMDB file where each value is a serialized datapoint"""
""" Read a LMDB file and produce deserialized values.
This can work with :func:`tensorpack.dataflow.dftools.dump_dataflow_to_lmdb`. """
def
__init__
(
self
,
lmdb_path
,
shuffle
=
True
):
def
__init__
(
self
,
lmdb_path
,
shuffle
=
True
):
"""
Args:
lmdb_path (str): a directory or a file.
shuffle (bool): shuffle the keys or not.
"""
super
(
LMDBDataPoint
,
self
)
.
__init__
(
super
(
LMDBDataPoint
,
self
)
.
__init__
(
lmdb_path
,
decoder
=
lambda
k
,
v
:
loads
(
v
),
shuffle
=
shuffle
)
lmdb_path
,
decoder
=
lambda
k
,
v
:
loads
(
v
),
shuffle
=
shuffle
)
class
CaffeLMDB
(
LMDBDataDecoder
):
class
CaffeLMDB
(
LMDBDataDecoder
):
""" Read a Caffe LMDB file where each value contains a caffe.Datum protobuf """
"""
Read a Caffe LMDB file where each value contains a ``caffe.Datum`` protobuf.
Produces datapoints of the format: [HWC image, label].
"""
def
__init__
(
self
,
lmdb_path
,
shuffle
=
True
):
def
__init__
(
self
,
lmdb_path
,
shuffle
=
True
):
"""
Args:
lmdb_path (str): a directory or a file.
shuffle (bool): shuffle the keys or not.
"""
cpb
=
get_caffe_pb
()
cpb
=
get_caffe_pb
()
def
decoder
(
k
,
v
):
def
decoder
(
k
,
v
):
...
@@ -168,9 +193,14 @@ class CaffeLMDB(LMDBDataDecoder):
...
@@ -168,9 +193,14 @@ class CaffeLMDB(LMDBDataDecoder):
class
SVMLightData
(
RNGDataFlow
):
class
SVMLightData
(
RNGDataFlow
):
""" Read X,y from a svmlight file """
""" Read X,y from a svmlight file
, and produce [X_i, y_i] pairs.
"""
def
__init__
(
self
,
filename
,
shuffle
=
True
):
def
__init__
(
self
,
filename
,
shuffle
=
True
):
"""
Args:
filename (str): input file
shuffle (bool): shuffle the data
"""
self
.
X
,
self
.
y
=
sklearn
.
datasets
.
load_svmlight_file
(
filename
)
self
.
X
,
self
.
y
=
sklearn
.
datasets
.
load_svmlight_file
(
filename
)
self
.
X
=
np
.
asarray
(
self
.
X
.
todense
())
self
.
X
=
np
.
asarray
(
self
.
X
.
todense
())
self
.
shuffle
=
shuffle
self
.
shuffle
=
shuffle
...
...
tensorpack/dataflow/image.py
View file @
d8935ef3
...
@@ -12,13 +12,13 @@ __all__ = ['ImageFromFile', 'AugmentImageComponent', 'AugmentImageComponents']
...
@@ -12,13 +12,13 @@ __all__ = ['ImageFromFile', 'AugmentImageComponent', 'AugmentImageComponents']
class
ImageFromFile
(
RNGDataFlow
):
class
ImageFromFile
(
RNGDataFlow
):
""" Produce images read from a list of files. """
def
__init__
(
self
,
files
,
channel
=
3
,
resize
=
None
,
shuffle
=
False
):
def
__init__
(
self
,
files
,
channel
=
3
,
resize
=
None
,
shuffle
=
False
):
"""
"""
Generate images of 1 channel or 3 channels (in RGB order) from list of files.
Args:
:param files: list of file paths
files (list): list of file paths.
:param channel: 1 or 3 channel
channel (int): 1 or 3. Produce RGB images if channel==3.
:param resize: a (h, w) tuple. If given, will force a resize
resize (tuple): (h, w). If given, resize the image.
"""
"""
assert
len
(
files
),
"No image files given to ImageFromFile!"
assert
len
(
files
),
"No image files given to ImageFromFile!"
self
.
files
=
files
self
.
files
=
files
...
@@ -45,14 +45,15 @@ class ImageFromFile(RNGDataFlow):
...
@@ -45,14 +45,15 @@ class ImageFromFile(RNGDataFlow):
class
AugmentImageComponent
(
MapDataComponent
):
class
AugmentImageComponent
(
MapDataComponent
):
"""
Apply image augmentors on 1 component.
"""
def
__init__
(
self
,
ds
,
augmentors
,
index
=
0
):
def
__init__
(
self
,
ds
,
augmentors
,
index
=
0
):
"""
"""
Augment the image component of datapoints
Args:
:param ds: a `DataFlow` instance.
ds (DataFlow): input DataFlow.
:param augmentors: a list of `ImageAugmentor` instance to be applied in order.
augmentors (AugmentorList): a list of :class:`imgaug.ImageAugmentor` to be applied in order.
:param index: the index (or list of indices) of the image component
index (int): the index of the image component to be augmented.
in the produced datapoints by `ds`. default to be 0
"""
"""
if
isinstance
(
augmentors
,
AugmentorList
):
if
isinstance
(
augmentors
,
AugmentorList
):
self
.
augs
=
augmentors
self
.
augs
=
augmentors
...
@@ -67,12 +68,16 @@ class AugmentImageComponent(MapDataComponent):
...
@@ -67,12 +68,16 @@ class AugmentImageComponent(MapDataComponent):
class
AugmentImageComponents
(
MapData
):
class
AugmentImageComponents
(
MapData
):
"""
Apply image augmentors on several components, with shared augmentation parameters.
"""
def
__init__
(
self
,
ds
,
augmentors
,
index
=
(
0
,
1
)):
def
__init__
(
self
,
ds
,
augmentors
,
index
=
(
0
,
1
)):
""" Augment a list of images of the same shape, with the same parameters
"""
:param ds: a `DataFlow` instance.
Args:
:param augmentors: a list of `ImageAugmentor` instance to be applied in order.
ds (DataFlow): input DataFlow.
:param index: tuple of indices of the image components
augmentors (AugmentorList): a list of :class:`imgaug.ImageAugmentor` instance to be applied in order.
index: tuple of indices of components.
"""
"""
self
.
augs
=
AugmentorList
(
augmentors
)
self
.
augs
=
AugmentorList
(
augmentors
)
self
.
ds
=
ds
self
.
ds
=
ds
...
...
tensorpack/dataflow/imgaug/base.py
View file @
d8935ef3
...
@@ -24,6 +24,7 @@ class Augmentor(object):
...
@@ -24,6 +24,7 @@ class Augmentor(object):
setattr
(
self
,
k
,
v
)
setattr
(
self
,
k
,
v
)
def
reset_state
(
self
):
def
reset_state
(
self
):
""" reset rng and other state """
self
.
rng
=
get_rng
(
self
)
self
.
rng
=
get_rng
(
self
)
def
augment
(
self
,
d
):
def
augment
(
self
,
d
):
...
@@ -64,9 +65,13 @@ class ImageAugmentor(Augmentor):
...
@@ -64,9 +65,13 @@ class ImageAugmentor(Augmentor):
def
augment
(
self
,
img
):
def
augment
(
self
,
img
):
"""
"""
Perform augmentation on the image in-place.
Perform augmentation on the image (possibly) in-place.
:param img: an [h,w] or [h,w,c] image
:returns: the augmented image, always of type 'float32'
Args:
img (np.ndarray): an [h,w] or [h,w,c] image.
Returns:
np.ndarray: the augmented image, always of type float32.
"""
"""
img
,
params
=
self
.
_augment_return_params
(
img
)
img
,
params
=
self
.
_augment_return_params
(
img
)
return
img
return
img
...
@@ -82,7 +87,8 @@ class AugmentorList(ImageAugmentor):
...
@@ -82,7 +87,8 @@ class AugmentorList(ImageAugmentor):
def
__init__
(
self
,
augmentors
):
def
__init__
(
self
,
augmentors
):
"""
"""
:param augmentors: list of `ImageAugmentor` instance to be applied
Args:
augmentors (list): list of :class:`ImageAugmentor` instance to be applied.
"""
"""
self
.
augs
=
augmentors
self
.
augs
=
augmentors
super
(
AugmentorList
,
self
)
.
__init__
()
super
(
AugmentorList
,
self
)
.
__init__
()
...
...
tensorpack/dataflow/imgaug/crop.py
View file @
d8935ef3
...
@@ -10,7 +10,7 @@ from six.moves import range
...
@@ -10,7 +10,7 @@ from six.moves import range
import
numpy
as
np
import
numpy
as
np
__all__
=
[
'RandomCrop'
,
'CenterCrop'
,
'FixedCrop'
,
__all__
=
[
'RandomCrop'
,
'CenterCrop'
,
'FixedCrop'
,
'
RandomCropRandomShape'
,
'perturb_BB'
,
'RandomCropAroundBox
'
]
'
perturb_BB'
,
'RandomCropAroundBox'
,
'RandomCropRandomShape
'
]
class
RandomCrop
(
ImageAugmentor
):
class
RandomCrop
(
ImageAugmentor
):
...
@@ -18,7 +18,8 @@ class RandomCrop(ImageAugmentor):
...
@@ -18,7 +18,8 @@ class RandomCrop(ImageAugmentor):
def
__init__
(
self
,
crop_shape
):
def
__init__
(
self
,
crop_shape
):
"""
"""
:param crop_shape: a shape like (h, w)
Args:
crop_shape: (h, w) tuple or a int
"""
"""
crop_shape
=
shape2d
(
crop_shape
)
crop_shape
=
shape2d
(
crop_shape
)
super
(
RandomCrop
,
self
)
.
__init__
()
super
(
RandomCrop
,
self
)
.
__init__
()
...
@@ -47,7 +48,8 @@ class CenterCrop(ImageAugmentor):
...
@@ -47,7 +48,8 @@ class CenterCrop(ImageAugmentor):
def
__init__
(
self
,
crop_shape
):
def
__init__
(
self
,
crop_shape
):
"""
"""
:param crop_shape: a shape like (h, w)
Args:
crop_shape: (h, w) tuple or a int
"""
"""
crop_shape
=
shape2d
(
crop_shape
)
crop_shape
=
shape2d
(
crop_shape
)
self
.
_init
(
locals
())
self
.
_init
(
locals
())
...
@@ -67,9 +69,8 @@ class FixedCrop(ImageAugmentor):
...
@@ -67,9 +69,8 @@ class FixedCrop(ImageAugmentor):
def
__init__
(
self
,
rect
):
def
__init__
(
self
,
rect
):
"""
"""
Two arguments defined the range in both axes to crop, min inclued, max excluded.
Args:
rect(Rect): min included, max excluded.
:param rect: a `Rect` instance
"""
"""
self
.
_init
(
locals
())
self
.
_init
(
locals
())
...
@@ -86,12 +87,15 @@ def perturb_BB(image_shape, bb, max_pertub_pixel,
...
@@ -86,12 +87,15 @@ def perturb_BB(image_shape, bb, max_pertub_pixel,
max_try
=
100
):
max_try
=
100
):
"""
"""
Perturb a bounding box.
Perturb a bounding box.
:param image_shape: [h, w]
:param bb: a `Rect` instance
Args:
:param max_pertub_pixel: pertubation on each coordinate
image_shape: [h, w]
:param max_aspect_ratio_diff: result can't have an aspect ratio too different from the original
bb (Rect): original bounding box
:param max_try: if cannot find a valid bounding box, return the original
max_pertub_pixel: pertubation on each coordinate
:returns: new bounding box
max_aspect_ratio_diff: result can't have an aspect ratio too different from the original
max_try: if cannot find a valid bounding box, return the original
Returns:
new bounding box
"""
"""
orig_ratio
=
bb
.
h
*
1.0
/
bb
.
w
orig_ratio
=
bb
.
h
*
1.0
/
bb
.
w
if
rng
is
None
:
if
rng
is
None
:
...
@@ -117,13 +121,15 @@ def perturb_BB(image_shape, bb, max_pertub_pixel,
...
@@ -117,13 +121,15 @@ def perturb_BB(image_shape, bb, max_pertub_pixel,
class
RandomCropAroundBox
(
ImageAugmentor
):
class
RandomCropAroundBox
(
ImageAugmentor
):
"""
"""
Crop a box around a bounding box
Crop a box around a bounding box
by some random pertubation
"""
"""
def
__init__
(
self
,
perturb_ratio
,
max_aspect_ratio_diff
=
0.3
):
def
__init__
(
self
,
perturb_ratio
,
max_aspect_ratio_diff
=
0.3
):
"""
"""
:param perturb_ratio: perturb distance will be in [0, perturb_ratio * sqrt(w * h)]
Args:
:param max_aspect_ratio_diff: keep aspect ratio within the range
perturb_ratio (float): perturb distance will be in
``[0, perturb_ratio * sqrt(w * h)]``
max_aspect_ratio_diff (float): keep aspect ratio difference within the range
"""
"""
super
(
RandomCropAroundBox
,
self
)
.
__init__
()
super
(
RandomCropAroundBox
,
self
)
.
__init__
()
self
.
_init
(
locals
())
self
.
_init
(
locals
())
...
@@ -144,14 +150,18 @@ class RandomCropAroundBox(ImageAugmentor):
...
@@ -144,14 +150,18 @@ class RandomCropAroundBox(ImageAugmentor):
class
RandomCropRandomShape
(
ImageAugmentor
):
class
RandomCropRandomShape
(
ImageAugmentor
):
""" Random crop with a random shape"""
def
__init__
(
self
,
wmin
,
hmin
,
def
__init__
(
self
,
wmin
,
hmin
,
wmax
=
None
,
hmax
=
None
,
wmax
=
None
,
hmax
=
None
,
max_aspect_ratio
=
None
):
max_aspect_ratio
=
None
):
"""
"""
Randomly crop a box of shape (h, w), sampled from [min, max]
(
inclusive).
Randomly crop a box of shape (h, w), sampled from [min, max]
(both
inclusive).
If max is None, will use the input image shape.
If max is None, will use the input image shape.
max_aspect_ratio is the upper bound of max(w,h)/min(w,h)
Args:
wmin, hmin, wmax, hmax: range to sample shape.
max_aspect_ratio (float): the upper bound of ``max(w,h)/min(w,h)``.
"""
"""
if
max_aspect_ratio
is
None
:
if
max_aspect_ratio
is
None
:
max_aspect_ratio
=
9999999
max_aspect_ratio
=
9999999
...
...
tensorpack/dataflow/imgaug/deform.py
View file @
d8935ef3
...
@@ -6,13 +6,12 @@ from .base import ImageAugmentor
...
@@ -6,13 +6,12 @@ from .base import ImageAugmentor
from
...utils
import
logger
from
...utils
import
logger
import
numpy
as
np
import
numpy
as
np
__all__
=
[
'GaussianDeform'
,
'GaussianMap'
]
__all__
=
[
'GaussianDeform'
]
# TODO really needs speedup
class
GaussianMap
(
object
):
class
GaussianMap
(
object
):
""" Generate gaussian weighted deformation map"""
""" Generate gaussian weighted deformation map"""
# TODO really needs speedup
def
__init__
(
self
,
image_shape
,
sigma
=
0.5
):
def
__init__
(
self
,
image_shape
,
sigma
=
0.5
):
assert
len
(
image_shape
)
==
2
assert
len
(
image_shape
)
==
2
...
@@ -20,6 +19,10 @@ class GaussianMap(object):
...
@@ -20,6 +19,10 @@ class GaussianMap(object):
self
.
sigma
=
sigma
self
.
sigma
=
sigma
def
get_gaussian_weight
(
self
,
anchor
):
def
get_gaussian_weight
(
self
,
anchor
):
"""
Args:
anchor: coordinate of the center
"""
ret
=
np
.
zeros
(
self
.
shape
,
dtype
=
'float32'
)
ret
=
np
.
zeros
(
self
.
shape
,
dtype
=
'float32'
)
y
,
x
=
np
.
mgrid
[:
self
.
shape
[
0
],
:
self
.
shape
[
1
]]
y
,
x
=
np
.
mgrid
[:
self
.
shape
[
0
],
:
self
.
shape
[
1
]]
...
@@ -55,20 +58,20 @@ def np_sample(img, coords):
...
@@ -55,20 +58,20 @@ def np_sample(img, coords):
img
[
ucoory
,
lcoorx
,
:]
*
diffy
*
ndiffx
img
[
ucoory
,
lcoorx
,
:]
*
diffy
*
ndiffx
return
ret
[:,
:,
0
,
:]
return
ret
[:,
:,
0
,
:]
# TODO input/output with different shape
class
GaussianDeform
(
ImageAugmentor
):
class
GaussianDeform
(
ImageAugmentor
):
"""
"""
Some kind of
deformation. Quite slow
.
Some kind of
slow deformation
.
"""
"""
# TODO input/output with different shape
def
__init__
(
self
,
anchors
,
shape
,
sigma
=
0.5
,
randrange
=
None
):
def
__init__
(
self
,
anchors
,
shape
,
sigma
=
0.5
,
randrange
=
None
):
"""
"""
:param anchors: in [0,1] coordinate
Args:
:param shape: image shape in [h, w]
anchors (list): list of center coordinates in range [0,1].
:param sigma: sigma for Gaussian weight
shape(list or tuple): image shape in [h, w].
:param randrange: default to shape[0] / 8
sigma (float): sigma for Gaussian weight
randrange (int): offset range. Defaults to shape[0] / 8
"""
"""
logger
.
warn
(
"GaussianDeform is slow. Consider using it with 4 or more prefetching processes."
)
logger
.
warn
(
"GaussianDeform is slow. Consider using it with 4 or more prefetching processes."
)
super
(
GaussianDeform
,
self
)
.
__init__
()
super
(
GaussianDeform
,
self
)
.
__init__
()
...
...
tensorpack/dataflow/imgaug/geometry.py
View file @
d8935ef3
...
@@ -17,8 +17,11 @@ class Rotation(ImageAugmentor):
...
@@ -17,8 +17,11 @@ class Rotation(ImageAugmentor):
interp
=
cv2
.
INTER_CUBIC
,
interp
=
cv2
.
INTER_CUBIC
,
border
=
cv2
.
BORDER_REPLICATE
):
border
=
cv2
.
BORDER_REPLICATE
):
"""
"""
:param max_deg: max abs value of the rotation degree
Args:
:param center_range: the location of the rotation center
max_deg (float): max abs value of the rotation degree (in angle).
center_range (tuple): (min, max) range of the random rotation center.
interp: cv2 interpolation method
border: cv2 border method
"""
"""
super
(
Rotation
,
self
)
.
__init__
()
super
(
Rotation
,
self
)
.
__init__
()
self
.
_init
(
locals
())
self
.
_init
(
locals
())
...
@@ -36,11 +39,16 @@ class Rotation(ImageAugmentor):
...
@@ -36,11 +39,16 @@ class Rotation(ImageAugmentor):
class
RotationAndCropValid
(
ImageAugmentor
):
class
RotationAndCropValid
(
ImageAugmentor
):
""" Random rotate and
crop the largest possible rect without the border
""" Random rotate and
then crop the largest possible rectangle.
T
his will produce images of different shapes.
Note that t
his will produce images of different shapes.
"""
"""
def
__init__
(
self
,
max_deg
,
interp
=
cv2
.
INTER_CUBIC
):
def
__init__
(
self
,
max_deg
,
interp
=
cv2
.
INTER_CUBIC
):
"""
Args:
max_deg (float): max abs value of the rotation degree (in angle).
interp: cv2 interpolation method
"""
super
(
RotationAndCropValid
,
self
)
.
__init__
()
super
(
RotationAndCropValid
,
self
)
.
__init__
()
self
.
_init
(
locals
())
self
.
_init
(
locals
())
...
@@ -63,7 +71,10 @@ class RotationAndCropValid(ImageAugmentor):
...
@@ -63,7 +71,10 @@ class RotationAndCropValid(ImageAugmentor):
@
staticmethod
@
staticmethod
def
largest_rotated_rect
(
w
,
h
,
angle
):
def
largest_rotated_rect
(
w
,
h
,
angle
):
""" http://stackoverflow.com/questions/16702966/rotate-image-and-crop-out-black-borders """
"""
Get largest rectangle after rotation.
http://stackoverflow.com/questions/16702966/rotate-image-and-crop-out-black-borders
"""
angle
=
angle
/
180.0
*
math
.
pi
angle
=
angle
/
180.0
*
math
.
pi
if
w
<=
0
or
h
<=
0
:
if
w
<=
0
or
h
<=
0
:
return
0
,
0
return
0
,
0
...
...
tensorpack/dataflow/imgaug/imgproc.py
View file @
d8935ef3
...
@@ -12,9 +12,8 @@ __all__ = ['Brightness', 'Contrast', 'MeanVarianceNormalize', 'GaussianBlur',
...
@@ -12,9 +12,8 @@ __all__ = ['Brightness', 'Contrast', 'MeanVarianceNormalize', 'GaussianBlur',
class
Brightness
(
ImageAugmentor
):
class
Brightness
(
ImageAugmentor
):
"""
"""
Random adjust brightness.
Random
ly
adjust brightness.
"""
"""
def
__init__
(
self
,
delta
,
clip
=
True
):
def
__init__
(
self
,
delta
,
clip
=
True
):
"""
"""
Randomly add a value within [-delta,delta], and clip in [0,255] if clip is True.
Randomly add a value within [-delta,delta], and clip in [0,255] if clip is True.
...
@@ -36,14 +35,14 @@ class Brightness(ImageAugmentor):
...
@@ -36,14 +35,14 @@ class Brightness(ImageAugmentor):
class
Contrast
(
ImageAugmentor
):
class
Contrast
(
ImageAugmentor
):
"""
"""
Apply x = (x - mean) * contrast_factor + mean to each channel
Apply ``x = (x - mean) * contrast_factor + mean`` to each channel.
and clip to [0, 255]
"""
"""
def
__init__
(
self
,
factor_range
,
clip
=
True
):
def
__init__
(
self
,
factor_range
,
clip
=
True
):
"""
"""
:param factor_range: an interval to random sample the `contrast_factor`.
Args:
:param clip: boolean.
factor_range (list or tuple): an interval to randomly sample the `contrast_factor`.
clip (bool): clip to [0, 255] if True.
"""
"""
super
(
Contrast
,
self
)
.
__init__
()
super
(
Contrast
,
self
)
.
__init__
()
self
.
_init
(
locals
())
self
.
_init
(
locals
())
...
@@ -61,14 +60,15 @@ class Contrast(ImageAugmentor):
...
@@ -61,14 +60,15 @@ class Contrast(ImageAugmentor):
class
MeanVarianceNormalize
(
ImageAugmentor
):
class
MeanVarianceNormalize
(
ImageAugmentor
):
"""
"""
Linearly scales image to have zero mean and unit norm.
Linearly scales
the
image to have zero mean and unit norm.
x = (x - mean) / adjusted_stddev
``x = (x - mean) / adjusted_stddev``
where
adjusted_stddev = max(stddev, 1.0/sqrt(num_pixels * channels))
where
``adjusted_stddev = max(stddev, 1.0/sqrt(num_pixels * channels))``
"""
"""
def
__init__
(
self
,
all_channel
=
True
):
def
__init__
(
self
,
all_channel
=
True
):
"""
"""
:param all_channel: if True, normalize all channels together. else separately.
Args:
all_channel (bool): if True, normalize all channels together. else separately.
"""
"""
self
.
all_channel
=
all_channel
self
.
all_channel
=
all_channel
...
@@ -85,9 +85,13 @@ class MeanVarianceNormalize(ImageAugmentor):
...
@@ -85,9 +85,13 @@ class MeanVarianceNormalize(ImageAugmentor):
class
GaussianBlur
(
ImageAugmentor
):
class
GaussianBlur
(
ImageAugmentor
):
""" Gaussian blur the image with random window size"""
def
__init__
(
self
,
max_size
=
3
):
def
__init__
(
self
,
max_size
=
3
):
""":params max_size: (maximum kernel size-1)/2"""
"""
Args:
max_size (int): max possible Gaussian window size would be 2 * max_size + 1
"""
super
(
GaussianBlur
,
self
)
.
__init__
()
super
(
GaussianBlur
,
self
)
.
__init__
()
self
.
_init
(
locals
())
self
.
_init
(
locals
())
...
@@ -103,8 +107,12 @@ class GaussianBlur(ImageAugmentor):
...
@@ -103,8 +107,12 @@ class GaussianBlur(ImageAugmentor):
class
Gamma
(
ImageAugmentor
):
class
Gamma
(
ImageAugmentor
):
""" Randomly adjust gamma """
def
__init__
(
self
,
range
=
(
-
0.5
,
0.5
)):
def
__init__
(
self
,
range
=
(
-
0.5
,
0.5
)):
"""
Args:
range(list or tuple): gamma range
"""
super
(
Gamma
,
self
)
.
__init__
()
super
(
Gamma
,
self
)
.
__init__
()
self
.
_init
(
locals
())
self
.
_init
(
locals
())
...
@@ -119,8 +127,13 @@ class Gamma(ImageAugmentor):
...
@@ -119,8 +127,13 @@ class Gamma(ImageAugmentor):
class
Clip
(
ImageAugmentor
):
class
Clip
(
ImageAugmentor
):
""" Clip the pixel values """
def
__init__
(
self
,
min
=
0
,
max
=
255
):
def
__init__
(
self
,
min
=
0
,
max
=
255
):
"""
Args:
min, max: the clip range
"""
self
.
_init
(
locals
())
self
.
_init
(
locals
())
def
_augment
(
self
,
img
,
_
):
def
_augment
(
self
,
img
,
_
):
...
@@ -129,10 +142,15 @@ class Clip(ImageAugmentor):
...
@@ -129,10 +142,15 @@ class Clip(ImageAugmentor):
class
Saturation
(
ImageAugmentor
):
class
Saturation
(
ImageAugmentor
):
""" Randomly adjust saturation.
Follows the implementation in `fb.resnet.torch
<https://github.com/facebook/fb.resnet.torch/blob/master/datasets/transforms.lua#L218>`_
"""
def
__init__
(
self
,
alpha
=
0.4
):
def
__init__
(
self
,
alpha
=
0.4
):
""" Saturation,
"""
see 'fb.resnet.torch' https://github.com/facebook/fb.resnet.torch/blob/master/datasets/transforms.lua#L218
Args:
alpha(float): maximum saturation change.
"""
"""
super
(
Saturation
,
self
)
.
__init__
()
super
(
Saturation
,
self
)
.
__init__
()
assert
alpha
<
1
assert
alpha
<
1
...
@@ -147,14 +165,19 @@ class Saturation(ImageAugmentor):
...
@@ -147,14 +165,19 @@ class Saturation(ImageAugmentor):
class
Lighting
(
ImageAugmentor
):
class
Lighting
(
ImageAugmentor
):
""" Lighting noise, as in the paper
`ImageNet Classification with Deep Convolutional Neural Networks
<https://papers.nips.cc/paper/4824-imagenet-classification-with-deep-convolutional-neural-networks.pdf>`_.
The implementation follows `fb.resnet.torch
<https://github.com/facebook/fb.resnet.torch/blob/master/datasets/transforms.lua#L184>`_.
"""
def
__init__
(
self
,
std
,
eigval
,
eigvec
):
def
__init__
(
self
,
std
,
eigval
,
eigvec
):
""" Lighting noise.
"""
See `ImageNet Classification with Deep Convolutional Neural Networks - Alex`
Args:
The implementation follows 'fb.resnet.torch':
std (float): maximum standard deviation
https://github.com/facebook/fb.resnet.torch/blob/master/datasets/transforms.lua#L184
eigval: a vector of (3,). The eigenvalues of 3 channels.
eigvec: a 3x3 matrix. Each column is one eigen vector.
:param eigvec: each column is one eigen vector
"""
"""
eigval
=
np
.
asarray
(
eigval
)
eigval
=
np
.
asarray
(
eigval
)
eigvec
=
np
.
asarray
(
eigvec
)
eigvec
=
np
.
asarray
(
eigvec
)
...
...
tensorpack/dataflow/imgaug/meta.py
View file @
d8935ef3
...
@@ -11,15 +11,22 @@ __all__ = ['RandomChooseAug', 'MapImage', 'Identity', 'RandomApplyAug',
...
@@ -11,15 +11,22 @@ __all__ = ['RandomChooseAug', 'MapImage', 'Identity', 'RandomApplyAug',
class
Identity
(
ImageAugmentor
):
class
Identity
(
ImageAugmentor
):
""" A no-op augmentor """
def
_augment
(
self
,
img
,
_
):
def
_augment
(
self
,
img
,
_
):
return
img
return
img
class
RandomApplyAug
(
ImageAugmentor
):
class
RandomApplyAug
(
ImageAugmentor
):
""" Randomly apply the augmentor with a prob. Otherwise do nothing"""
""" Randomly apply the augmentor with a probability.
Otherwise do nothing
"""
def
__init__
(
self
,
aug
,
prob
):
def
__init__
(
self
,
aug
,
prob
):
"""
Args:
aug (ImageAugmentor): an augmentor
prob (float): the probability
"""
self
.
_init
(
locals
())
self
.
_init
(
locals
())
super
(
RandomApplyAug
,
self
)
.
__init__
()
super
(
RandomApplyAug
,
self
)
.
__init__
()
...
@@ -43,10 +50,11 @@ class RandomApplyAug(ImageAugmentor):
...
@@ -43,10 +50,11 @@ class RandomApplyAug(ImageAugmentor):
class
RandomChooseAug
(
ImageAugmentor
):
class
RandomChooseAug
(
ImageAugmentor
):
""" Randomly choose one from a list of augmentors """
def
__init__
(
self
,
aug_lists
):
def
__init__
(
self
,
aug_lists
):
"""
"""
:param aug_lists: list of augmentor, or list of (augmentor, probability) tuple
Args:
aug_lists (list): list of augmentors, or list of (augmentor, probability) tuples
"""
"""
if
isinstance
(
aug_lists
[
0
],
(
tuple
,
list
)):
if
isinstance
(
aug_lists
[
0
],
(
tuple
,
list
)):
prob
=
[
k
[
1
]
for
k
in
aug_lists
]
prob
=
[
k
[
1
]
for
k
in
aug_lists
]
...
@@ -73,11 +81,15 @@ class RandomChooseAug(ImageAugmentor):
...
@@ -73,11 +81,15 @@ class RandomChooseAug(ImageAugmentor):
class
RandomOrderAug
(
ImageAugmentor
):
class
RandomOrderAug
(
ImageAugmentor
):
"""
Apply the augmentors with randomized order.
"""
def
__init__
(
self
,
aug_lists
):
def
__init__
(
self
,
aug_lists
):
"""
"""
Shuffle the augmentors into random order.
Args:
:param aug_lists: list of augmentor, or list of (augmentor, probability) tuple
aug_lists (list): list of augmentors.
The augmentors are assumed to not change the shape of images.
"""
"""
self
.
_init
(
locals
())
self
.
_init
(
locals
())
super
(
RandomOrderAug
,
self
)
.
__init__
()
super
(
RandomOrderAug
,
self
)
.
__init__
()
...
@@ -109,7 +121,8 @@ class MapImage(ImageAugmentor):
...
@@ -109,7 +121,8 @@ class MapImage(ImageAugmentor):
def
__init__
(
self
,
func
):
def
__init__
(
self
,
func
):
"""
"""
:param func: a function which takes a image array and return a augmented one
Args:
func: a function which takes an image array and return an augmented one
"""
"""
self
.
func
=
func
self
.
func
=
func
...
...
tensorpack/dataflow/imgaug/noise.py
View file @
d8935ef3
...
@@ -11,8 +11,13 @@ __all__ = ['JpegNoise', 'GaussianNoise', 'SaltPepperNoise']
...
@@ -11,8 +11,13 @@ __all__ = ['JpegNoise', 'GaussianNoise', 'SaltPepperNoise']
class
JpegNoise
(
ImageAugmentor
):
class
JpegNoise
(
ImageAugmentor
):
""" Random Jpeg noise. """
def
__init__
(
self
,
quality_range
=
(
40
,
100
)):
def
__init__
(
self
,
quality_range
=
(
40
,
100
)):
"""
Args:
quality_range (tuple): range to sample Jpeg quality
"""
super
(
JpegNoise
,
self
)
.
__init__
()
super
(
JpegNoise
,
self
)
.
__init__
()
self
.
_init
(
locals
())
self
.
_init
(
locals
())
...
@@ -25,10 +30,14 @@ class JpegNoise(ImageAugmentor):
...
@@ -25,10 +30,14 @@ class JpegNoise(ImageAugmentor):
class
GaussianNoise
(
ImageAugmentor
):
class
GaussianNoise
(
ImageAugmentor
):
"""
Add random gaussian noise N(0, sigma^2) of the same shape to img.
"""
def
__init__
(
self
,
sigma
=
1
,
clip
=
True
):
def
__init__
(
self
,
sigma
=
1
,
clip
=
True
):
"""
"""
Add a gaussian noise N(0, sigma^2) of the same shape to img.
Args:
sigma (float): stddev of the Gaussian distribution.
clip (bool): clip the result to [0,255] in the end.
"""
"""
super
(
GaussianNoise
,
self
)
.
__init__
()
super
(
GaussianNoise
,
self
)
.
__init__
()
self
.
_init
(
locals
())
self
.
_init
(
locals
())
...
@@ -44,10 +53,14 @@ class GaussianNoise(ImageAugmentor):
...
@@ -44,10 +53,14 @@ class GaussianNoise(ImageAugmentor):
class
SaltPepperNoise
(
ImageAugmentor
):
class
SaltPepperNoise
(
ImageAugmentor
):
""" Salt and pepper noise.
Randomly set some elements in img to 0 or 255, regardless of its channels.
"""
def
__init__
(
self
,
white_prob
=
0.05
,
black_prob
=
0.05
):
def
__init__
(
self
,
white_prob
=
0.05
,
black_prob
=
0.05
):
""" Salt and pepper noise.
"""
Randomly set some elements in img to 0 or 255, regardless of its channels.
Args:
white_prob (float), black_prob (float): probabilities setting an element to 255 or 0.
"""
"""
assert
white_prob
+
black_prob
<=
1
,
"Sum of probabilities cannot be greater than 1"
assert
white_prob
+
black_prob
<=
1
,
"Sum of probabilities cannot be greater than 1"
super
(
SaltPepperNoise
,
self
)
.
__init__
()
super
(
SaltPepperNoise
,
self
)
.
__init__
()
...
...
tensorpack/dataflow/imgaug/noname.py
View file @
d8935ef3
...
@@ -13,20 +13,18 @@ __all__ = ['Flip', 'Resize', 'RandomResize', 'ResizeShortestEdge']
...
@@ -13,20 +13,18 @@ __all__ = ['Flip', 'Resize', 'RandomResize', 'ResizeShortestEdge']
class
Flip
(
ImageAugmentor
):
class
Flip
(
ImageAugmentor
):
"""
"""
Random flip.
Random flip
the image either horizontally or vertically
.
"""
"""
def
__init__
(
self
,
horiz
=
False
,
vert
=
False
,
prob
=
0.5
):
def
__init__
(
self
,
horiz
=
False
,
vert
=
False
,
prob
=
0.5
):
"""
"""
Only one of horiz, vert can be set.
Args:
horiz (bool): use horizontal flip.
:param horiz: whether or not apply horizontal flip.
vert (bool): use vertical flip.
:param vert: whether or not apply vertical flip.
prob (float): probability of flip.
:param prob: probability of flip.
"""
"""
super
(
Flip
,
self
)
.
__init__
()
super
(
Flip
,
self
)
.
__init__
()
if
horiz
and
vert
:
if
horiz
and
vert
:
raise
ValueError
(
"Please use two Flip instead."
)
raise
ValueError
(
"
Cannot do both horiz and vert.
Please use two Flip instead."
)
elif
horiz
:
elif
horiz
:
self
.
code
=
1
self
.
code
=
1
elif
vert
:
elif
vert
:
...
@@ -53,7 +51,9 @@ class Resize(ImageAugmentor):
...
@@ -53,7 +51,9 @@ class Resize(ImageAugmentor):
def
__init__
(
self
,
shape
,
interp
=
cv2
.
INTER_CUBIC
):
def
__init__
(
self
,
shape
,
interp
=
cv2
.
INTER_CUBIC
):
"""
"""
:param shape: shape in (h, w)
Args:
shape: (h, w) tuple or a int
interp: cv2 interpolation method
"""
"""
shape
=
tuple
(
shape2d
(
shape
))
shape
=
tuple
(
shape2d
(
shape
))
self
.
_init
(
locals
())
self
.
_init
(
locals
())
...
@@ -68,11 +68,16 @@ class Resize(ImageAugmentor):
...
@@ -68,11 +68,16 @@ class Resize(ImageAugmentor):
class
ResizeShortestEdge
(
ImageAugmentor
):
class
ResizeShortestEdge
(
ImageAugmentor
):
""" Resize the shortest edge to a certain number while
"""
keeping the aspect ratio
Resize the shortest edge to a certain number while
keeping the aspect ratio.
"""
"""
def
__init__
(
self
,
size
):
def
__init__
(
self
,
size
):
"""
Args:
size (int): the size to resize the shortest edge to.
"""
size
=
size
*
1.0
size
=
size
*
1.0
self
.
_init
(
locals
())
self
.
_init
(
locals
())
...
@@ -87,15 +92,18 @@ class ResizeShortestEdge(ImageAugmentor):
...
@@ -87,15 +92,18 @@ class ResizeShortestEdge(ImageAugmentor):
class
RandomResize
(
ImageAugmentor
):
class
RandomResize
(
ImageAugmentor
):
"""
r
andomly rescale w and h of the image"""
"""
R
andomly rescale w and h of the image"""
def
__init__
(
self
,
xrange
,
yrange
,
minimum
=
(
0
,
0
),
aspect_ratio_thres
=
0.15
,
def
__init__
(
self
,
xrange
,
yrange
,
minimum
=
(
0
,
0
),
aspect_ratio_thres
=
0.15
,
interp
=
cv2
.
INTER_CUBIC
):
interp
=
cv2
.
INTER_CUBIC
):
"""
"""
:param xrange: (min, max) scaling ratio
Args:
:param yrange: (min, max) scaling ratio
xrange (tuple): (min, max) range of scaling ratio for w
:param minimum: (xmin, ymin). Avoid scaling down too much.
yrange (tuple): (min, max) range of scaling ratio for h
:param aspect_ratio_thres: at most change k=20
%
aspect ratio
minimum (tuple): (xmin, ymin). avoid scaling down too much.
aspect_ratio_thres (float): discard samples which change aspect ratio
larger than this threshold.
interp: cv2 interpolation method
"""
"""
super
(
RandomResize
,
self
)
.
__init__
()
super
(
RandomResize
,
self
)
.
__init__
()
self
.
_init
(
locals
())
self
.
_init
(
locals
())
...
...
tensorpack/dataflow/imgaug/paste.py
View file @
d8935ef3
...
@@ -19,9 +19,11 @@ class BackgroundFiller(object):
...
@@ -19,9 +19,11 @@ class BackgroundFiller(object):
"""
"""
Return a proper background image of background_shape, given img
Return a proper background image of background_shape, given img
:param background_shape: a shape of [h, w]
Args:
:param img: an image
background_shape: a shape of [h, w]
:returns: a background image
img: an image
Returns:
a background image
"""
"""
return
self
.
_fill
(
background_shape
,
img
)
return
self
.
_fill
(
background_shape
,
img
)
...
@@ -35,7 +37,8 @@ class ConstantBackgroundFiller(BackgroundFiller):
...
@@ -35,7 +37,8 @@ class ConstantBackgroundFiller(BackgroundFiller):
def
__init__
(
self
,
value
):
def
__init__
(
self
,
value
):
"""
"""
:param value: the value to fill the background.
Args:
value (float): the value to fill the background.
"""
"""
self
.
value
=
value
self
.
value
=
value
...
@@ -55,8 +58,9 @@ class CenterPaste(ImageAugmentor):
...
@@ -55,8 +58,9 @@ class CenterPaste(ImageAugmentor):
def
__init__
(
self
,
background_shape
,
background_filler
=
None
):
def
__init__
(
self
,
background_shape
,
background_filler
=
None
):
"""
"""
:param background_shape: shape of the background canvas.
Args:
:param background_filler: a `BackgroundFiller` instance. Default to zero-filler.
background_shape (tuple): shape of the background canvas.
background_filler (BackgroundFiller): How to fill the background. Defaults to zero-filler.
"""
"""
if
background_filler
is
None
:
if
background_filler
is
None
:
background_filler
=
ConstantBackgroundFiller
(
0
)
background_filler
=
ConstantBackgroundFiller
(
0
)
...
...
tensorpack/dataflow/prefetch.py
View file @
d8935ef3
...
@@ -48,15 +48,19 @@ class PrefetchProcess(mp.Process):
...
@@ -48,15 +48,19 @@ class PrefetchProcess(mp.Process):
class
PrefetchData
(
ProxyDataFlow
):
class
PrefetchData
(
ProxyDataFlow
):
"""
"""
Prefetch data from a `DataFlow` using multiprocessing
Prefetch data from a DataFlow using Python multiprocessing utilities.
Note:
This is significantly slower than :class:`PrefetchDataZMQ` when data
is large.
"""
"""
def
__init__
(
self
,
ds
,
nr_prefetch
,
nr_proc
=
1
):
def
__init__
(
self
,
ds
,
nr_prefetch
,
nr_proc
=
1
):
"""
"""
:param ds: a `DataFlow` instance.
Args:
:param nr_prefetch: size of the queue to hold prefetched datapoints
.
ds (DataFlow): input DataFlow
.
:param nr_proc: number of processes to use. When larger than 1, order
nr_prefetch (int): size of the queue to hold prefetched datapoints.
of data points will be random
.
nr_proc (int): number of processes to use
.
"""
"""
super
(
PrefetchData
,
self
)
.
__init__
(
ds
)
super
(
PrefetchData
,
self
)
.
__init__
(
ds
)
try
:
try
:
...
@@ -85,9 +89,8 @@ class PrefetchData(ProxyDataFlow):
...
@@ -85,9 +89,8 @@ class PrefetchData(ProxyDataFlow):
def
BlockParallel
(
ds
,
queue_size
):
def
BlockParallel
(
ds
,
queue_size
):
# TODO more doc
"""
"""
Insert `
BlockParallel` in dataflow pipeline to block parallelism on ds
Insert `
`BlockParallel`` in dataflow pipeline to block parallelism on ds.
:param ds: a `DataFlow`
:param ds: a `DataFlow`
:param queue_size: size of the queue used
:param queue_size: size of the queue used
...
@@ -96,7 +99,6 @@ def BlockParallel(ds, queue_size):
...
@@ -96,7 +99,6 @@ def BlockParallel(ds, queue_size):
class
PrefetchProcessZMQ
(
mp
.
Process
):
class
PrefetchProcessZMQ
(
mp
.
Process
):
def
__init__
(
self
,
ds
,
conn_name
):
def
__init__
(
self
,
ds
,
conn_name
):
"""
"""
:param ds: a `DataFlow` instance.
:param ds: a `DataFlow` instance.
...
@@ -118,15 +120,17 @@ class PrefetchProcessZMQ(mp.Process):
...
@@ -118,15 +120,17 @@ class PrefetchProcessZMQ(mp.Process):
class
PrefetchDataZMQ
(
ProxyDataFlow
):
class
PrefetchDataZMQ
(
ProxyDataFlow
):
""" Work the same as `PrefetchData`, but faster. """
"""
Prefetch data from a DataFlow using multiple processes, with ZMQ for
communication.
"""
def
__init__
(
self
,
ds
,
nr_proc
=
1
,
pipedir
=
None
):
def
__init__
(
self
,
ds
,
nr_proc
=
1
,
pipedir
=
None
):
"""
"""
:param ds: a `DataFlow` instance.
Args:
:param nr_proc: number of processes to use. When larger than 1, order
ds (DataFlow): input DataFlow.
of datapoints will be random
.
nr_proc (int): number of processes to use
.
:param pipedir: a local directory where the pipes would be
.
pipedir (str): a local directory where the pipes should be put
.
Useful if you're running on non-local FS such as N
FS.
Useful if you're running on non-local FS such as NFS or Gluster
FS.
"""
"""
super
(
PrefetchDataZMQ
,
self
)
.
__init__
(
ds
)
super
(
PrefetchDataZMQ
,
self
)
.
__init__
(
ds
)
try
:
try
:
...
@@ -185,10 +189,20 @@ class PrefetchDataZMQ(ProxyDataFlow):
...
@@ -185,10 +189,20 @@ class PrefetchDataZMQ(ProxyDataFlow):
class
PrefetchOnGPUs
(
PrefetchDataZMQ
):
class
PrefetchOnGPUs
(
PrefetchDataZMQ
):
""" Prefetch with each process having a specific CUDA_VISIBLE_DEVICES
"""
variable"""
Prefetch with each process having its own ``CUDA_VISIBLE_DEVICES`` variable
mapped to one GPU.
"""
def
__init__
(
self
,
ds
,
gpus
,
pipedir
=
None
):
def
__init__
(
self
,
ds
,
gpus
,
pipedir
=
None
):
"""
Args:
ds (DataFlow): input DataFlow.
gpus (list[int]): list of GPUs to use. Will also start this many
of processes.
pipedir (str): a local directory where the pipes should be put.
Useful if you're running on non-local FS such as NFS or GlusterFS.
"""
self
.
gpus
=
gpus
self
.
gpus
=
gpus
super
(
PrefetchOnGPUs
,
self
)
.
__init__
(
ds
,
len
(
gpus
),
pipedir
)
super
(
PrefetchOnGPUs
,
self
)
.
__init__
(
ds
,
len
(
gpus
),
pipedir
)
...
...
tensorpack/dataflow/raw.py
View file @
d8935ef3
...
@@ -7,26 +7,21 @@ import numpy as np
...
@@ -7,26 +7,21 @@ import numpy as np
import
copy
import
copy
from
six.moves
import
range
from
six.moves
import
range
from
.base
import
DataFlow
,
RNGDataFlow
from
.base
import
DataFlow
,
RNGDataFlow
from
..utils.serialize
import
loads
__all__
=
[
'FakeData'
,
'DataFromQueue'
,
'DataFromList'
]
__all__
=
[
'FakeData'
,
'DataFromQueue'
,
'DataFromList'
]
try
:
import
zmq
except
:
pass
else
:
__all__
.
append
(
'DataFromSocket'
)
class
FakeData
(
RNGDataFlow
):
class
FakeData
(
RNGDataFlow
):
""" Generate fake
fixed
data of given shapes"""
""" Generate fake data of given shapes"""
def
__init__
(
self
,
shapes
,
size
,
random
=
True
,
dtype
=
'float32'
):
def
__init__
(
self
,
shapes
,
size
=
1000
,
random
=
True
,
dtype
=
'float32'
):
"""
"""
:param shapes: a list of lists/tuples
Args:
:param size: size of this DataFlow
shapes (list): a list of lists/tuples. Shapes of each component.
:param random: whether to randomly generate data every iteration. note
size (int): size of this DataFlow.
that only generating the data could be time-consuming!
random (bool): whether to randomly generate data every iteration.
Note that merely generating the data could sometimes be time-consuming!
dtype (str): data type.
"""
"""
super
(
FakeData
,
self
)
.
__init__
()
super
(
FakeData
,
self
)
.
__init__
()
self
.
shapes
=
shapes
self
.
shapes
=
shapes
...
@@ -49,8 +44,11 @@ class FakeData(RNGDataFlow):
...
@@ -49,8 +44,11 @@ class FakeData(RNGDataFlow):
class
DataFromQueue
(
DataFlow
):
class
DataFromQueue
(
DataFlow
):
""" Produce data from a queue """
""" Produce data from a queue """
def
__init__
(
self
,
queue
):
def
__init__
(
self
,
queue
):
"""
Args:
queue (queue): a queue with ``get()`` method.
"""
self
.
queue
=
queue
self
.
queue
=
queue
def
get_data
(
self
):
def
get_data
(
self
):
...
@@ -62,6 +60,11 @@ class DataFromList(RNGDataFlow):
...
@@ -62,6 +60,11 @@ class DataFromList(RNGDataFlow):
""" Produce data from a list"""
""" Produce data from a list"""
def
__init__
(
self
,
lst
,
shuffle
=
True
):
def
__init__
(
self
,
lst
,
shuffle
=
True
):
"""
Args:
lst (list): input list.
shuffle (bool): shuffle data.
"""
super
(
DataFromList
,
self
)
.
__init__
()
super
(
DataFromList
,
self
)
.
__init__
()
self
.
lst
=
lst
self
.
lst
=
lst
self
.
shuffle
=
shuffle
self
.
shuffle
=
shuffle
...
@@ -78,22 +81,3 @@ class DataFromList(RNGDataFlow):
...
@@ -78,22 +81,3 @@ class DataFromList(RNGDataFlow):
self
.
rng
.
shuffle
(
idxs
)
self
.
rng
.
shuffle
(
idxs
)
for
k
in
idxs
:
for
k
in
idxs
:
yield
self
.
lst
[
k
]
yield
self
.
lst
[
k
]
class
DataFromSocket
(
DataFlow
):
""" Produce data from a zmq socket"""
def
__init__
(
self
,
socket_name
):
self
.
_name
=
socket_name
def
get_data
(
self
):
try
:
ctx
=
zmq
.
Context
()
socket
=
ctx
.
socket
(
zmq
.
PULL
)
socket
.
bind
(
self
.
_name
)
while
True
:
dp
=
loads
(
socket
.
recv
(
copy
=
False
))
yield
dp
finally
:
ctx
.
destroy
(
linger
=
0
)
tensorpack/dataflow/remote.py
View file @
d8935ef3
...
@@ -19,6 +19,14 @@ from ..utils.serialize import dumps, loads
...
@@ -19,6 +19,14 @@ from ..utils.serialize import dumps, loads
def
serve_data
(
ds
,
addr
):
def
serve_data
(
ds
,
addr
):
"""
Serve the DataFlow on a ZMQ socket addr.
It will dump and send each datapoint to this addr with a PUSH socket.
Args:
ds (DataFlow): DataFlow to serve. Will infinitely loop over the DataFlow.
addr: a ZMQ socket addr.
"""
ctx
=
zmq
.
Context
()
ctx
=
zmq
.
Context
()
socket
=
ctx
.
socket
(
zmq
.
PUSH
)
socket
=
ctx
.
socket
(
zmq
.
PUSH
)
socket
.
set_hwm
(
10
)
socket
.
set_hwm
(
10
)
...
@@ -27,7 +35,7 @@ def serve_data(ds, addr):
...
@@ -27,7 +35,7 @@ def serve_data(ds, addr):
try
:
try
:
ds
.
reset_state
()
ds
.
reset_state
()
logger
.
info
(
"Serving data at {}"
.
format
(
addr
))
logger
.
info
(
"Serving data at {}"
.
format
(
addr
))
# TODO print statistics
here
# TODO print statistics
such as speed
while
True
:
while
True
:
for
dp
in
ds
.
get_data
():
for
dp
in
ds
.
get_data
():
socket
.
send
(
dumps
(
dp
),
copy
=
False
)
socket
.
send
(
dumps
(
dp
),
copy
=
False
)
...
@@ -39,17 +47,25 @@ def serve_data(ds, addr):
...
@@ -39,17 +47,25 @@ def serve_data(ds, addr):
class
RemoteData
(
DataFlow
):
class
RemoteData
(
DataFlow
):
""" Produce data from a ZMQ socket. """
def
__init__
(
self
,
addr
):
def
__init__
(
self
,
addr
):
self
.
ctx
=
zmq
.
Context
()
"""
self
.
socket
=
self
.
ctx
.
socket
(
zmq
.
PULL
)
Args:
self
.
socket
.
set_hwm
(
10
)
addr (str): addr of the socket to connect to.
self
.
socket
.
connect
(
addr
)
"""
self
.
_addr
=
addr
def
get_data
(
self
):
def
get_data
(
self
):
while
True
:
try
:
dp
=
loads
(
self
.
socket
.
recv
(
copy
=
False
))
ctx
=
zmq
.
Context
()
yield
dp
socket
=
ctx
.
socket
(
zmq
.
PULL
)
socket
.
connect
(
self
.
_addr
)
while
True
:
dp
=
loads
(
socket
.
recv
(
copy
=
False
))
yield
dp
finally
:
ctx
.
destroy
(
linger
=
0
)
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
...
...
tensorpack/dataflow/tf_func.py
View file @
d8935ef3
...
@@ -12,7 +12,9 @@ except ImportError:
...
@@ -12,7 +12,9 @@ except ImportError:
logger
.
warn_dependency
(
'TFFuncMapper'
,
'tensorflow'
)
logger
.
warn_dependency
(
'TFFuncMapper'
,
'tensorflow'
)
__all__
=
[]
__all__
=
[]
else
:
else
:
__all__
=
[
'TFFuncMapper'
]
__all__
=
[]
""" This file was deprecated """
class
TFFuncMapper
(
ProxyDataFlow
):
class
TFFuncMapper
(
ProxyDataFlow
):
...
...
tensorpack/models/model_desc.py
View file @
d8935ef3
...
@@ -138,7 +138,7 @@ class ModelFromMetaGraph(ModelDesc):
...
@@ -138,7 +138,7 @@ class ModelFromMetaGraph(ModelDesc):
def
__init__
(
self
,
filename
):
def
__init__
(
self
,
filename
):
"""
"""
Args:
Args:
filename(str): file name of the saved meta graph.
filename
(str): file name of the saved meta graph.
"""
"""
tf
.
train
.
import_meta_graph
(
filename
)
tf
.
train
.
import_meta_graph
(
filename
)
all_coll
=
tf
.
get_default_graph
()
.
get_all_collection_keys
()
all_coll
=
tf
.
get_default_graph
()
.
get_all_collection_keys
()
...
...
tensorpack/models/regularize.py
View file @
d8935ef3
...
@@ -59,9 +59,9 @@ def Dropout(x, keep_prob=0.5, is_training=None):
...
@@ -59,9 +59,9 @@ def Dropout(x, keep_prob=0.5, is_training=None):
Neural Networks from Overfitting <http://dl.acm.org/citation.cfm?id=2670313>`_.
Neural Networks from Overfitting <http://dl.acm.org/citation.cfm?id=2670313>`_.
Args:
Args:
keep_prob: the probability that each element is kept. It is only used
keep_prob
(float)
: the probability that each element is kept. It is only used
when is_training=True.
when is_training=True.
is_training: If None, will use the current :class:`tensorpack.tfutils.TowerContext`
is_training
(bool)
: If None, will use the current :class:`tensorpack.tfutils.TowerContext`
to figure out.
to figure out.
"""
"""
if
is_training
is
None
:
if
is_training
is
None
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment