Commit 8759e324 authored by Yuxin Wu's avatar Yuxin Wu

reset rng

parent f4507d45
......@@ -6,7 +6,7 @@
from abc import abstractmethod, ABCMeta
__all__ = ['DataFlow']
__all__ = ['DataFlow', 'ProxyDataFlow']
class DataFlow(object):
__metaclass__ = ABCMeta
......@@ -23,4 +23,18 @@ class DataFlow(object):
"""
raise NotImplementedError()
def reset_state(self):
"""
Reset state of the dataflow (usually the random seed)
"""
pass
class ProxyDataFlow(DataFlow):
def __init__(self, ds):
self.ds = ds
def reset_state(self):
self.ds.reset_state()
def size(self):
return self.ds.size()
......@@ -5,14 +5,15 @@
import numpy as np
import copy
from .base import DataFlow
from .base import DataFlow, ProxyDataFlow
from .imgaug import AugmentorList, Image
from ..utils import *
__all__ = ['BatchData', 'FixedSizeData', 'FakeData', 'MapData',
'MapDataComponent', 'RandomChooseData',
'AugmentImageComponent']
class BatchData(DataFlow):
class BatchData(ProxyDataFlow):
def __init__(self, ds, batch_size, remainder=False):
"""
Group data in ds into batches
......@@ -20,7 +21,7 @@ class BatchData(DataFlow):
remainder: whether to return the remaining data smaller than a batch_size.
if set True, will possibly return a data point of a smaller 1st dimension
"""
self.ds = ds
super(BatchData, self).__init__(ds)
if not remainder:
assert batch_size <= ds.size()
self.batch_size = batch_size
......@@ -60,10 +61,10 @@ class BatchData(DataFlow):
np.array([x[k] for x in data_holder], dtype=tp))
return result
class FixedSizeData(DataFlow):
class FixedSizeData(ProxyDataFlow):
""" generate data from another dataflow, but with a fixed epoch size"""
def __init__(self, ds, size):
self.ds = ds
super(FixedSizeData, self).__init__(ds)
self._size = size
self.itr = None
......@@ -86,13 +87,13 @@ class FixedSizeData(DataFlow):
if cnt == self._size:
return
class RepeatedData(DataFlow):
class RepeatedData(ProxyDataFlow):
""" repeat another dataflow for certain times
if nr == -1, repeat infinitely many times
"""
def __init__(self, ds, nr):
self.nr = nr
self.ds = ds
super(RepeatedData, self).__init__(ds)
def size(self):
if self.nr == -1:
......@@ -117,37 +118,35 @@ class FakeData(DataFlow):
"""
self.shapes = shapes
self._size = size
self.rng = get_rng(self)
def size(self):
return self._size
def reset_state(self):
self.rng = get_rng(self)
def get_data(self):
for _ in xrange(self._size):
yield [np.random.random(k) for k in self.shapes]
yield [self.rng.random_sample(k) for k in self.shapes]
class MapData(DataFlow):
class MapData(ProxyDataFlow):
""" Map a function to the datapoint"""
def __init__(self, ds, func):
self.ds = ds
super(MapData, self).__init_(ds)
self.func = func
def size(self):
return self.ds.size()
def get_data(self):
for dp in self.ds.get_data():
yield self.func(dp)
class MapDataComponent(DataFlow):
class MapDataComponent(ProxyDataFlow):
""" Apply a function to the given index in the datapoint"""
def __init__(self, ds, func, index=0):
self.ds = ds
super(MapDataComponent, self).__init__(ds)
self.func = func
self.index = index
def size(self):
return self.ds.size()
def get_data(self):
for dp in self.ds.get_data():
dp = copy.deepcopy(dp) # avoid modifying the original dp
......@@ -169,6 +168,13 @@ class RandomChooseData(DataFlow):
prob = 1.0 / len(df_lists)
self.df_lists = [(k, prob) for k in df_lists]
def reset_state(self):
for d in self.df_lists:
if isinstance(d, tuple):
d[0].reset_state()
else:
d.reset_state()
def get_data(self):
itrs = [v[0].get_data() for v in self.df_lists]
probs = np.array([v[1] for v in self.df_lists])
......
......@@ -23,6 +23,7 @@ class PrefetchProcess(multiprocessing.Process):
self.queue = queue
def run(self):
self.ds.reset_state()
try:
for dp in self.ds.get_data():
self.queue.put(dp)
......
......@@ -30,7 +30,7 @@ class TrainConfig(object):
initialize variables of a session. default to a new session.
model: a ModelDesc instance
step_per_epoch: the number of steps (parameter updates) to perform
in each epoch. default to dataset.size()
in each epoch.
max_epoch: maximum number of epoch to run training. default to 100
nr_tower: int. number of towers. default to 1.
"""
......@@ -49,7 +49,7 @@ class TrainConfig(object):
assert_type(self.session_config, tf.ConfigProto)
self.session_init = kwargs.pop('session_init', NewSession())
assert_type(self.session_init, SessionInit)
self.step_per_epoch = int(kwargs.pop('step_per_epoch', self.dataset.size()))
self.step_per_epoch = int(kwargs.pop('step_per_epoch'))
self.max_epoch = int(kwargs.pop('max_epoch', 100))
assert self.step_per_epoch > 0 and self.max_epoch > 0
self.nr_tower = int(kwargs.pop('nr_tower', 1))
......
......@@ -87,4 +87,5 @@ def get_global_step():
get_global_step_var())
def get_rng(self):
return np.random.RandomState()
seed = (id(self) + os.getpid()) % 4294967295
return np.random.RandomState(seed)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment