Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
S
seminar-breakout
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
Analytics
CI / CD
Repository
Value Stream
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
Shashank Suhas
seminar-breakout
Commits
8553816b
Commit
8553816b
authored
Feb 07, 2017
by
Yuxin Wu
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
use MapData to implement LMDBDataDecoder and LMDBDataPoint
parent
7c7f6e85
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
51 additions
and
54 deletions
+51
-54
examples/GAN/GAN.py
examples/GAN/GAN.py
+9
-5
tensorpack/callbacks/steps.py
tensorpack/callbacks/steps.py
+3
-2
tensorpack/dataflow/format.py
tensorpack/dataflow/format.py
+37
-39
tensorpack/models/batch_norm.py
tensorpack/models/batch_norm.py
+2
-8
No files found.
examples/GAN/GAN.py
View file @
8553816b
...
...
@@ -14,17 +14,21 @@ from tensorpack.dataflow import DataFlow
class
GANModelDesc
(
ModelDesc
):
def
collect_variables
(
self
):
"""Extract variables by prefix
def
collect_variables
(
self
,
g_scope
=
'gen'
,
d_scope
=
'discrim'
):
"""
Assign self.g_vars to the parameters under scope `g_scope`,
and same with self.d_vars.
"""
all_vars
=
tf
.
trainable_variables
()
self
.
g_vars
=
[
v
for
v
in
all_vars
if
v
.
name
.
startswith
(
'gen/'
)]
self
.
d_vars
=
[
v
for
v
in
all_vars
if
v
.
name
.
startswith
(
'discrim/'
)]
self
.
g_vars
=
[
v
for
v
in
all_vars
if
v
.
name
.
startswith
(
g_scope
+
'/'
)]
self
.
d_vars
=
[
v
for
v
in
all_vars
if
v
.
name
.
startswith
(
d_scope
+
'/'
)]
# TODO after TF1.0.0rc1
# self.g_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, g_scope)
# self.d_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, d_scope)
def
build_losses
(
self
,
logits_real
,
logits_fake
):
"""D and G play two-player minimax game with value function V(G,D)
min_G max _D V(D, G) = IE_{x ~ p_data} [log D(x)] + IE_{z ~ p_fake} [log (1 - D(G(z)))]
Note, we swap 0, 1 labels as suggested in "Improving GANs".
...
...
tensorpack/callbacks/steps.py
View file @
8553816b
...
...
@@ -86,12 +86,13 @@ class ProgressBar(Callback):
def
_before_train
(
self
):
self
.
_total
=
self
.
trainer
.
config
.
steps_per_epoch
self
.
_tqdm_args
=
get_tqdm_kwargs
(
leave
=
True
)
if
self
.
_names
is
not
[]
:
if
len
(
self
.
_names
)
:
self
.
_tqdm_args
[
'bar_format'
]
=
self
.
_tqdm_args
[
'bar_format'
]
+
"{postfix} "
def
_trigger_step
(
self
,
*
args
):
if
self
.
local_step
==
1
:
self
.
_bar
=
tqdm
.
trange
(
self
.
_total
,
**
self
.
_tqdm_args
)
if
len
(
self
.
_names
):
self
.
_bar
.
set_postfix
(
zip
(
self
.
_tags
,
args
))
self
.
_bar
.
update
()
...
...
tensorpack/dataflow/format.py
View file @
8553816b
...
...
@@ -13,6 +13,7 @@ from ..utils.loadcaffe import get_caffe_pb
from
..utils.serialize
import
loads
from
..utils.argtools
import
log_once
from
.base
import
RNGDataFlow
from
.common
import
MapData
__all__
=
[
'HDF5Data'
,
'LMDBData'
,
'LMDBDataDecoder'
,
'LMDBDataPoint'
,
'CaffeLMDB'
,
'SVMLightData'
]
...
...
@@ -133,55 +134,54 @@ class LMDBData(RNGDataFlow):
yield
[
k
,
v
]
class
LMDBDataDecoder
(
LMDB
Data
):
class
LMDBDataDecoder
(
Map
Data
):
""" Read a LMDB database and produce a decoded output."""
def
__init__
(
self
,
lmdb_
path
,
decoder
,
shuffle
=
True
,
keys
=
None
):
def
__init__
(
self
,
lmdb_
data
,
decoder
):
"""
Args:
lmdb_
path, shuffle, keys: same as :class:`LMDBData`
.
lmdb_
data: a :class:`LMDBData` instance
.
decoder (k,v -> dp | None): a function taking k, v and returning a datapoint,
or return None to discard.
"""
super
(
LMDBDataDecoder
,
self
)
.
__init__
(
lmdb_path
,
shuffle
=
shuffle
,
keys
=
keys
)
self
.
decoder
=
decoder
def
get_data
(
self
):
for
dp
in
super
(
LMDBDataDecoder
,
self
)
.
get_data
():
v
=
self
.
decoder
(
dp
[
0
],
dp
[
1
])
if
v
:
yield
v
def
f
(
dp
):
return
decoder
(
dp
[
0
],
dp
[
1
])
super
(
LMDBDataDecoder
,
self
)
.
__init__
(
lmdb_data
,
f
)
class
LMDBDataPoint
(
LMDBDataDecoder
):
class
LMDBDataPoint
(
MapData
):
""" Read a LMDB file and produce deserialized values.
This can work with :func:`tensorpack.dataflow.dftools.dump_dataflow_to_lmdb`. """
def
__init__
(
self
,
lmdb_
path
,
shuffle
=
True
,
keys
=
None
):
def
__init__
(
self
,
lmdb_
data
):
"""
Args:
lmdb_path (str): a directory or a file.
shuffle (bool): shuffle the keys or not.
keys (list): list of keys for lmdb file or the key format `'{:0>8d}'`
lmdb_data: a :class:`LMDBData` instance.
"""
super
(
LMDBDataPoint
,
self
)
.
__init__
(
lmdb_path
,
decoder
=
lambda
k
,
v
:
loads
(
v
),
shuffle
=
shuffle
,
keys
=
keys
)
def
f
(
dp
):
return
loads
(
dp
[
1
])
super
(
LMDBDataPoint
,
self
)
.
__init__
(
lmdb_data
,
f
)
class
CaffeLMDB
(
LMDBDataDecoder
):
def
CaffeLMDB
(
lmdb_path
,
shuffle
=
True
,
keys
=
None
):
"""
Read a Caffe LMDB file where each value contains a ``caffe.Datum`` protobuf.
Produces datapoints of the format: [HWC image, label].
Example:
ds = CaffeLMDB("/tmp/validation", keys='{:0>8d}')
"""
Note that Caffe LMDB format is not efficient: it stores serialized raw
arrays rather than JPEG images.
def
__init__
(
self
,
lmdb_path
,
shuffle
=
True
,
keys
=
None
):
"""
Args:
lmdb_path, shuffle, keys: same as :class:`LMDBData`.
Returns:
a :class:`LMDBDataDecoder` instance.
Example:
ds = CaffeLMDB("/tmp/validation", keys='{:0>8d}')
"""
cpb
=
get_caffe_pb
()
lmdb_data
=
LMDBData
(
lmdb_path
,
shuffle
,
keys
)
def
decoder
(
k
,
v
):
try
:
...
...
@@ -193,9 +193,7 @@ class CaffeLMDB(LMDBDataDecoder):
log_once
(
"Cannot read key {}"
.
format
(
k
),
'warn'
)
return
None
return
[
img
.
transpose
(
1
,
2
,
0
),
datum
.
label
]
super
(
CaffeLMDB
,
self
)
.
__init__
(
lmdb_path
,
decoder
=
decoder
,
shuffle
=
shuffle
,
keys
=
keys
)
return
LMDBDataDecoder
(
lmdb_data
,
decoder
)
class
SVMLightData
(
RNGDataFlow
):
...
...
tensorpack/models/batch_norm.py
View file @
8553816b
...
...
@@ -121,14 +121,8 @@ def BatchNorm(x, use_local_stat=None, decay=0.9, epsilon=1e-5):
* ``variance/EMA``: the moving average of variance.
Note:
*
In multi-tower training, only the first training tower maintains a moving average.
In multi-tower training, only the first training tower maintains a moving average.
This is consistent with most frameworks.
* It automatically selects :meth:`BatchNormV1` or :meth:`BatchNormV2`
according to availability.
* This is a slightly faster but equivalent version of BatchNormV1. It uses
``fused_batch_norm`` in training.
"""
shape
=
x
.
get_shape
()
.
as_list
()
assert
len
(
shape
)
in
[
2
,
4
]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment