Commit 8553816b authored by Yuxin Wu's avatar Yuxin Wu

use MapData to implement LMDBDataDecoder and LMDBDataPoint

parent 7c7f6e85
...@@ -14,17 +14,21 @@ from tensorpack.dataflow import DataFlow ...@@ -14,17 +14,21 @@ from tensorpack.dataflow import DataFlow
class GANModelDesc(ModelDesc): class GANModelDesc(ModelDesc):
def collect_variables(self): def collect_variables(self, g_scope='gen', d_scope='discrim'):
"""Extract variables by prefix """
Assign self.g_vars to the parameters under scope `g_scope`,
and same with self.d_vars.
""" """
all_vars = tf.trainable_variables() all_vars = tf.trainable_variables()
self.g_vars = [v for v in all_vars if v.name.startswith('gen/')] self.g_vars = [v for v in all_vars if v.name.startswith(g_scope + '/')]
self.d_vars = [v for v in all_vars if v.name.startswith('discrim/')] self.d_vars = [v for v in all_vars if v.name.startswith(d_scope + '/')]
# TODO after TF1.0.0rc1
# self.g_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, g_scope)
# self.d_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, d_scope)
def build_losses(self, logits_real, logits_fake): def build_losses(self, logits_real, logits_fake):
"""D and G play two-player minimax game with value function V(G,D) """D and G play two-player minimax game with value function V(G,D)
min_G max _D V(D, G) = IE_{x ~ p_data} [log D(x)] + IE_{z ~ p_fake} [log (1 - D(G(z)))] min_G max _D V(D, G) = IE_{x ~ p_data} [log D(x)] + IE_{z ~ p_fake} [log (1 - D(G(z)))]
Note, we swap 0, 1 labels as suggested in "Improving GANs". Note, we swap 0, 1 labels as suggested in "Improving GANs".
......
...@@ -86,13 +86,14 @@ class ProgressBar(Callback): ...@@ -86,13 +86,14 @@ class ProgressBar(Callback):
def _before_train(self): def _before_train(self):
self._total = self.trainer.config.steps_per_epoch self._total = self.trainer.config.steps_per_epoch
self._tqdm_args = get_tqdm_kwargs(leave=True) self._tqdm_args = get_tqdm_kwargs(leave=True)
if self._names is not []: if len(self._names):
self._tqdm_args['bar_format'] = self._tqdm_args['bar_format'] + "{postfix} " self._tqdm_args['bar_format'] = self._tqdm_args['bar_format'] + "{postfix} "
def _trigger_step(self, *args): def _trigger_step(self, *args):
if self.local_step == 1: if self.local_step == 1:
self._bar = tqdm.trange(self._total, **self._tqdm_args) self._bar = tqdm.trange(self._total, **self._tqdm_args)
self._bar.set_postfix(zip(self._tags, args)) if len(self._names):
self._bar.set_postfix(zip(self._tags, args))
self._bar.update() self._bar.update()
if self.local_step == self._total: if self.local_step == self._total:
......
...@@ -13,6 +13,7 @@ from ..utils.loadcaffe import get_caffe_pb ...@@ -13,6 +13,7 @@ from ..utils.loadcaffe import get_caffe_pb
from ..utils.serialize import loads from ..utils.serialize import loads
from ..utils.argtools import log_once from ..utils.argtools import log_once
from .base import RNGDataFlow from .base import RNGDataFlow
from .common import MapData
__all__ = ['HDF5Data', 'LMDBData', 'LMDBDataDecoder', 'LMDBDataPoint', __all__ = ['HDF5Data', 'LMDBData', 'LMDBDataDecoder', 'LMDBDataPoint',
'CaffeLMDB', 'SVMLightData'] 'CaffeLMDB', 'SVMLightData']
...@@ -133,69 +134,66 @@ class LMDBData(RNGDataFlow): ...@@ -133,69 +134,66 @@ class LMDBData(RNGDataFlow):
yield [k, v] yield [k, v]
class LMDBDataDecoder(LMDBData): class LMDBDataDecoder(MapData):
""" Read a LMDB database and produce a decoded output.""" """ Read a LMDB database and produce a decoded output."""
def __init__(self, lmdb_path, decoder, shuffle=True, keys=None): def __init__(self, lmdb_data, decoder):
""" """
Args: Args:
lmdb_path, shuffle, keys: same as :class:`LMDBData`. lmdb_data: a :class:`LMDBData` instance.
decoder (k,v -> dp | None): a function taking k, v and returning a datapoint, decoder (k,v -> dp | None): a function taking k, v and returning a datapoint,
or return None to discard. or return None to discard.
""" """
super(LMDBDataDecoder, self).__init__(lmdb_path, shuffle=shuffle, keys=keys) def f(dp):
self.decoder = decoder return decoder(dp[0], dp[1])
super(LMDBDataDecoder, self).__init__(lmdb_data, f)
def get_data(self):
for dp in super(LMDBDataDecoder, self).get_data():
v = self.decoder(dp[0], dp[1])
if v:
yield v
class LMDBDataPoint(LMDBDataDecoder): class LMDBDataPoint(MapData):
""" Read a LMDB file and produce deserialized values. """ Read a LMDB file and produce deserialized values.
This can work with :func:`tensorpack.dataflow.dftools.dump_dataflow_to_lmdb`. """ This can work with :func:`tensorpack.dataflow.dftools.dump_dataflow_to_lmdb`. """
def __init__(self, lmdb_path, shuffle=True, keys=None): def __init__(self, lmdb_data):
""" """
Args: Args:
lmdb_path (str): a directory or a file. lmdb_data: a :class:`LMDBData` instance.
shuffle (bool): shuffle the keys or not.
keys (list): list of keys for lmdb file or the key format `'{:0>8d}'`
""" """
super(LMDBDataPoint, self).__init__( def f(dp):
lmdb_path, decoder=lambda k, v: loads(v), shuffle=shuffle, keys=keys) return loads(dp[1])
super(LMDBDataPoint, self).__init__(lmdb_data, f)
class CaffeLMDB(LMDBDataDecoder): def CaffeLMDB(lmdb_path, shuffle=True, keys=None):
""" """
Read a Caffe LMDB file where each value contains a ``caffe.Datum`` protobuf. Read a Caffe LMDB file where each value contains a ``caffe.Datum`` protobuf.
Produces datapoints of the format: [HWC image, label]. Produces datapoints of the format: [HWC image, label].
Note that Caffe LMDB format is not efficient: it stores serialized raw
arrays rather than JPEG images.
Args:
lmdb_path, shuffle, keys: same as :class:`LMDBData`.
Returns:
a :class:`LMDBDataDecoder` instance.
Example: Example:
ds = CaffeLMDB("/tmp/validation", keys='{:0>8d}') ds = CaffeLMDB("/tmp/validation", keys='{:0>8d}')
""" """
def __init__(self, lmdb_path, shuffle=True, keys=None): cpb = get_caffe_pb()
""" lmdb_data = LMDBData(lmdb_path, shuffle, keys)
Args:
lmdb_path, shuffle, keys: same as :class:`LMDBData`. def decoder(k, v):
""" try:
cpb = get_caffe_pb() datum = cpb.Datum()
datum.ParseFromString(v)
def decoder(k, v): img = np.fromstring(datum.data, dtype=np.uint8)
try: img = img.reshape(datum.channels, datum.height, datum.width)
datum = cpb.Datum() except Exception:
datum.ParseFromString(v) log_once("Cannot read key {}".format(k), 'warn')
img = np.fromstring(datum.data, dtype=np.uint8) return None
img = img.reshape(datum.channels, datum.height, datum.width) return [img.transpose(1, 2, 0), datum.label]
except Exception: return LMDBDataDecoder(lmdb_data, decoder)
log_once("Cannot read key {}".format(k), 'warn')
return None
return [img.transpose(1, 2, 0), datum.label]
super(CaffeLMDB, self).__init__(
lmdb_path, decoder=decoder, shuffle=shuffle, keys=keys)
class SVMLightData(RNGDataFlow): class SVMLightData(RNGDataFlow):
......
...@@ -121,14 +121,8 @@ def BatchNorm(x, use_local_stat=None, decay=0.9, epsilon=1e-5): ...@@ -121,14 +121,8 @@ def BatchNorm(x, use_local_stat=None, decay=0.9, epsilon=1e-5):
* ``variance/EMA``: the moving average of variance. * ``variance/EMA``: the moving average of variance.
Note: Note:
* In multi-tower training, only the first training tower maintains a moving average. In multi-tower training, only the first training tower maintains a moving average.
This is consistent with most frameworks. This is consistent with most frameworks.
* It automatically selects :meth:`BatchNormV1` or :meth:`BatchNormV2`
according to availability.
* This is a slightly faster but equivalent version of BatchNormV1. It uses
``fused_batch_norm`` in training.
""" """
shape = x.get_shape().as_list() shape = x.get_shape().as_list()
assert len(shape) in [2, 4] assert len(shape) in [2, 4]
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment