Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
d8935ef3
Commit
d8935ef3
authored
Jan 05, 2017
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
update sphinx doc for dataflow/
parent
edecca96
Changes
27
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
27 changed files
with
583 additions
and
373 deletions
+583
-373
docs/conf.py
docs/conf.py
+5
-6
tensorpack/dataflow/base.py
tensorpack/dataflow/base.py
+18
-13
tensorpack/dataflow/common.py
tensorpack/dataflow/common.py
+140
-91
tensorpack/dataflow/dataset/bsds500.py
tensorpack/dataflow/dataset/bsds500.py
+8
-8
tensorpack/dataflow/dataset/cifar.py
tensorpack/dataflow/dataset/cifar.py
+11
-12
tensorpack/dataflow/dataset/ilsvrc.py
tensorpack/dataflow/dataset/ilsvrc.py
+35
-31
tensorpack/dataflow/dataset/mnist.py
tensorpack/dataflow/dataset/mnist.py
+4
-3
tensorpack/dataflow/dataset/svhn.py
tensorpack/dataflow/dataset/svhn.py
+6
-4
tensorpack/dataflow/dataset/visualqa.py
tensorpack/dataflow/dataset/visualqa.py
+2
-4
tensorpack/dataflow/dftools.py
tensorpack/dataflow/dftools.py
+30
-19
tensorpack/dataflow/format.py
tensorpack/dataflow/format.py
+45
-15
tensorpack/dataflow/image.py
tensorpack/dataflow/image.py
+20
-15
tensorpack/dataflow/imgaug/base.py
tensorpack/dataflow/imgaug/base.py
+10
-4
tensorpack/dataflow/imgaug/crop.py
tensorpack/dataflow/imgaug/crop.py
+27
-17
tensorpack/dataflow/imgaug/deform.py
tensorpack/dataflow/imgaug/deform.py
+13
-10
tensorpack/dataflow/imgaug/geometry.py
tensorpack/dataflow/imgaug/geometry.py
+16
-5
tensorpack/dataflow/imgaug/imgproc.py
tensorpack/dataflow/imgaug/imgproc.py
+43
-20
tensorpack/dataflow/imgaug/meta.py
tensorpack/dataflow/imgaug/meta.py
+20
-7
tensorpack/dataflow/imgaug/noise.py
tensorpack/dataflow/imgaug/noise.py
+17
-4
tensorpack/dataflow/imgaug/noname.py
tensorpack/dataflow/imgaug/noname.py
+24
-16
tensorpack/dataflow/imgaug/paste.py
tensorpack/dataflow/imgaug/paste.py
+10
-6
tensorpack/dataflow/prefetch.py
tensorpack/dataflow/prefetch.py
+31
-17
tensorpack/dataflow/raw.py
tensorpack/dataflow/raw.py
+17
-33
tensorpack/dataflow/remote.py
tensorpack/dataflow/remote.py
+25
-9
tensorpack/dataflow/tf_func.py
tensorpack/dataflow/tf_func.py
+3
-1
tensorpack/models/model_desc.py
tensorpack/models/model_desc.py
+1
-1
tensorpack/models/regularize.py
tensorpack/models/regularize.py
+2
-2
No files found.
docs/conf.py
View file @
d8935ef3
...
...
@@ -69,9 +69,11 @@ extensions = [
#'sphinx.ext.coverage',
#'sphinx.ext.mathjax',
'sphinx.ext.mathbase'
,
'sphinx.ext.intersphinx'
,
'sphinx.ext.viewcode'
,
]
napoleon_google_docstring
=
True
napoleon_include_init_with_doc
=
True
napoleon_numpy_docstring
=
False
napoleon_use_rtype
=
False
...
...
@@ -332,11 +334,9 @@ texinfo_documents = [
# If true, do not generate a @detailmenu in the "Top" node's menu.
#texinfo_no_detailmenu = False
def
skip
(
app
,
what
,
name
,
obj
,
skip
,
options
):
# keep __init__
if
name
==
"__init__"
:
return
False
return
skip
intersphinx_timeout
=
0.1
intersphinx_mapping
=
{
'python'
:
(
'https://docs.python.org/3.4'
,
None
)}
def
process_signature
(
app
,
what
,
name
,
obj
,
options
,
signature
,
return_annotation
):
...
...
@@ -350,7 +350,6 @@ def process_signature(app, what, name, obj, options, signature,
def
setup
(
app
):
from
recommonmark.transform
import
AutoStructify
app
.
connect
(
'autodoc-process-signature'
,
process_signature
)
app
.
connect
(
"autodoc-skip-member"
,
skip
)
app
.
add_config_value
(
'recommonmark_config'
,
{
'url_resolver'
:
lambda
url
:
\
...
...
tensorpack/dataflow/base.py
View file @
d8935ef3
...
...
@@ -15,37 +15,41 @@ __all__ = ['DataFlow', 'ProxyDataFlow', 'RNGDataFlow']
class
DataFlow
(
object
):
""" Base class for all DataFlow """
class
Infinity
:
pass
@
abstractmethod
def
get_data
(
self
):
"""
A generator to generate data as a list.
Datapoint should be a mutable list.
Each component should be assumed immutable.
The method to generate datapoints.
Yields:
list: The datapoint, i.e. list of components.
"""
def
size
(
self
):
"""
Size of this data flow.
Returns:
int: size of this data flow.
Raises:
:class:`NotImplementedError` if this DataFlow doesn't have a size.
"""
raise
NotImplementedError
()
def
reset_state
(
self
):
"""
Reset state of the dataflow. Will always be called before consuming data points.
for example, RNG **HAS** to be reset here if used in the DataFlow.
Otherwise it may not work well with prefetching, because different
Reset state of the dataflow. It has to be called before producing datapoints.
For example, RNG **has to** be reset if used in the DataFlow,
otherwise it won't work well with prefetching, because different
processes will have the same RNG state.
"""
pass
class
RNGDataFlow
(
DataFlow
):
""" A
dataflow with rng
"""
""" A
DataFlow with RNG
"""
def
reset_state
(
self
):
""" Reset the RNG """
self
.
rng
=
get_rng
(
self
)
...
...
@@ -54,13 +58,14 @@ class ProxyDataFlow(DataFlow):
def
__init__
(
self
,
ds
):
"""
:param ds: a :mod:`DataFlow` instance to proxy
Args:
ds (DataFlow): DataFlow to proxy.
"""
self
.
ds
=
ds
def
reset_state
(
self
):
"""
Will reset state of the proxied DataFlow
Reset state of the proxied DataFlow.
"""
self
.
ds
.
reset_state
()
...
...
tensorpack/dataflow/common.py
View file @
d8935ef3
This diff is collapsed.
Click to expand it.
tensorpack/dataflow/dataset/bsds500.py
View file @
d8935ef3
...
...
@@ -25,20 +25,20 @@ IMG_W, IMG_H = 481, 321
class
BSDS500
(
RNGDataFlow
):
"""
`Berkeley Segmentation Data Set and Benchmarks 500
`Berkeley Segmentation Data Set and Benchmarks 500
dataset
<http://www.eecs.berkeley.edu/Research/Projects/CS/vision/grouping/resources.html#bsds500>`_.
Produce (image, label) pair, where image has shape (321, 481, 3) and
ranges in [0,255]. Label is binary and has shape (321, 481).
Those pixels annotated as boundaries by <=2 annotators are set to 0.
This is used in `Holistically-Nested Edge Detection
<http://arxiv.org/abs/1504.06375>`_.
Produce ``(image, label)`` pair, where ``image`` has shape (321, 481, 3(BGR)) and
ranges in [0,255].
``Label`` is a floating point image of shape (321, 481) in range [0, 1].
The value of each pixel is ``number of times it is annotated as edge / total number of annotators for this image``.
"""
def
__init__
(
self
,
name
,
data_dir
=
None
,
shuffle
=
True
):
"""
:param name: 'train', 'test', 'val'
:param data_dir: a directory containing the original 'BSR' directory.
Args:
name (str): 'train', 'test', 'val'
data_dir (str): a directory containing the original 'BSR' directory.
"""
# check and download data
if
data_dir
is
None
:
...
...
tensorpack/dataflow/dataset/cifar.py
View file @
d8935ef3
...
...
@@ -80,17 +80,7 @@ def get_filenames(dir, cifar_classnum):
class
CifarBase
(
RNGDataFlow
):
"""
Return [image, label],
image is 32x32x3 in the range [0,255]
"""
def
__init__
(
self
,
train_or_test
,
shuffle
=
True
,
dir
=
None
,
cifar_classnum
=
10
):
"""
Args:
train_or_test: string either 'train' or 'test'
shuffle: default to True
"""
assert
train_or_test
in
[
'train'
,
'test'
]
assert
cifar_classnum
==
10
or
cifar_classnum
==
100
self
.
cifar_classnum
=
cifar_classnum
...
...
@@ -139,13 +129,22 @@ class CifarBase(RNGDataFlow):
class
Cifar10
(
CifarBase
):
"""
Produces [image, label] in Cifar10 dataset,
image is 32x32x3 in the range [0,255].
label is an int.
"""
def
__init__
(
self
,
train_or_test
,
shuffle
=
True
,
dir
=
None
):
"""
Args:
train_or_test (str): either 'train' or 'test'.
shuffle (bool): shuffle the dataset.
"""
super
(
Cifar10
,
self
)
.
__init__
(
train_or_test
,
shuffle
,
dir
,
10
)
class
Cifar100
(
CifarBase
):
""" Similar to Cifar10"""
def
__init__
(
self
,
train_or_test
,
shuffle
=
True
,
dir
=
None
):
super
(
Cifar100
,
self
)
.
__init__
(
train_or_test
,
shuffle
,
dir
,
100
)
...
...
tensorpack/dataflow/dataset/ilsvrc.py
View file @
d8935ef3
...
...
@@ -22,7 +22,7 @@ CAFFE_ILSVRC12_URL = "http://dl.caffe.berkeleyvision.org/caffe_ilsvrc12.tar.gz"
class
ILSVRCMeta
(
object
):
"""
Some
metadata for ILSVRC dataset.
Provide methods to access
metadata for ILSVRC dataset.
"""
def
__init__
(
self
,
dir
=
None
):
...
...
@@ -37,7 +37,8 @@ class ILSVRCMeta(object):
def
get_synset_words_1000
(
self
):
"""
:returns a dict of {cls_number: cls_name}
Returns:
dict: {cls_number: cls_name}
"""
fname
=
os
.
path
.
join
(
self
.
dir
,
'synset_words.txt'
)
assert
os
.
path
.
isfile
(
fname
)
...
...
@@ -46,7 +47,8 @@ class ILSVRCMeta(object):
def
get_synset_1000
(
self
):
"""
:returns a dict of {cls_number: synset_id}
Returns:
dict: {cls_number: synset_id}
"""
fname
=
os
.
path
.
join
(
self
.
dir
,
'synsets.txt'
)
assert
os
.
path
.
isfile
(
fname
)
...
...
@@ -59,8 +61,10 @@ class ILSVRCMeta(object):
def
get_image_list
(
self
,
name
):
"""
:param name: 'train' or 'val' or 'test'
:returns: list of (image filename, cls)
Args:
name (str): 'train' or 'val' or 'test'
Returns:
list: list of (image filename, label)
"""
assert
name
in
[
'train'
,
'val'
,
'test'
]
fname
=
os
.
path
.
join
(
self
.
dir
,
name
+
'.txt'
)
...
...
@@ -75,8 +79,10 @@ class ILSVRCMeta(object):
def
get_per_pixel_mean
(
self
,
size
=
None
):
"""
:param size: return image size in [h, w]. default to (256, 256)
:returns: per-pixel mean as an array of shape (h, w, 3) in range [0, 255]
Args:
size (tuple): image size in (h, w). Defaults to (256, 256).
Returns:
np.ndarray: per-pixel mean of shape (h, w, 3 (BGR)) in range [0, 255].
"""
obj
=
self
.
caffepb
.
BlobProto
()
...
...
@@ -91,18 +97,26 @@ class ILSVRCMeta(object):
class
ILSVRC12
(
RNGDataFlow
):
"""
Produces ILSVRC12 images of shape [h, w, 3(BGR)], and a label between [0, 999],
and optionally a bounding box of [xmin, ymin, xmax, ymax].
"""
def
__init__
(
self
,
dir
,
name
,
meta_dir
=
None
,
shuffle
=
True
,
dir_structure
=
'original'
,
include_bb
=
False
):
"""
:param dir: A directory containing a subdir named `name`, where the
original ILSVRC12_`name`.tar gets decompressed.
:param name: 'train' or 'val' or 'test'
:param dir_structure: The dir structure of 'val' and 'test'.
If is 'original' then keep the original decompressed directory with list
of image files (as below). If set to 'train', use the the same
directory structure as 'train/', with class name as subdirectories.
:param include_bb: Include the bounding box. Maybe useful in training.
Args:
dir (str): A directory containing a subdir named ``name``, where the
original ``ILSVRC12_img_{name}.tar`` gets decompressed.
name (str): 'train' or 'val' or 'test'.
shuffle (bool): shuffle the dataset.
dir_structure (str): The dir structure of 'val' and 'test' directory.
If is 'original', it expects the original decompressed
directory, which only has list of image files (as below).
If set to 'train', it expects the same two-level
directory structure simlar to 'train/'.
include_bb (bool): Include the bounding box. Maybe useful in training.
Examples:
When `dir_structure=='original'`, `dir` should have the following structure:
...
...
@@ -120,22 +134,16 @@ class ILSVRC12(RNGDataFlow):
test/
ILSVRC2012_test_00000001.JPEG
...
bbox/
n02134418/
n02134418_198.xml
...
...
After decompress ILSVRC12_img_train
.tar, you can use the following
command to build the above structure
for `train/`
:
With ILSVRC12_img_*
.tar, you can use the following
command to build the above structure:
.. code-block:: none
tar xvf ILSVRC12_img_train.tar -C train && cd train
mkdir val && tar xvf ILSVRC12_img_val.tar -C val
mkdir test && tar xvf ILSVRC12_img_test.tar -C test
mkdir train && tar xvf ILSVRC12_img_train.tar -C train && cd train
find -type f -name '*.tar' | parallel -P 10 'echo {} && mkdir -p {/.} && tar xf {} -C {/.}'
Or:
for i in *.tar; do dir=${i
%
.tar}; echo $dir; mkdir -p $dir; tar xf $i -C $dir; done
"""
assert
name
in
[
'train'
,
'test'
,
'val'
]
self
.
full_dir
=
os
.
path
.
join
(
dir
,
name
)
...
...
@@ -158,10 +166,6 @@ class ILSVRC12(RNGDataFlow):
return
len
(
self
.
imglist
)
def
get_data
(
self
):
"""
Produce original images of shape [h, w, 3(BGR)], and label,
and optionally a bbox of [xmin, ymin, xmax, ymax]
"""
idxs
=
np
.
arange
(
len
(
self
.
imglist
))
add_label_to_fname
=
(
self
.
name
!=
'train'
and
self
.
dir_structure
!=
'original'
)
if
self
.
shuffle
:
...
...
tensorpack/dataflow/dataset/mnist.py
View file @
d8935ef3
...
...
@@ -65,14 +65,15 @@ def extract_labels(filename):
class
Mnist
(
RNGDataFlow
):
"""
Return [image, label]
,
image is 28x28 in the range [0,1]
Produces [image, label] in MNIST dataset
,
image is 28x28 in the range [0,1], label is an int.
"""
def
__init__
(
self
,
train_or_test
,
shuffle
=
True
,
dir
=
None
):
"""
Args:
train_or_test: string either 'train' or 'test'
train_or_test (str): either 'train' or 'test'
shuffle (bool): shuffle the dataset
"""
if
dir
is
None
:
dir
=
get_dataset_path
(
'mnist_data'
)
...
...
tensorpack/dataflow/dataset/svhn.py
View file @
d8935ef3
...
...
@@ -21,15 +21,17 @@ SVHN_URL = "http://ufldl.stanford.edu/housenumbers/"
class
SVHNDigit
(
RNGDataFlow
):
"""
SVHN
Cropped Digit Dataset.
return img of 32x32x3
, label of 0-9
`SVHN <http://ufldl.stanford.edu/housenumbers/>`_
Cropped Digit Dataset.
Produces [img, label], img of 32x32x3 in range [0,255]
, label of 0-9
"""
_Cache
=
{}
def
__init__
(
self
,
name
,
data_dir
=
None
,
shuffle
=
True
):
"""
:param name: 'train', 'test', or 'extra'
:param data_dir: a directory containing the original {train,test,extra}_32x32.mat
Args:
name (str): 'train', 'test', or 'extra'.
data_dir (str): a directory containing the original {train,test,extra}_32x32.mat.
shuffle (bool): shuffle the dataset.
"""
self
.
shuffle
=
shuffle
...
...
tensorpack/dataflow/dataset/visualqa.py
View file @
d8935ef3
...
...
@@ -18,13 +18,11 @@ def read_json(fname):
f
.
close
()
return
ret
# TODO shuffle
class
VisualQA
(
DataFlow
):
"""
Visual QA dataset. See http://visualqa.org/
Simply read
q/a json file and produce q/a pairs in their original format.
`Visual QA <http://visualqa.org/>`_ dataset.
It simply reads
q/a json file and produce q/a pairs in their original format.
"""
def
__init__
(
self
,
question_file
,
annotation_file
):
...
...
tensorpack/dataflow/dftools.py
View file @
d8935ef3
...
...
@@ -23,17 +23,17 @@ except ImportError:
else
:
__all__
.
extend
([
'dump_dataflow_to_lmdb'
])
# TODO pass a name_func to write label as filename?
def
dump_dataset_images
(
ds
,
dirname
,
max_count
=
None
,
index
=
0
):
""" Dump images from a
`DataFlow`
to a directory.
""" Dump images from a
DataFlow
to a directory.
:param ds: a `DataFlow` instance.
:param dirname: name of the directory.
:param max_count: max number of images to dump
:param index: the index of the image component in a data point.
Args:
ds (DataFlow): the DataFlow to dump.
dirname (str): name of the directory.
max_count (int): limit max number of images to dump. Defaults to unlimited.
index (int): the index of the image component in the data point.
"""
# TODO pass a name_func to write label as filename?
mkdir_p
(
dirname
)
if
max_count
is
None
:
max_count
=
sys
.
maxint
...
...
@@ -48,9 +48,15 @@ def dump_dataset_images(ds, dirname, max_count=None, index=0):
def
dump_dataflow_to_lmdb
(
ds
,
lmdb_path
):
""" Dump a `Dataflow` ds to a lmdb database, where the key is the index
and the data is the serialized datapoint.
The output database can be read directly by `LMDBDataPoint`
"""
Dump a Dataflow to a lmdb database, where the keys are indices and values
are serialized datapoints.
The output database can be read directly by
:class:`tensorpack.dataflow.LMDBDataPoint`.
Args:
ds (DataFlow): the DataFlow to dump.
lmdb_path (str): output path. Either a directory or a mdb file.
"""
assert
isinstance
(
ds
,
DataFlow
),
type
(
ds
)
isdir
=
os
.
path
.
isdir
(
lmdb_path
)
...
...
@@ -80,15 +86,20 @@ def dump_dataflow_to_lmdb(ds, lmdb_path):
def
dataflow_to_process_queue
(
ds
,
size
,
nr_consumer
):
"""
Convert a `DataFlow` to a multiprocessing.Queue.
The dataflow will only be reset in the spawned process.
:param ds: a `DataFlow`
:param size: size of the queue
:param nr_consumer: number of consumer of the queue.
will add this many of `DIE` sentinel to the end of the queue.
:returns: (queue, process). The process will take data from `ds` to fill
the queue once you start it. Each element is (task_id, dp).
Convert a DataFlow to a :class:`multiprocessing.Queue`.
The DataFlow will only be reset in the spawned process.
Args:
ds (DataFlow): the DataFlow to dump.
size (int): size of the queue
nr_consumer (int): number of consumer of the queue.
The producer will add this many of ``DIE`` sentinel to the end of the queue.
Returns:
tuple(queue, process):
The process will take data from ``ds`` and fill
the queue, once you start it. Each element in the queue is (idx,
dp). idx can be the ``DIE`` sentinel when ``ds`` is exhausted.
"""
q
=
mp
.
Queue
(
size
)
...
...
tensorpack/dataflow/format.py
View file @
d8935ef3
...
...
@@ -26,7 +26,7 @@ try:
except
ImportError
:
logger
.
warn_dependency
(
"LMDBData"
,
'lmdb'
)
else
:
__all__
.
extend
([
'LMDBData'
,
'
CaffeLMDB'
,
'LMDBDataDecoder'
,
'LMDBDataPoint
'
])
__all__
.
extend
([
'LMDBData'
,
'
LMDBDataDecoder'
,
'LMDBDataPoint'
,
'CaffeLMDB
'
])
try
:
import
sklearn.datasets
...
...
@@ -40,19 +40,23 @@ else:
Adapters for different data format.
"""
# TODO lazy load
class
HDF5Data
(
RNGDataFlow
):
"""
Zip data from different paths in an HDF5 file. Will load all data into memory.
Zip data from different paths in an HDF5 file.
Warning:
The current implementation will load all data into memory.
"""
# TODO lazy load
def
__init__
(
self
,
filename
,
data_paths
,
shuffle
=
True
):
"""
:param filename: h5 data file.
:param data_paths: list of h5 paths to zipped. For example ['images', 'labels']
:param shuffle: shuffle the order of all data.
Args:
filename (str): h5 data file.
data_paths (list): list of h5 paths to zipped.
For example `['images', 'labels']`.
shuffle (bool): shuffle all data.
"""
self
.
f
=
h5py
.
File
(
filename
,
'r'
)
logger
.
info
(
"Loading {} to memory..."
.
format
(
filename
))
...
...
@@ -74,9 +78,13 @@ class HDF5Data(RNGDataFlow):
class
LMDBData
(
RNGDataFlow
):
""" Read a lmdb and produce k,v pair """
""" Read a LMDB database and produce (k,v) pairs """
def
__init__
(
self
,
lmdb_path
,
shuffle
=
True
):
"""
Args:
lmdb_path (str): a directory or a file.
shuffle (bool): shuffle the keys or not.
"""
self
.
_lmdb_path
=
lmdb_path
self
.
_shuffle
=
shuffle
self
.
open_lmdb
()
...
...
@@ -122,11 +130,14 @@ class LMDBData(RNGDataFlow):
class
LMDBDataDecoder
(
LMDBData
):
""" Read a LMDB database and produce a decoded output."""
def
__init__
(
self
,
lmdb_path
,
decoder
,
shuffle
=
True
):
"""
:param decoder: a function taking k, v and return a data point,
or return None to skip
Args:
lmdb_path (str): a directory or a file.
decoder (k,v -> dp | None): a function taking k, v and returning a datapoint,
or return None to discard.
shuffle (bool): shuffle the keys or not.
"""
super
(
LMDBDataDecoder
,
self
)
.
__init__
(
lmdb_path
,
shuffle
)
self
.
decoder
=
decoder
...
...
@@ -139,17 +150,31 @@ class LMDBDataDecoder(LMDBData):
class
LMDBDataPoint
(
LMDBDataDecoder
):
""" Read a LMDB file where each value is a serialized datapoint"""
""" Read a LMDB file and produce deserialized values.
This can work with :func:`tensorpack.dataflow.dftools.dump_dataflow_to_lmdb`. """
def
__init__
(
self
,
lmdb_path
,
shuffle
=
True
):
"""
Args:
lmdb_path (str): a directory or a file.
shuffle (bool): shuffle the keys or not.
"""
super
(
LMDBDataPoint
,
self
)
.
__init__
(
lmdb_path
,
decoder
=
lambda
k
,
v
:
loads
(
v
),
shuffle
=
shuffle
)
class
CaffeLMDB
(
LMDBDataDecoder
):
""" Read a Caffe LMDB file where each value contains a caffe.Datum protobuf """
"""
Read a Caffe LMDB file where each value contains a ``caffe.Datum`` protobuf.
Produces datapoints of the format: [HWC image, label].
"""
def
__init__
(
self
,
lmdb_path
,
shuffle
=
True
):
"""
Args:
lmdb_path (str): a directory or a file.
shuffle (bool): shuffle the keys or not.
"""
cpb
=
get_caffe_pb
()
def
decoder
(
k
,
v
):
...
...
@@ -168,9 +193,14 @@ class CaffeLMDB(LMDBDataDecoder):
class
SVMLightData
(
RNGDataFlow
):
""" Read X,y from a svmlight file """
""" Read X,y from a svmlight file
, and produce [X_i, y_i] pairs.
"""
def
__init__
(
self
,
filename
,
shuffle
=
True
):
"""
Args:
filename (str): input file
shuffle (bool): shuffle the data
"""
self
.
X
,
self
.
y
=
sklearn
.
datasets
.
load_svmlight_file
(
filename
)
self
.
X
=
np
.
asarray
(
self
.
X
.
todense
())
self
.
shuffle
=
shuffle
...
...
tensorpack/dataflow/image.py
View file @
d8935ef3
...
...
@@ -12,13 +12,13 @@ __all__ = ['ImageFromFile', 'AugmentImageComponent', 'AugmentImageComponents']
class
ImageFromFile
(
RNGDataFlow
):
""" Produce images read from a list of files. """
def
__init__
(
self
,
files
,
channel
=
3
,
resize
=
None
,
shuffle
=
False
):
"""
Generate images of 1 channel or 3 channels (in RGB order) from list of files.
:param files: list of file paths
:param channel: 1 or 3 channel
:param resize: a (h, w) tuple. If given, will force a resize
Args:
files (list): list of file paths.
channel (int): 1 or 3. Produce RGB images if channel==3.
resize (tuple): (h, w). If given, resize the image.
"""
assert
len
(
files
),
"No image files given to ImageFromFile!"
self
.
files
=
files
...
...
@@ -45,14 +45,15 @@ class ImageFromFile(RNGDataFlow):
class
AugmentImageComponent
(
MapDataComponent
):
"""
Apply image augmentors on 1 component.
"""
def
__init__
(
self
,
ds
,
augmentors
,
index
=
0
):
"""
Augment the image component of datapoints
:param ds: a `DataFlow` instance.
:param augmentors: a list of `ImageAugmentor` instance to be applied in order.
:param index: the index (or list of indices) of the image component
in the produced datapoints by `ds`. default to be 0
Args:
ds (DataFlow): input DataFlow.
augmentors (AugmentorList): a list of :class:`imgaug.ImageAugmentor` to be applied in order.
index (int): the index of the image component to be augmented.
"""
if
isinstance
(
augmentors
,
AugmentorList
):
self
.
augs
=
augmentors
...
...
@@ -67,12 +68,16 @@ class AugmentImageComponent(MapDataComponent):
class
AugmentImageComponents
(
MapData
):
"""
Apply image augmentors on several components, with shared augmentation parameters.
"""
def
__init__
(
self
,
ds
,
augmentors
,
index
=
(
0
,
1
)):
""" Augment a list of images of the same shape, with the same parameters
:param ds: a `DataFlow` instance.
:param augmentors: a list of `ImageAugmentor` instance to be applied in order.
:param index: tuple of indices of the image components
"""
Args:
ds (DataFlow): input DataFlow.
augmentors (AugmentorList): a list of :class:`imgaug.ImageAugmentor` instance to be applied in order.
index: tuple of indices of components.
"""
self
.
augs
=
AugmentorList
(
augmentors
)
self
.
ds
=
ds
...
...
tensorpack/dataflow/imgaug/base.py
View file @
d8935ef3
...
...
@@ -24,6 +24,7 @@ class Augmentor(object):
setattr
(
self
,
k
,
v
)
def
reset_state
(
self
):
""" reset rng and other state """
self
.
rng
=
get_rng
(
self
)
def
augment
(
self
,
d
):
...
...
@@ -64,9 +65,13 @@ class ImageAugmentor(Augmentor):
def
augment
(
self
,
img
):
"""
Perform augmentation on the image in-place.
:param img: an [h,w] or [h,w,c] image
:returns: the augmented image, always of type 'float32'
Perform augmentation on the image (possibly) in-place.
Args:
img (np.ndarray): an [h,w] or [h,w,c] image.
Returns:
np.ndarray: the augmented image, always of type float32.
"""
img
,
params
=
self
.
_augment_return_params
(
img
)
return
img
...
...
@@ -82,7 +87,8 @@ class AugmentorList(ImageAugmentor):
def
__init__
(
self
,
augmentors
):
"""
:param augmentors: list of `ImageAugmentor` instance to be applied
Args:
augmentors (list): list of :class:`ImageAugmentor` instance to be applied.
"""
self
.
augs
=
augmentors
super
(
AugmentorList
,
self
)
.
__init__
()
...
...
tensorpack/dataflow/imgaug/crop.py
View file @
d8935ef3
...
...
@@ -10,7 +10,7 @@ from six.moves import range
import
numpy
as
np
__all__
=
[
'RandomCrop'
,
'CenterCrop'
,
'FixedCrop'
,
'
RandomCropRandomShape'
,
'perturb_BB'
,
'RandomCropAroundBox
'
]
'
perturb_BB'
,
'RandomCropAroundBox'
,
'RandomCropRandomShape
'
]
class
RandomCrop
(
ImageAugmentor
):
...
...
@@ -18,7 +18,8 @@ class RandomCrop(ImageAugmentor):
def
__init__
(
self
,
crop_shape
):
"""
:param crop_shape: a shape like (h, w)
Args:
crop_shape: (h, w) tuple or a int
"""
crop_shape
=
shape2d
(
crop_shape
)
super
(
RandomCrop
,
self
)
.
__init__
()
...
...
@@ -47,7 +48,8 @@ class CenterCrop(ImageAugmentor):
def
__init__
(
self
,
crop_shape
):
"""
:param crop_shape: a shape like (h, w)
Args:
crop_shape: (h, w) tuple or a int
"""
crop_shape
=
shape2d
(
crop_shape
)
self
.
_init
(
locals
())
...
...
@@ -67,9 +69,8 @@ class FixedCrop(ImageAugmentor):
def
__init__
(
self
,
rect
):
"""
Two arguments defined the range in both axes to crop, min inclued, max excluded.
:param rect: a `Rect` instance
Args:
rect(Rect): min included, max excluded.
"""
self
.
_init
(
locals
())
...
...
@@ -86,12 +87,15 @@ def perturb_BB(image_shape, bb, max_pertub_pixel,
max_try
=
100
):
"""
Perturb a bounding box.
:param image_shape: [h, w]
:param bb: a `Rect` instance
:param max_pertub_pixel: pertubation on each coordinate
:param max_aspect_ratio_diff: result can't have an aspect ratio too different from the original
:param max_try: if cannot find a valid bounding box, return the original
:returns: new bounding box
Args:
image_shape: [h, w]
bb (Rect): original bounding box
max_pertub_pixel: pertubation on each coordinate
max_aspect_ratio_diff: result can't have an aspect ratio too different from the original
max_try: if cannot find a valid bounding box, return the original
Returns:
new bounding box
"""
orig_ratio
=
bb
.
h
*
1.0
/
bb
.
w
if
rng
is
None
:
...
...
@@ -117,13 +121,15 @@ def perturb_BB(image_shape, bb, max_pertub_pixel,
class
RandomCropAroundBox
(
ImageAugmentor
):
"""
Crop a box around a bounding box
Crop a box around a bounding box
by some random pertubation
"""
def
__init__
(
self
,
perturb_ratio
,
max_aspect_ratio_diff
=
0.3
):
"""
:param perturb_ratio: perturb distance will be in [0, perturb_ratio * sqrt(w * h)]
:param max_aspect_ratio_diff: keep aspect ratio within the range
Args:
perturb_ratio (float): perturb distance will be in
``[0, perturb_ratio * sqrt(w * h)]``
max_aspect_ratio_diff (float): keep aspect ratio difference within the range
"""
super
(
RandomCropAroundBox
,
self
)
.
__init__
()
self
.
_init
(
locals
())
...
...
@@ -144,14 +150,18 @@ class RandomCropAroundBox(ImageAugmentor):
class
RandomCropRandomShape
(
ImageAugmentor
):
""" Random crop with a random shape"""
def
__init__
(
self
,
wmin
,
hmin
,
wmax
=
None
,
hmax
=
None
,
max_aspect_ratio
=
None
):
"""
Randomly crop a box of shape (h, w), sampled from [min, max]
(
inclusive).
Randomly crop a box of shape (h, w), sampled from [min, max]
(both
inclusive).
If max is None, will use the input image shape.
max_aspect_ratio is the upper bound of max(w,h)/min(w,h)
Args:
wmin, hmin, wmax, hmax: range to sample shape.
max_aspect_ratio (float): the upper bound of ``max(w,h)/min(w,h)``.
"""
if
max_aspect_ratio
is
None
:
max_aspect_ratio
=
9999999
...
...
tensorpack/dataflow/imgaug/deform.py
View file @
d8935ef3
...
...
@@ -6,13 +6,12 @@ from .base import ImageAugmentor
from
...utils
import
logger
import
numpy
as
np
__all__
=
[
'GaussianDeform'
,
'GaussianMap'
]
# TODO really needs speedup
__all__
=
[
'GaussianDeform'
]
class
GaussianMap
(
object
):
""" Generate gaussian weighted deformation map"""
# TODO really needs speedup
def
__init__
(
self
,
image_shape
,
sigma
=
0.5
):
assert
len
(
image_shape
)
==
2
...
...
@@ -20,6 +19,10 @@ class GaussianMap(object):
self
.
sigma
=
sigma
def
get_gaussian_weight
(
self
,
anchor
):
"""
Args:
anchor: coordinate of the center
"""
ret
=
np
.
zeros
(
self
.
shape
,
dtype
=
'float32'
)
y
,
x
=
np
.
mgrid
[:
self
.
shape
[
0
],
:
self
.
shape
[
1
]]
...
...
@@ -55,20 +58,20 @@ def np_sample(img, coords):
img
[
ucoory
,
lcoorx
,
:]
*
diffy
*
ndiffx
return
ret
[:,
:,
0
,
:]
# TODO input/output with different shape
class
GaussianDeform
(
ImageAugmentor
):
"""
Some kind of
deformation. Quite slow
.
Some kind of
slow deformation
.
"""
# TODO input/output with different shape
def
__init__
(
self
,
anchors
,
shape
,
sigma
=
0.5
,
randrange
=
None
):
"""
:param anchors: in [0,1] coordinate
:param shape: image shape in [h, w]
:param sigma: sigma for Gaussian weight
:param randrange: default to shape[0] / 8
Args:
anchors (list): list of center coordinates in range [0,1].
shape(list or tuple): image shape in [h, w].
sigma (float): sigma for Gaussian weight
randrange (int): offset range. Defaults to shape[0] / 8
"""
logger
.
warn
(
"GaussianDeform is slow. Consider using it with 4 or more prefetching processes."
)
super
(
GaussianDeform
,
self
)
.
__init__
()
...
...
tensorpack/dataflow/imgaug/geometry.py
View file @
d8935ef3
...
...
@@ -17,8 +17,11 @@ class Rotation(ImageAugmentor):
interp
=
cv2
.
INTER_CUBIC
,
border
=
cv2
.
BORDER_REPLICATE
):
"""
:param max_deg: max abs value of the rotation degree
:param center_range: the location of the rotation center
Args:
max_deg (float): max abs value of the rotation degree (in angle).
center_range (tuple): (min, max) range of the random rotation center.
interp: cv2 interpolation method
border: cv2 border method
"""
super
(
Rotation
,
self
)
.
__init__
()
self
.
_init
(
locals
())
...
...
@@ -36,11 +39,16 @@ class Rotation(ImageAugmentor):
class
RotationAndCropValid
(
ImageAugmentor
):
""" Random rotate and
crop the largest possible rect without the border
T
his will produce images of different shapes.
""" Random rotate and
then crop the largest possible rectangle.
Note that t
his will produce images of different shapes.
"""
def
__init__
(
self
,
max_deg
,
interp
=
cv2
.
INTER_CUBIC
):
"""
Args:
max_deg (float): max abs value of the rotation degree (in angle).
interp: cv2 interpolation method
"""
super
(
RotationAndCropValid
,
self
)
.
__init__
()
self
.
_init
(
locals
())
...
...
@@ -63,7 +71,10 @@ class RotationAndCropValid(ImageAugmentor):
@
staticmethod
def
largest_rotated_rect
(
w
,
h
,
angle
):
""" http://stackoverflow.com/questions/16702966/rotate-image-and-crop-out-black-borders """
"""
Get largest rectangle after rotation.
http://stackoverflow.com/questions/16702966/rotate-image-and-crop-out-black-borders
"""
angle
=
angle
/
180.0
*
math
.
pi
if
w
<=
0
or
h
<=
0
:
return
0
,
0
...
...
tensorpack/dataflow/imgaug/imgproc.py
View file @
d8935ef3
...
...
@@ -12,9 +12,8 @@ __all__ = ['Brightness', 'Contrast', 'MeanVarianceNormalize', 'GaussianBlur',
class
Brightness
(
ImageAugmentor
):
"""
Random adjust brightness.
Random
ly
adjust brightness.
"""
def
__init__
(
self
,
delta
,
clip
=
True
):
"""
Randomly add a value within [-delta,delta], and clip in [0,255] if clip is True.
...
...
@@ -36,14 +35,14 @@ class Brightness(ImageAugmentor):
class
Contrast
(
ImageAugmentor
):
"""
Apply x = (x - mean) * contrast_factor + mean to each channel
and clip to [0, 255]
Apply ``x = (x - mean) * contrast_factor + mean`` to each channel.
"""
def
__init__
(
self
,
factor_range
,
clip
=
True
):
"""
:param factor_range: an interval to random sample the `contrast_factor`.
:param clip: boolean.
Args:
factor_range (list or tuple): an interval to randomly sample the `contrast_factor`.
clip (bool): clip to [0, 255] if True.
"""
super
(
Contrast
,
self
)
.
__init__
()
self
.
_init
(
locals
())
...
...
@@ -61,14 +60,15 @@ class Contrast(ImageAugmentor):
class
MeanVarianceNormalize
(
ImageAugmentor
):
"""
Linearly scales image to have zero mean and unit norm.
x = (x - mean) / adjusted_stddev
where
adjusted_stddev = max(stddev, 1.0/sqrt(num_pixels * channels))
Linearly scales
the
image to have zero mean and unit norm.
``x = (x - mean) / adjusted_stddev``
where
``adjusted_stddev = max(stddev, 1.0/sqrt(num_pixels * channels))``
"""
def
__init__
(
self
,
all_channel
=
True
):
"""
:param all_channel: if True, normalize all channels together. else separately.
Args:
all_channel (bool): if True, normalize all channels together. else separately.
"""
self
.
all_channel
=
all_channel
...
...
@@ -85,9 +85,13 @@ class MeanVarianceNormalize(ImageAugmentor):
class
GaussianBlur
(
ImageAugmentor
):
""" Gaussian blur the image with random window size"""
def
__init__
(
self
,
max_size
=
3
):
""":params max_size: (maximum kernel size-1)/2"""
"""
Args:
max_size (int): max possible Gaussian window size would be 2 * max_size + 1
"""
super
(
GaussianBlur
,
self
)
.
__init__
()
self
.
_init
(
locals
())
...
...
@@ -103,8 +107,12 @@ class GaussianBlur(ImageAugmentor):
class
Gamma
(
ImageAugmentor
):
""" Randomly adjust gamma """
def
__init__
(
self
,
range
=
(
-
0.5
,
0.5
)):
"""
Args:
range(list or tuple): gamma range
"""
super
(
Gamma
,
self
)
.
__init__
()
self
.
_init
(
locals
())
...
...
@@ -119,8 +127,13 @@ class Gamma(ImageAugmentor):
class
Clip
(
ImageAugmentor
):
""" Clip the pixel values """
def
__init__
(
self
,
min
=
0
,
max
=
255
):
"""
Args:
min, max: the clip range
"""
self
.
_init
(
locals
())
def
_augment
(
self
,
img
,
_
):
...
...
@@ -129,10 +142,15 @@ class Clip(ImageAugmentor):
class
Saturation
(
ImageAugmentor
):
""" Randomly adjust saturation.
Follows the implementation in `fb.resnet.torch
<https://github.com/facebook/fb.resnet.torch/blob/master/datasets/transforms.lua#L218>`_
"""
def
__init__
(
self
,
alpha
=
0.4
):
""" Saturation,
see 'fb.resnet.torch' https://github.com/facebook/fb.resnet.torch/blob/master/datasets/transforms.lua#L218
"""
Args:
alpha(float): maximum saturation change.
"""
super
(
Saturation
,
self
)
.
__init__
()
assert
alpha
<
1
...
...
@@ -147,14 +165,19 @@ class Saturation(ImageAugmentor):
class
Lighting
(
ImageAugmentor
):
""" Lighting noise, as in the paper
`ImageNet Classification with Deep Convolutional Neural Networks
<https://papers.nips.cc/paper/4824-imagenet-classification-with-deep-convolutional-neural-networks.pdf>`_.
The implementation follows `fb.resnet.torch
<https://github.com/facebook/fb.resnet.torch/blob/master/datasets/transforms.lua#L184>`_.
"""
def
__init__
(
self
,
std
,
eigval
,
eigvec
):
""" Lighting noise.
See `ImageNet Classification with Deep Convolutional Neural Networks - Alex`
The implementation follows 'fb.resnet.torch':
https://github.com/facebook/fb.resnet.torch/blob/master/datasets/transforms.lua#L184
:param eigvec: each column is one eigen vector
"""
Args:
std (float): maximum standard deviation
eigval: a vector of (3,). The eigenvalues of 3 channels.
eigvec: a 3x3 matrix. Each column is one eigen vector.
"""
eigval
=
np
.
asarray
(
eigval
)
eigvec
=
np
.
asarray
(
eigvec
)
...
...
tensorpack/dataflow/imgaug/meta.py
View file @
d8935ef3
...
...
@@ -11,15 +11,22 @@ __all__ = ['RandomChooseAug', 'MapImage', 'Identity', 'RandomApplyAug',
class
Identity
(
ImageAugmentor
):
""" A no-op augmentor """
def
_augment
(
self
,
img
,
_
):
return
img
class
RandomApplyAug
(
ImageAugmentor
):
""" Randomly apply the augmentor with a prob. Otherwise do nothing"""
""" Randomly apply the augmentor with a probability.
Otherwise do nothing
"""
def
__init__
(
self
,
aug
,
prob
):
"""
Args:
aug (ImageAugmentor): an augmentor
prob (float): the probability
"""
self
.
_init
(
locals
())
super
(
RandomApplyAug
,
self
)
.
__init__
()
...
...
@@ -43,10 +50,11 @@ class RandomApplyAug(ImageAugmentor):
class
RandomChooseAug
(
ImageAugmentor
):
""" Randomly choose one from a list of augmentors """
def
__init__
(
self
,
aug_lists
):
"""
:param aug_lists: list of augmentor, or list of (augmentor, probability) tuple
Args:
aug_lists (list): list of augmentors, or list of (augmentor, probability) tuples
"""
if
isinstance
(
aug_lists
[
0
],
(
tuple
,
list
)):
prob
=
[
k
[
1
]
for
k
in
aug_lists
]
...
...
@@ -73,11 +81,15 @@ class RandomChooseAug(ImageAugmentor):
class
RandomOrderAug
(
ImageAugmentor
):
"""
Apply the augmentors with randomized order.
"""
def
__init__
(
self
,
aug_lists
):
"""
Shuffle the augmentors into random order.
:param aug_lists: list of augmentor, or list of (augmentor, probability) tuple
Args:
aug_lists (list): list of augmentors.
The augmentors are assumed to not change the shape of images.
"""
self
.
_init
(
locals
())
super
(
RandomOrderAug
,
self
)
.
__init__
()
...
...
@@ -109,7 +121,8 @@ class MapImage(ImageAugmentor):
def
__init__
(
self
,
func
):
"""
:param func: a function which takes a image array and return a augmented one
Args:
func: a function which takes an image array and return an augmented one
"""
self
.
func
=
func
...
...
tensorpack/dataflow/imgaug/noise.py
View file @
d8935ef3
...
...
@@ -11,8 +11,13 @@ __all__ = ['JpegNoise', 'GaussianNoise', 'SaltPepperNoise']
class
JpegNoise
(
ImageAugmentor
):
""" Random Jpeg noise. """
def
__init__
(
self
,
quality_range
=
(
40
,
100
)):
"""
Args:
quality_range (tuple): range to sample Jpeg quality
"""
super
(
JpegNoise
,
self
)
.
__init__
()
self
.
_init
(
locals
())
...
...
@@ -25,10 +30,14 @@ class JpegNoise(ImageAugmentor):
class
GaussianNoise
(
ImageAugmentor
):
"""
Add random gaussian noise N(0, sigma^2) of the same shape to img.
"""
def
__init__
(
self
,
sigma
=
1
,
clip
=
True
):
"""
Add a gaussian noise N(0, sigma^2) of the same shape to img.
Args:
sigma (float): stddev of the Gaussian distribution.
clip (bool): clip the result to [0,255] in the end.
"""
super
(
GaussianNoise
,
self
)
.
__init__
()
self
.
_init
(
locals
())
...
...
@@ -44,10 +53,14 @@ class GaussianNoise(ImageAugmentor):
class
SaltPepperNoise
(
ImageAugmentor
):
""" Salt and pepper noise.
Randomly set some elements in img to 0 or 255, regardless of its channels.
"""
def
__init__
(
self
,
white_prob
=
0.05
,
black_prob
=
0.05
):
""" Salt and pepper noise.
Randomly set some elements in img to 0 or 255, regardless of its channels.
"""
Args:
white_prob (float), black_prob (float): probabilities setting an element to 255 or 0.
"""
assert
white_prob
+
black_prob
<=
1
,
"Sum of probabilities cannot be greater than 1"
super
(
SaltPepperNoise
,
self
)
.
__init__
()
...
...
tensorpack/dataflow/imgaug/noname.py
View file @
d8935ef3
...
...
@@ -13,20 +13,18 @@ __all__ = ['Flip', 'Resize', 'RandomResize', 'ResizeShortestEdge']
class
Flip
(
ImageAugmentor
):
"""
Random flip.
Random flip
the image either horizontally or vertically
.
"""
def
__init__
(
self
,
horiz
=
False
,
vert
=
False
,
prob
=
0.5
):
"""
Only one of horiz, vert can be set.
:param horiz: whether or not apply horizontal flip.
:param vert: whether or not apply vertical flip.
:param prob: probability of flip.
Args:
horiz (bool): use horizontal flip.
vert (bool): use vertical flip.
prob (float): probability of flip.
"""
super
(
Flip
,
self
)
.
__init__
()
if
horiz
and
vert
:
raise
ValueError
(
"Please use two Flip instead."
)
raise
ValueError
(
"
Cannot do both horiz and vert.
Please use two Flip instead."
)
elif
horiz
:
self
.
code
=
1
elif
vert
:
...
...
@@ -53,7 +51,9 @@ class Resize(ImageAugmentor):
def
__init__
(
self
,
shape
,
interp
=
cv2
.
INTER_CUBIC
):
"""
:param shape: shape in (h, w)
Args:
shape: (h, w) tuple or a int
interp: cv2 interpolation method
"""
shape
=
tuple
(
shape2d
(
shape
))
self
.
_init
(
locals
())
...
...
@@ -68,11 +68,16 @@ class Resize(ImageAugmentor):
class
ResizeShortestEdge
(
ImageAugmentor
):
""" Resize the shortest edge to a certain number while
keeping the aspect ratio
"""
Resize the shortest edge to a certain number while
keeping the aspect ratio.
"""
def
__init__
(
self
,
size
):
"""
Args:
size (int): the size to resize the shortest edge to.
"""
size
=
size
*
1.0
self
.
_init
(
locals
())
...
...
@@ -87,15 +92,18 @@ class ResizeShortestEdge(ImageAugmentor):
class
RandomResize
(
ImageAugmentor
):
"""
r
andomly rescale w and h of the image"""
"""
R
andomly rescale w and h of the image"""
def
__init__
(
self
,
xrange
,
yrange
,
minimum
=
(
0
,
0
),
aspect_ratio_thres
=
0.15
,
interp
=
cv2
.
INTER_CUBIC
):
"""
:param xrange: (min, max) scaling ratio
:param yrange: (min, max) scaling ratio
:param minimum: (xmin, ymin). Avoid scaling down too much.
:param aspect_ratio_thres: at most change k=20
%
aspect ratio
Args:
xrange (tuple): (min, max) range of scaling ratio for w
yrange (tuple): (min, max) range of scaling ratio for h
minimum (tuple): (xmin, ymin). avoid scaling down too much.
aspect_ratio_thres (float): discard samples which change aspect ratio
larger than this threshold.
interp: cv2 interpolation method
"""
super
(
RandomResize
,
self
)
.
__init__
()
self
.
_init
(
locals
())
...
...
tensorpack/dataflow/imgaug/paste.py
View file @
d8935ef3
...
...
@@ -19,9 +19,11 @@ class BackgroundFiller(object):
"""
Return a proper background image of background_shape, given img
:param background_shape: a shape of [h, w]
:param img: an image
:returns: a background image
Args:
background_shape: a shape of [h, w]
img: an image
Returns:
a background image
"""
return
self
.
_fill
(
background_shape
,
img
)
...
...
@@ -35,7 +37,8 @@ class ConstantBackgroundFiller(BackgroundFiller):
def
__init__
(
self
,
value
):
"""
:param value: the value to fill the background.
Args:
value (float): the value to fill the background.
"""
self
.
value
=
value
...
...
@@ -55,8 +58,9 @@ class CenterPaste(ImageAugmentor):
def
__init__
(
self
,
background_shape
,
background_filler
=
None
):
"""
:param background_shape: shape of the background canvas.
:param background_filler: a `BackgroundFiller` instance. Default to zero-filler.
Args:
background_shape (tuple): shape of the background canvas.
background_filler (BackgroundFiller): How to fill the background. Defaults to zero-filler.
"""
if
background_filler
is
None
:
background_filler
=
ConstantBackgroundFiller
(
0
)
...
...
tensorpack/dataflow/prefetch.py
View file @
d8935ef3
...
...
@@ -48,15 +48,19 @@ class PrefetchProcess(mp.Process):
class
PrefetchData
(
ProxyDataFlow
):
"""
Prefetch data from a `DataFlow` using multiprocessing
Prefetch data from a DataFlow using Python multiprocessing utilities.
Note:
This is significantly slower than :class:`PrefetchDataZMQ` when data
is large.
"""
def
__init__
(
self
,
ds
,
nr_prefetch
,
nr_proc
=
1
):
"""
:param ds: a `DataFlow` instance.
:param nr_prefetch: size of the queue to hold prefetched datapoints
.
:param nr_proc: number of processes to use. When larger than 1, order
of data points will be random
.
Args:
ds (DataFlow): input DataFlow
.
nr_prefetch (int): size of the queue to hold prefetched datapoints.
nr_proc (int): number of processes to use
.
"""
super
(
PrefetchData
,
self
)
.
__init__
(
ds
)
try
:
...
...
@@ -85,9 +89,8 @@ class PrefetchData(ProxyDataFlow):
def
BlockParallel
(
ds
,
queue_size
):
# TODO more doc
"""
Insert `
BlockParallel` in dataflow pipeline to block parallelism on ds
Insert `
`BlockParallel`` in dataflow pipeline to block parallelism on ds.
:param ds: a `DataFlow`
:param queue_size: size of the queue used
...
...
@@ -96,7 +99,6 @@ def BlockParallel(ds, queue_size):
class
PrefetchProcessZMQ
(
mp
.
Process
):
def
__init__
(
self
,
ds
,
conn_name
):
"""
:param ds: a `DataFlow` instance.
...
...
@@ -118,15 +120,17 @@ class PrefetchProcessZMQ(mp.Process):
class
PrefetchDataZMQ
(
ProxyDataFlow
):
""" Work the same as `PrefetchData`, but faster. """
"""
Prefetch data from a DataFlow using multiple processes, with ZMQ for
communication.
"""
def
__init__
(
self
,
ds
,
nr_proc
=
1
,
pipedir
=
None
):
"""
:param ds: a `DataFlow` instance.
:param nr_proc: number of processes to use. When larger than 1, order
of datapoints will be random
.
:param pipedir: a local directory where the pipes would be
.
Useful if you're running on non-local FS such as N
FS.
Args:
ds (DataFlow): input DataFlow.
nr_proc (int): number of processes to use
.
pipedir (str): a local directory where the pipes should be put
.
Useful if you're running on non-local FS such as NFS or Gluster
FS.
"""
super
(
PrefetchDataZMQ
,
self
)
.
__init__
(
ds
)
try
:
...
...
@@ -185,10 +189,20 @@ class PrefetchDataZMQ(ProxyDataFlow):
class
PrefetchOnGPUs
(
PrefetchDataZMQ
):
""" Prefetch with each process having a specific CUDA_VISIBLE_DEVICES
variable"""
"""
Prefetch with each process having its own ``CUDA_VISIBLE_DEVICES`` variable
mapped to one GPU.
"""
def
__init__
(
self
,
ds
,
gpus
,
pipedir
=
None
):
"""
Args:
ds (DataFlow): input DataFlow.
gpus (list[int]): list of GPUs to use. Will also start this many
of processes.
pipedir (str): a local directory where the pipes should be put.
Useful if you're running on non-local FS such as NFS or GlusterFS.
"""
self
.
gpus
=
gpus
super
(
PrefetchOnGPUs
,
self
)
.
__init__
(
ds
,
len
(
gpus
),
pipedir
)
...
...
tensorpack/dataflow/raw.py
View file @
d8935ef3
...
...
@@ -7,26 +7,21 @@ import numpy as np
import
copy
from
six.moves
import
range
from
.base
import
DataFlow
,
RNGDataFlow
from
..utils.serialize
import
loads
__all__
=
[
'FakeData'
,
'DataFromQueue'
,
'DataFromList'
]
try
:
import
zmq
except
:
pass
else
:
__all__
.
append
(
'DataFromSocket'
)
class
FakeData
(
RNGDataFlow
):
""" Generate fake
fixed
data of given shapes"""
""" Generate fake data of given shapes"""
def
__init__
(
self
,
shapes
,
size
,
random
=
True
,
dtype
=
'float32'
):
def
__init__
(
self
,
shapes
,
size
=
1000
,
random
=
True
,
dtype
=
'float32'
):
"""
:param shapes: a list of lists/tuples
:param size: size of this DataFlow
:param random: whether to randomly generate data every iteration. note
that only generating the data could be time-consuming!
Args:
shapes (list): a list of lists/tuples. Shapes of each component.
size (int): size of this DataFlow.
random (bool): whether to randomly generate data every iteration.
Note that merely generating the data could sometimes be time-consuming!
dtype (str): data type.
"""
super
(
FakeData
,
self
)
.
__init__
()
self
.
shapes
=
shapes
...
...
@@ -49,8 +44,11 @@ class FakeData(RNGDataFlow):
class
DataFromQueue
(
DataFlow
):
""" Produce data from a queue """
def
__init__
(
self
,
queue
):
"""
Args:
queue (queue): a queue with ``get()`` method.
"""
self
.
queue
=
queue
def
get_data
(
self
):
...
...
@@ -62,6 +60,11 @@ class DataFromList(RNGDataFlow):
""" Produce data from a list"""
def
__init__
(
self
,
lst
,
shuffle
=
True
):
"""
Args:
lst (list): input list.
shuffle (bool): shuffle data.
"""
super
(
DataFromList
,
self
)
.
__init__
()
self
.
lst
=
lst
self
.
shuffle
=
shuffle
...
...
@@ -78,22 +81,3 @@ class DataFromList(RNGDataFlow):
self
.
rng
.
shuffle
(
idxs
)
for
k
in
idxs
:
yield
self
.
lst
[
k
]
class
DataFromSocket
(
DataFlow
):
""" Produce data from a zmq socket"""
def
__init__
(
self
,
socket_name
):
self
.
_name
=
socket_name
def
get_data
(
self
):
try
:
ctx
=
zmq
.
Context
()
socket
=
ctx
.
socket
(
zmq
.
PULL
)
socket
.
bind
(
self
.
_name
)
while
True
:
dp
=
loads
(
socket
.
recv
(
copy
=
False
))
yield
dp
finally
:
ctx
.
destroy
(
linger
=
0
)
tensorpack/dataflow/remote.py
View file @
d8935ef3
...
...
@@ -19,6 +19,14 @@ from ..utils.serialize import dumps, loads
def
serve_data
(
ds
,
addr
):
"""
Serve the DataFlow on a ZMQ socket addr.
It will dump and send each datapoint to this addr with a PUSH socket.
Args:
ds (DataFlow): DataFlow to serve. Will infinitely loop over the DataFlow.
addr: a ZMQ socket addr.
"""
ctx
=
zmq
.
Context
()
socket
=
ctx
.
socket
(
zmq
.
PUSH
)
socket
.
set_hwm
(
10
)
...
...
@@ -27,7 +35,7 @@ def serve_data(ds, addr):
try
:
ds
.
reset_state
()
logger
.
info
(
"Serving data at {}"
.
format
(
addr
))
# TODO print statistics
here
# TODO print statistics
such as speed
while
True
:
for
dp
in
ds
.
get_data
():
socket
.
send
(
dumps
(
dp
),
copy
=
False
)
...
...
@@ -39,17 +47,25 @@ def serve_data(ds, addr):
class
RemoteData
(
DataFlow
):
""" Produce data from a ZMQ socket. """
def
__init__
(
self
,
addr
):
self
.
ctx
=
zmq
.
Context
()
self
.
socket
=
self
.
ctx
.
socket
(
zmq
.
PULL
)
self
.
socket
.
set_hwm
(
10
)
self
.
socket
.
connect
(
addr
)
"""
Args:
addr (str): addr of the socket to connect to.
"""
self
.
_addr
=
addr
def
get_data
(
self
):
while
True
:
dp
=
loads
(
self
.
socket
.
recv
(
copy
=
False
))
yield
dp
try
:
ctx
=
zmq
.
Context
()
socket
=
ctx
.
socket
(
zmq
.
PULL
)
socket
.
connect
(
self
.
_addr
)
while
True
:
dp
=
loads
(
socket
.
recv
(
copy
=
False
))
yield
dp
finally
:
ctx
.
destroy
(
linger
=
0
)
if
__name__
==
'__main__'
:
...
...
tensorpack/dataflow/tf_func.py
View file @
d8935ef3
...
...
@@ -12,7 +12,9 @@ except ImportError:
logger
.
warn_dependency
(
'TFFuncMapper'
,
'tensorflow'
)
__all__
=
[]
else
:
__all__
=
[
'TFFuncMapper'
]
__all__
=
[]
""" This file was deprecated """
class
TFFuncMapper
(
ProxyDataFlow
):
...
...
tensorpack/models/model_desc.py
View file @
d8935ef3
...
...
@@ -138,7 +138,7 @@ class ModelFromMetaGraph(ModelDesc):
def
__init__
(
self
,
filename
):
"""
Args:
filename(str): file name of the saved meta graph.
filename
(str): file name of the saved meta graph.
"""
tf
.
train
.
import_meta_graph
(
filename
)
all_coll
=
tf
.
get_default_graph
()
.
get_all_collection_keys
()
...
...
tensorpack/models/regularize.py
View file @
d8935ef3
...
...
@@ -59,9 +59,9 @@ def Dropout(x, keep_prob=0.5, is_training=None):
Neural Networks from Overfitting <http://dl.acm.org/citation.cfm?id=2670313>`_.
Args:
keep_prob: the probability that each element is kept. It is only used
keep_prob
(float)
: the probability that each element is kept. It is only used
when is_training=True.
is_training: If None, will use the current :class:`tensorpack.tfutils.TowerContext`
is_training
(bool)
: If None, will use the current :class:`tensorpack.tfutils.TowerContext`
to figure out.
"""
if
is_training
is
None
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment