Commit 1126fa5c authored by Yuxin Wu's avatar Yuxin Wu

add rngdataflow as a base

parent c6de1746
......@@ -35,6 +35,15 @@ class DataFlow(object):
"""
pass
class RNGDataFlow(DataFlow):
""" A dataflow with rng"""
def __init__(self):
self.rng = get_rng(self)
def reset_state(self):
self.rng = get_rng(self)
class ProxyDataFlow(DataFlow):
""" Base class for DataFlow that proxies another"""
def __init__(self, ds):
......
......@@ -5,6 +5,7 @@
from __future__ import division
import copy
import numpy as np
from collections import deque
from six.moves import range, map
from .base import DataFlow, ProxyDataFlow
from ..utils import *
......@@ -12,7 +13,7 @@ from ..utils import *
__all__ = ['BatchData', 'FixedSizeData', 'FakeData', 'MapData',
'RepeatedData', 'MapDataComponent', 'RandomChooseData',
'RandomMixData', 'JoinData', 'ConcatData', 'SelectComponent',
'DataFromQueue']
'DataFromQueue', 'LocallyShuffleData']
class BatchData(ProxyDataFlow):
def __init__(self, ds, batch_size, remainder=False):
......@@ -134,16 +135,16 @@ class RepeatedData(ProxyDataFlow):
for dp in self.ds.get_data():
yield dp
class FakeData(DataFlow):
class FakeData(RNGDataFlow):
""" Generate fake random data of given shapes"""
def __init__(self, shapes, size):
"""
:param shapes: a list of lists/tuples
:param size: size of this DataFlow
"""
super(FakeData, self).__init__()
self.shapes = shapes
self._size = int(size)
self.rng = get_rng(self)
def size(self):
return self._size
......@@ -191,7 +192,7 @@ class MapDataComponent(ProxyDataFlow):
dp[self.index] = repl # NOTE modifying
yield dp
class RandomChooseData(DataFlow):
class RandomChooseData(RNGDataFlow):
"""
Randomly choose from several DataFlow. Stop producing when any of them is
exhausted.
......@@ -200,21 +201,21 @@ class RandomChooseData(DataFlow):
"""
:param df_lists: list of dataflow, or list of (dataflow, probability) tuple
"""
super(RandomChooseData, self).__init__()
if isinstance(df_lists[0], (tuple, list)):
assert sum([v[1] for v in df_lists]) == 1.0
self.df_lists = df_lists
else:
prob = 1.0 / len(df_lists)
self.df_lists = [(k, prob) for k in df_lists]
self.rng = get_rng(self)
def reset_state(self):
super(RandomChooseData, self).reset_state()
for d in self.df_lists:
if isinstance(d, tuple):
d[0].reset_state()
else:
d.reset_state()
self.rng = get_rng(self)
def get_data(self):
itrs = [v[0].get_data() for v in self.df_lists]
......@@ -226,7 +227,7 @@ class RandomChooseData(DataFlow):
except StopIteration:
return
class RandomMixData(DataFlow):
class RandomMixData(RNGDataFlow):
"""
Randomly choose from several dataflow, and will eventually exhaust all dataflow. So it's a perfect mix.
"""
......@@ -235,14 +236,14 @@ class RandomMixData(DataFlow):
:param df_lists: list of dataflow.
All DataFlow in `df_lists` must have :func:`size()` implemented
"""
super(RandomMixData, self).__init__()
self.df_lists = df_lists
self.sizes = [k.size() for k in self.df_lists]
self.rng = get_rng(self)
def reset_state(self):
super(RandomMixData, self).reset_state()
for d in self.df_lists:
d.reset_state()
self.rng = get_rng(self)
def size(self):
return sum(self.sizes)
......@@ -318,9 +319,22 @@ class JoinData(DataFlow):
for itr in itrs:
del itr
class LocallyShuffleData(ProxyDataFlow, RNGDataFlow):
def __init__(self, ds, cache_size):
ProxyDataFlow.__init__(self, ds)
RNGDataFlow.__init__(self)
self.cache_size = cache_size
self.q = deque(maxlen=self.cache_size)
def reset_state(self):
RNGDataFlow.reset_state(self)
def get_data(self):
# TODO
pass
class DataFromQueue(DataFlow):
""" provide data from a queue
"""
""" Provide data from a queue """
def __init__(self, queue):
self.queue = queue
......
......@@ -51,7 +51,7 @@ class SaverRestore(SessionInit):
"""
Restore an old model saved by `ModelSaver`.
"""
def __init__(self, model_path):
def __init__(self, model_path, prefix=None):
"""
:param model_path: a model file or a ``checkpoint`` file.
"""
......@@ -61,12 +61,13 @@ class SaverRestore(SessionInit):
os.path.dirname(model_path)).model_checkpoint_path
assert os.path.isfile(model_path)
self.set_path(model_path)
self.prefix = prefix
def _init(self, sess):
logger.info(
"Restoring checkpoint from {}.".format(self.path))
chkpt_vars = SaverRestore._read_checkpoint_vars(self.path)
vars_map = SaverRestore._get_vars_to_restore_multimap(chkpt_vars)
vars_map = self._get_vars_to_restore_multimap(chkpt_vars)
for dic in SaverRestore._produce_restore_dict(vars_map):
# multiple saver under same name scope would cause error:
# training/saver.py: assert restore_op.name.endswith("restore_all"), restore_op.name
......@@ -93,6 +94,7 @@ class SaverRestore(SessionInit):
@staticmethod
def _read_checkpoint_vars(model_path):
""" return a set of strings """
reader = tf.train.NewCheckpointReader(model_path)
ckpt_vars = reader.get_variable_to_shape_map().keys()
for v in ckpt_vars:
......@@ -100,11 +102,10 @@ class SaverRestore(SessionInit):
logger.warn("Found {} in checkpoint. Anything from prediction tower shouldn't be saved.".format(v.name))
return set(ckpt_vars)
@staticmethod
def _get_vars_to_restore_multimap(vars_available):
def _get_vars_to_restore_multimap(self, vars_available):
"""
Get a dict of {var_name: [var, var]} to restore
:param vars_available: varaibles available in the checkpoint, for existence checking
:param vars_available: varaible names available in the checkpoint, for existence checking
"""
vars_to_restore = tf.all_variables()
var_dict = defaultdict(list)
......@@ -117,6 +118,8 @@ class SaverRestore(SessionInit):
if 'tower' in name:
new_name = re.sub('tower[p0-9]+/', '', name)
name = new_name
if self.prefix and name.startswith(self.prefix):
name = name[len(self.prefix)+1:]
if name in vars_available:
var_dict[name].append(v)
vars_available.remove(name)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment