Commit 8553816b authored by Yuxin Wu's avatar Yuxin Wu

use MapData to implement LMDBDataDecoder and LMDBDataPoint

parent 7c7f6e85
......@@ -14,17 +14,21 @@ from tensorpack.dataflow import DataFlow
class GANModelDesc(ModelDesc):
def collect_variables(self):
"""Extract variables by prefix
def collect_variables(self, g_scope='gen', d_scope='discrim'):
"""
Assign self.g_vars to the parameters under scope `g_scope`,
and same with self.d_vars.
"""
all_vars = tf.trainable_variables()
self.g_vars = [v for v in all_vars if v.name.startswith('gen/')]
self.d_vars = [v for v in all_vars if v.name.startswith('discrim/')]
self.g_vars = [v for v in all_vars if v.name.startswith(g_scope + '/')]
self.d_vars = [v for v in all_vars if v.name.startswith(d_scope + '/')]
# TODO after TF1.0.0rc1
# self.g_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, g_scope)
# self.d_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, d_scope)
def build_losses(self, logits_real, logits_fake):
"""D and G play two-player minimax game with value function V(G,D)
min_G max _D V(D, G) = IE_{x ~ p_data} [log D(x)] + IE_{z ~ p_fake} [log (1 - D(G(z)))]
Note, we swap 0, 1 labels as suggested in "Improving GANs".
......
......@@ -86,13 +86,14 @@ class ProgressBar(Callback):
def _before_train(self):
self._total = self.trainer.config.steps_per_epoch
self._tqdm_args = get_tqdm_kwargs(leave=True)
if self._names is not []:
if len(self._names):
self._tqdm_args['bar_format'] = self._tqdm_args['bar_format'] + "{postfix} "
def _trigger_step(self, *args):
if self.local_step == 1:
self._bar = tqdm.trange(self._total, **self._tqdm_args)
self._bar.set_postfix(zip(self._tags, args))
if len(self._names):
self._bar.set_postfix(zip(self._tags, args))
self._bar.update()
if self.local_step == self._total:
......
......@@ -13,6 +13,7 @@ from ..utils.loadcaffe import get_caffe_pb
from ..utils.serialize import loads
from ..utils.argtools import log_once
from .base import RNGDataFlow
from .common import MapData
__all__ = ['HDF5Data', 'LMDBData', 'LMDBDataDecoder', 'LMDBDataPoint',
'CaffeLMDB', 'SVMLightData']
......@@ -133,69 +134,66 @@ class LMDBData(RNGDataFlow):
yield [k, v]
class LMDBDataDecoder(LMDBData):
class LMDBDataDecoder(MapData):
""" Read a LMDB database and produce a decoded output."""
def __init__(self, lmdb_path, decoder, shuffle=True, keys=None):
def __init__(self, lmdb_data, decoder):
"""
Args:
lmdb_path, shuffle, keys: same as :class:`LMDBData`.
lmdb_data: a :class:`LMDBData` instance.
decoder (k,v -> dp | None): a function taking k, v and returning a datapoint,
or return None to discard.
"""
super(LMDBDataDecoder, self).__init__(lmdb_path, shuffle=shuffle, keys=keys)
self.decoder = decoder
def get_data(self):
for dp in super(LMDBDataDecoder, self).get_data():
v = self.decoder(dp[0], dp[1])
if v:
yield v
def f(dp):
return decoder(dp[0], dp[1])
super(LMDBDataDecoder, self).__init__(lmdb_data, f)
class LMDBDataPoint(LMDBDataDecoder):
class LMDBDataPoint(MapData):
""" Read a LMDB file and produce deserialized values.
This can work with :func:`tensorpack.dataflow.dftools.dump_dataflow_to_lmdb`. """
def __init__(self, lmdb_path, shuffle=True, keys=None):
def __init__(self, lmdb_data):
"""
Args:
lmdb_path (str): a directory or a file.
shuffle (bool): shuffle the keys or not.
keys (list): list of keys for lmdb file or the key format `'{:0>8d}'`
lmdb_data: a :class:`LMDBData` instance.
"""
super(LMDBDataPoint, self).__init__(
lmdb_path, decoder=lambda k, v: loads(v), shuffle=shuffle, keys=keys)
def f(dp):
return loads(dp[1])
super(LMDBDataPoint, self).__init__(lmdb_data, f)
class CaffeLMDB(LMDBDataDecoder):
def CaffeLMDB(lmdb_path, shuffle=True, keys=None):
"""
Read a Caffe LMDB file where each value contains a ``caffe.Datum`` protobuf.
Produces datapoints of the format: [HWC image, label].
Note that Caffe LMDB format is not efficient: it stores serialized raw
arrays rather than JPEG images.
Args:
lmdb_path, shuffle, keys: same as :class:`LMDBData`.
Returns:
a :class:`LMDBDataDecoder` instance.
Example:
ds = CaffeLMDB("/tmp/validation", keys='{:0>8d}')
"""
def __init__(self, lmdb_path, shuffle=True, keys=None):
"""
Args:
lmdb_path, shuffle, keys: same as :class:`LMDBData`.
"""
cpb = get_caffe_pb()
def decoder(k, v):
try:
datum = cpb.Datum()
datum.ParseFromString(v)
img = np.fromstring(datum.data, dtype=np.uint8)
img = img.reshape(datum.channels, datum.height, datum.width)
except Exception:
log_once("Cannot read key {}".format(k), 'warn')
return None
return [img.transpose(1, 2, 0), datum.label]
super(CaffeLMDB, self).__init__(
lmdb_path, decoder=decoder, shuffle=shuffle, keys=keys)
cpb = get_caffe_pb()
lmdb_data = LMDBData(lmdb_path, shuffle, keys)
def decoder(k, v):
try:
datum = cpb.Datum()
datum.ParseFromString(v)
img = np.fromstring(datum.data, dtype=np.uint8)
img = img.reshape(datum.channels, datum.height, datum.width)
except Exception:
log_once("Cannot read key {}".format(k), 'warn')
return None
return [img.transpose(1, 2, 0), datum.label]
return LMDBDataDecoder(lmdb_data, decoder)
class SVMLightData(RNGDataFlow):
......
......@@ -121,14 +121,8 @@ def BatchNorm(x, use_local_stat=None, decay=0.9, epsilon=1e-5):
* ``variance/EMA``: the moving average of variance.
Note:
* In multi-tower training, only the first training tower maintains a moving average.
This is consistent with most frameworks.
* It automatically selects :meth:`BatchNormV1` or :meth:`BatchNormV2`
according to availability.
* This is a slightly faster but equivalent version of BatchNormV1. It uses
``fused_batch_norm`` in training.
In multi-tower training, only the first training tower maintains a moving average.
This is consistent with most frameworks.
"""
shape = x.get_shape().as_list()
assert len(shape) in [2, 4]
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment