Commit 1126fa5c authored by Yuxin Wu's avatar Yuxin Wu

add rngdataflow as a base

parent c6de1746
...@@ -35,6 +35,15 @@ class DataFlow(object): ...@@ -35,6 +35,15 @@ class DataFlow(object):
""" """
pass pass
class RNGDataFlow(DataFlow):
""" A dataflow with rng"""
def __init__(self):
self.rng = get_rng(self)
def reset_state(self):
self.rng = get_rng(self)
class ProxyDataFlow(DataFlow): class ProxyDataFlow(DataFlow):
""" Base class for DataFlow that proxies another""" """ Base class for DataFlow that proxies another"""
def __init__(self, ds): def __init__(self, ds):
......
...@@ -5,6 +5,7 @@ ...@@ -5,6 +5,7 @@
from __future__ import division from __future__ import division
import copy import copy
import numpy as np import numpy as np
from collections import deque
from six.moves import range, map from six.moves import range, map
from .base import DataFlow, ProxyDataFlow from .base import DataFlow, ProxyDataFlow
from ..utils import * from ..utils import *
...@@ -12,7 +13,7 @@ from ..utils import * ...@@ -12,7 +13,7 @@ from ..utils import *
__all__ = ['BatchData', 'FixedSizeData', 'FakeData', 'MapData', __all__ = ['BatchData', 'FixedSizeData', 'FakeData', 'MapData',
'RepeatedData', 'MapDataComponent', 'RandomChooseData', 'RepeatedData', 'MapDataComponent', 'RandomChooseData',
'RandomMixData', 'JoinData', 'ConcatData', 'SelectComponent', 'RandomMixData', 'JoinData', 'ConcatData', 'SelectComponent',
'DataFromQueue'] 'DataFromQueue', 'LocallyShuffleData']
class BatchData(ProxyDataFlow): class BatchData(ProxyDataFlow):
def __init__(self, ds, batch_size, remainder=False): def __init__(self, ds, batch_size, remainder=False):
...@@ -134,16 +135,16 @@ class RepeatedData(ProxyDataFlow): ...@@ -134,16 +135,16 @@ class RepeatedData(ProxyDataFlow):
for dp in self.ds.get_data(): for dp in self.ds.get_data():
yield dp yield dp
class FakeData(DataFlow): class FakeData(RNGDataFlow):
""" Generate fake random data of given shapes""" """ Generate fake random data of given shapes"""
def __init__(self, shapes, size): def __init__(self, shapes, size):
""" """
:param shapes: a list of lists/tuples :param shapes: a list of lists/tuples
:param size: size of this DataFlow :param size: size of this DataFlow
""" """
super(FakeData, self).__init__()
self.shapes = shapes self.shapes = shapes
self._size = int(size) self._size = int(size)
self.rng = get_rng(self)
def size(self): def size(self):
return self._size return self._size
...@@ -191,7 +192,7 @@ class MapDataComponent(ProxyDataFlow): ...@@ -191,7 +192,7 @@ class MapDataComponent(ProxyDataFlow):
dp[self.index] = repl # NOTE modifying dp[self.index] = repl # NOTE modifying
yield dp yield dp
class RandomChooseData(DataFlow): class RandomChooseData(RNGDataFlow):
""" """
Randomly choose from several DataFlow. Stop producing when any of them is Randomly choose from several DataFlow. Stop producing when any of them is
exhausted. exhausted.
...@@ -200,21 +201,21 @@ class RandomChooseData(DataFlow): ...@@ -200,21 +201,21 @@ class RandomChooseData(DataFlow):
""" """
:param df_lists: list of dataflow, or list of (dataflow, probability) tuple :param df_lists: list of dataflow, or list of (dataflow, probability) tuple
""" """
super(RandomChooseData, self).__init__()
if isinstance(df_lists[0], (tuple, list)): if isinstance(df_lists[0], (tuple, list)):
assert sum([v[1] for v in df_lists]) == 1.0 assert sum([v[1] for v in df_lists]) == 1.0
self.df_lists = df_lists self.df_lists = df_lists
else: else:
prob = 1.0 / len(df_lists) prob = 1.0 / len(df_lists)
self.df_lists = [(k, prob) for k in df_lists] self.df_lists = [(k, prob) for k in df_lists]
self.rng = get_rng(self)
def reset_state(self): def reset_state(self):
super(RandomChooseData, self).reset_state()
for d in self.df_lists: for d in self.df_lists:
if isinstance(d, tuple): if isinstance(d, tuple):
d[0].reset_state() d[0].reset_state()
else: else:
d.reset_state() d.reset_state()
self.rng = get_rng(self)
def get_data(self): def get_data(self):
itrs = [v[0].get_data() for v in self.df_lists] itrs = [v[0].get_data() for v in self.df_lists]
...@@ -226,7 +227,7 @@ class RandomChooseData(DataFlow): ...@@ -226,7 +227,7 @@ class RandomChooseData(DataFlow):
except StopIteration: except StopIteration:
return return
class RandomMixData(DataFlow): class RandomMixData(RNGDataFlow):
""" """
Randomly choose from several dataflow, and will eventually exhaust all dataflow. So it's a perfect mix. Randomly choose from several dataflow, and will eventually exhaust all dataflow. So it's a perfect mix.
""" """
...@@ -235,14 +236,14 @@ class RandomMixData(DataFlow): ...@@ -235,14 +236,14 @@ class RandomMixData(DataFlow):
:param df_lists: list of dataflow. :param df_lists: list of dataflow.
All DataFlow in `df_lists` must have :func:`size()` implemented All DataFlow in `df_lists` must have :func:`size()` implemented
""" """
super(RandomMixData, self).__init__()
self.df_lists = df_lists self.df_lists = df_lists
self.sizes = [k.size() for k in self.df_lists] self.sizes = [k.size() for k in self.df_lists]
self.rng = get_rng(self)
def reset_state(self): def reset_state(self):
super(RandomMixData, self).reset_state()
for d in self.df_lists: for d in self.df_lists:
d.reset_state() d.reset_state()
self.rng = get_rng(self)
def size(self): def size(self):
return sum(self.sizes) return sum(self.sizes)
...@@ -318,9 +319,22 @@ class JoinData(DataFlow): ...@@ -318,9 +319,22 @@ class JoinData(DataFlow):
for itr in itrs: for itr in itrs:
del itr del itr
class LocallyShuffleData(ProxyDataFlow, RNGDataFlow):
def __init__(self, ds, cache_size):
ProxyDataFlow.__init__(self, ds)
RNGDataFlow.__init__(self)
self.cache_size = cache_size
self.q = deque(maxlen=self.cache_size)
def reset_state(self):
RNGDataFlow.reset_state(self)
def get_data(self):
# TODO
pass
class DataFromQueue(DataFlow): class DataFromQueue(DataFlow):
""" provide data from a queue """ Provide data from a queue """
"""
def __init__(self, queue): def __init__(self, queue):
self.queue = queue self.queue = queue
......
...@@ -51,7 +51,7 @@ class SaverRestore(SessionInit): ...@@ -51,7 +51,7 @@ class SaverRestore(SessionInit):
""" """
Restore an old model saved by `ModelSaver`. Restore an old model saved by `ModelSaver`.
""" """
def __init__(self, model_path): def __init__(self, model_path, prefix=None):
""" """
:param model_path: a model file or a ``checkpoint`` file. :param model_path: a model file or a ``checkpoint`` file.
""" """
...@@ -61,12 +61,13 @@ class SaverRestore(SessionInit): ...@@ -61,12 +61,13 @@ class SaverRestore(SessionInit):
os.path.dirname(model_path)).model_checkpoint_path os.path.dirname(model_path)).model_checkpoint_path
assert os.path.isfile(model_path) assert os.path.isfile(model_path)
self.set_path(model_path) self.set_path(model_path)
self.prefix = prefix
def _init(self, sess): def _init(self, sess):
logger.info( logger.info(
"Restoring checkpoint from {}.".format(self.path)) "Restoring checkpoint from {}.".format(self.path))
chkpt_vars = SaverRestore._read_checkpoint_vars(self.path) chkpt_vars = SaverRestore._read_checkpoint_vars(self.path)
vars_map = SaverRestore._get_vars_to_restore_multimap(chkpt_vars) vars_map = self._get_vars_to_restore_multimap(chkpt_vars)
for dic in SaverRestore._produce_restore_dict(vars_map): for dic in SaverRestore._produce_restore_dict(vars_map):
# multiple saver under same name scope would cause error: # multiple saver under same name scope would cause error:
# training/saver.py: assert restore_op.name.endswith("restore_all"), restore_op.name # training/saver.py: assert restore_op.name.endswith("restore_all"), restore_op.name
...@@ -93,6 +94,7 @@ class SaverRestore(SessionInit): ...@@ -93,6 +94,7 @@ class SaverRestore(SessionInit):
@staticmethod @staticmethod
def _read_checkpoint_vars(model_path): def _read_checkpoint_vars(model_path):
""" return a set of strings """
reader = tf.train.NewCheckpointReader(model_path) reader = tf.train.NewCheckpointReader(model_path)
ckpt_vars = reader.get_variable_to_shape_map().keys() ckpt_vars = reader.get_variable_to_shape_map().keys()
for v in ckpt_vars: for v in ckpt_vars:
...@@ -100,11 +102,10 @@ class SaverRestore(SessionInit): ...@@ -100,11 +102,10 @@ class SaverRestore(SessionInit):
logger.warn("Found {} in checkpoint. Anything from prediction tower shouldn't be saved.".format(v.name)) logger.warn("Found {} in checkpoint. Anything from prediction tower shouldn't be saved.".format(v.name))
return set(ckpt_vars) return set(ckpt_vars)
@staticmethod def _get_vars_to_restore_multimap(self, vars_available):
def _get_vars_to_restore_multimap(vars_available):
""" """
Get a dict of {var_name: [var, var]} to restore Get a dict of {var_name: [var, var]} to restore
:param vars_available: varaibles available in the checkpoint, for existence checking :param vars_available: varaible names available in the checkpoint, for existence checking
""" """
vars_to_restore = tf.all_variables() vars_to_restore = tf.all_variables()
var_dict = defaultdict(list) var_dict = defaultdict(list)
...@@ -117,6 +118,8 @@ class SaverRestore(SessionInit): ...@@ -117,6 +118,8 @@ class SaverRestore(SessionInit):
if 'tower' in name: if 'tower' in name:
new_name = re.sub('tower[p0-9]+/', '', name) new_name = re.sub('tower[p0-9]+/', '', name)
name = new_name name = new_name
if self.prefix and name.startswith(self.prefix):
name = name[len(self.prefix)+1:]
if name in vars_available: if name in vars_available:
var_dict[name].append(v) var_dict[name].append(v)
vars_available.remove(name) vars_available.remove(name)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment