Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
8553816b
Commit
8553816b
authored
Feb 07, 2017
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
use MapData to implement LMDBDataDecoder and LMDBDataPoint
parent
7c7f6e85
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
51 additions
and
54 deletions
+51
-54
examples/GAN/GAN.py
examples/GAN/GAN.py
+9
-5
tensorpack/callbacks/steps.py
tensorpack/callbacks/steps.py
+3
-2
tensorpack/dataflow/format.py
tensorpack/dataflow/format.py
+37
-39
tensorpack/models/batch_norm.py
tensorpack/models/batch_norm.py
+2
-8
No files found.
examples/GAN/GAN.py
View file @
8553816b
...
@@ -14,17 +14,21 @@ from tensorpack.dataflow import DataFlow
...
@@ -14,17 +14,21 @@ from tensorpack.dataflow import DataFlow
class
GANModelDesc
(
ModelDesc
):
class
GANModelDesc
(
ModelDesc
):
def
collect_variables
(
self
):
def
collect_variables
(
self
,
g_scope
=
'gen'
,
d_scope
=
'discrim'
):
"""Extract variables by prefix
"""
Assign self.g_vars to the parameters under scope `g_scope`,
and same with self.d_vars.
"""
"""
all_vars
=
tf
.
trainable_variables
()
all_vars
=
tf
.
trainable_variables
()
self
.
g_vars
=
[
v
for
v
in
all_vars
if
v
.
name
.
startswith
(
'gen/'
)]
self
.
g_vars
=
[
v
for
v
in
all_vars
if
v
.
name
.
startswith
(
g_scope
+
'/'
)]
self
.
d_vars
=
[
v
for
v
in
all_vars
if
v
.
name
.
startswith
(
'discrim/'
)]
self
.
d_vars
=
[
v
for
v
in
all_vars
if
v
.
name
.
startswith
(
d_scope
+
'/'
)]
# TODO after TF1.0.0rc1
# self.g_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, g_scope)
# self.d_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, d_scope)
def
build_losses
(
self
,
logits_real
,
logits_fake
):
def
build_losses
(
self
,
logits_real
,
logits_fake
):
"""D and G play two-player minimax game with value function V(G,D)
"""D and G play two-player minimax game with value function V(G,D)
min_G max _D V(D, G) = IE_{x ~ p_data} [log D(x)] + IE_{z ~ p_fake} [log (1 - D(G(z)))]
min_G max _D V(D, G) = IE_{x ~ p_data} [log D(x)] + IE_{z ~ p_fake} [log (1 - D(G(z)))]
Note, we swap 0, 1 labels as suggested in "Improving GANs".
Note, we swap 0, 1 labels as suggested in "Improving GANs".
...
...
tensorpack/callbacks/steps.py
View file @
8553816b
...
@@ -86,13 +86,14 @@ class ProgressBar(Callback):
...
@@ -86,13 +86,14 @@ class ProgressBar(Callback):
def
_before_train
(
self
):
def
_before_train
(
self
):
self
.
_total
=
self
.
trainer
.
config
.
steps_per_epoch
self
.
_total
=
self
.
trainer
.
config
.
steps_per_epoch
self
.
_tqdm_args
=
get_tqdm_kwargs
(
leave
=
True
)
self
.
_tqdm_args
=
get_tqdm_kwargs
(
leave
=
True
)
if
self
.
_names
is
not
[]
:
if
len
(
self
.
_names
)
:
self
.
_tqdm_args
[
'bar_format'
]
=
self
.
_tqdm_args
[
'bar_format'
]
+
"{postfix} "
self
.
_tqdm_args
[
'bar_format'
]
=
self
.
_tqdm_args
[
'bar_format'
]
+
"{postfix} "
def
_trigger_step
(
self
,
*
args
):
def
_trigger_step
(
self
,
*
args
):
if
self
.
local_step
==
1
:
if
self
.
local_step
==
1
:
self
.
_bar
=
tqdm
.
trange
(
self
.
_total
,
**
self
.
_tqdm_args
)
self
.
_bar
=
tqdm
.
trange
(
self
.
_total
,
**
self
.
_tqdm_args
)
self
.
_bar
.
set_postfix
(
zip
(
self
.
_tags
,
args
))
if
len
(
self
.
_names
):
self
.
_bar
.
set_postfix
(
zip
(
self
.
_tags
,
args
))
self
.
_bar
.
update
()
self
.
_bar
.
update
()
if
self
.
local_step
==
self
.
_total
:
if
self
.
local_step
==
self
.
_total
:
...
...
tensorpack/dataflow/format.py
View file @
8553816b
...
@@ -13,6 +13,7 @@ from ..utils.loadcaffe import get_caffe_pb
...
@@ -13,6 +13,7 @@ from ..utils.loadcaffe import get_caffe_pb
from
..utils.serialize
import
loads
from
..utils.serialize
import
loads
from
..utils.argtools
import
log_once
from
..utils.argtools
import
log_once
from
.base
import
RNGDataFlow
from
.base
import
RNGDataFlow
from
.common
import
MapData
__all__
=
[
'HDF5Data'
,
'LMDBData'
,
'LMDBDataDecoder'
,
'LMDBDataPoint'
,
__all__
=
[
'HDF5Data'
,
'LMDBData'
,
'LMDBDataDecoder'
,
'LMDBDataPoint'
,
'CaffeLMDB'
,
'SVMLightData'
]
'CaffeLMDB'
,
'SVMLightData'
]
...
@@ -133,69 +134,66 @@ class LMDBData(RNGDataFlow):
...
@@ -133,69 +134,66 @@ class LMDBData(RNGDataFlow):
yield
[
k
,
v
]
yield
[
k
,
v
]
class
LMDBDataDecoder
(
LMDB
Data
):
class
LMDBDataDecoder
(
Map
Data
):
""" Read a LMDB database and produce a decoded output."""
""" Read a LMDB database and produce a decoded output."""
def
__init__
(
self
,
lmdb_
path
,
decoder
,
shuffle
=
True
,
keys
=
None
):
def
__init__
(
self
,
lmdb_
data
,
decoder
):
"""
"""
Args:
Args:
lmdb_
path, shuffle, keys: same as :class:`LMDBData`
.
lmdb_
data: a :class:`LMDBData` instance
.
decoder (k,v -> dp | None): a function taking k, v and returning a datapoint,
decoder (k,v -> dp | None): a function taking k, v and returning a datapoint,
or return None to discard.
or return None to discard.
"""
"""
super
(
LMDBDataDecoder
,
self
)
.
__init__
(
lmdb_path
,
shuffle
=
shuffle
,
keys
=
keys
)
def
f
(
dp
):
self
.
decoder
=
decoder
return
decoder
(
dp
[
0
],
dp
[
1
])
super
(
LMDBDataDecoder
,
self
)
.
__init__
(
lmdb_data
,
f
)
def
get_data
(
self
):
for
dp
in
super
(
LMDBDataDecoder
,
self
)
.
get_data
():
v
=
self
.
decoder
(
dp
[
0
],
dp
[
1
])
if
v
:
yield
v
class
LMDBDataPoint
(
LMDBDataDecoder
):
class
LMDBDataPoint
(
MapData
):
""" Read a LMDB file and produce deserialized values.
""" Read a LMDB file and produce deserialized values.
This can work with :func:`tensorpack.dataflow.dftools.dump_dataflow_to_lmdb`. """
This can work with :func:`tensorpack.dataflow.dftools.dump_dataflow_to_lmdb`. """
def
__init__
(
self
,
lmdb_
path
,
shuffle
=
True
,
keys
=
None
):
def
__init__
(
self
,
lmdb_
data
):
"""
"""
Args:
Args:
lmdb_path (str): a directory or a file.
lmdb_data: a :class:`LMDBData` instance.
shuffle (bool): shuffle the keys or not.
keys (list): list of keys for lmdb file or the key format `'{:0>8d}'`
"""
"""
super
(
LMDBDataPoint
,
self
)
.
__init__
(
def
f
(
dp
):
lmdb_path
,
decoder
=
lambda
k
,
v
:
loads
(
v
),
shuffle
=
shuffle
,
keys
=
keys
)
return
loads
(
dp
[
1
])
super
(
LMDBDataPoint
,
self
)
.
__init__
(
lmdb_data
,
f
)
class
CaffeLMDB
(
LMDBDataDecoder
):
def
CaffeLMDB
(
lmdb_path
,
shuffle
=
True
,
keys
=
None
):
"""
"""
Read a Caffe LMDB file where each value contains a ``caffe.Datum`` protobuf.
Read a Caffe LMDB file where each value contains a ``caffe.Datum`` protobuf.
Produces datapoints of the format: [HWC image, label].
Produces datapoints of the format: [HWC image, label].
Note that Caffe LMDB format is not efficient: it stores serialized raw
arrays rather than JPEG images.
Args:
lmdb_path, shuffle, keys: same as :class:`LMDBData`.
Returns:
a :class:`LMDBDataDecoder` instance.
Example:
Example:
ds = CaffeLMDB("/tmp/validation", keys='{:0>8d}')
ds = CaffeLMDB("/tmp/validation", keys='{:0>8d}')
"""
"""
def
__init__
(
self
,
lmdb_path
,
shuffle
=
True
,
keys
=
None
):
cpb
=
get_caffe_pb
()
"""
lmdb_data
=
LMDBData
(
lmdb_path
,
shuffle
,
keys
)
Args:
lmdb_path, shuffle, keys: same as :class:`LMDBData`.
def
decoder
(
k
,
v
):
"""
try
:
cpb
=
get_caffe_pb
()
datum
=
cpb
.
Datum
()
datum
.
ParseFromString
(
v
)
def
decoder
(
k
,
v
):
img
=
np
.
fromstring
(
datum
.
data
,
dtype
=
np
.
uint8
)
try
:
img
=
img
.
reshape
(
datum
.
channels
,
datum
.
height
,
datum
.
width
)
datum
=
cpb
.
Datum
()
except
Exception
:
datum
.
ParseFromString
(
v
)
log_once
(
"Cannot read key {}"
.
format
(
k
),
'warn'
)
img
=
np
.
fromstring
(
datum
.
data
,
dtype
=
np
.
uint8
)
return
None
img
=
img
.
reshape
(
datum
.
channels
,
datum
.
height
,
datum
.
width
)
return
[
img
.
transpose
(
1
,
2
,
0
),
datum
.
label
]
except
Exception
:
return
LMDBDataDecoder
(
lmdb_data
,
decoder
)
log_once
(
"Cannot read key {}"
.
format
(
k
),
'warn'
)
return
None
return
[
img
.
transpose
(
1
,
2
,
0
),
datum
.
label
]
super
(
CaffeLMDB
,
self
)
.
__init__
(
lmdb_path
,
decoder
=
decoder
,
shuffle
=
shuffle
,
keys
=
keys
)
class
SVMLightData
(
RNGDataFlow
):
class
SVMLightData
(
RNGDataFlow
):
...
...
tensorpack/models/batch_norm.py
View file @
8553816b
...
@@ -121,14 +121,8 @@ def BatchNorm(x, use_local_stat=None, decay=0.9, epsilon=1e-5):
...
@@ -121,14 +121,8 @@ def BatchNorm(x, use_local_stat=None, decay=0.9, epsilon=1e-5):
* ``variance/EMA``: the moving average of variance.
* ``variance/EMA``: the moving average of variance.
Note:
Note:
* In multi-tower training, only the first training tower maintains a moving average.
In multi-tower training, only the first training tower maintains a moving average.
This is consistent with most frameworks.
This is consistent with most frameworks.
* It automatically selects :meth:`BatchNormV1` or :meth:`BatchNormV2`
according to availability.
* This is a slightly faster but equivalent version of BatchNormV1. It uses
``fused_batch_norm`` in training.
"""
"""
shape
=
x
.
get_shape
()
.
as_list
()
shape
=
x
.
get_shape
()
.
as_list
()
assert
len
(
shape
)
in
[
2
,
4
]
assert
len
(
shape
)
in
[
2
,
4
]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment