Commit 8759e324 authored by Yuxin Wu's avatar Yuxin Wu

reset rng

parent f4507d45
...@@ -6,7 +6,7 @@ ...@@ -6,7 +6,7 @@
from abc import abstractmethod, ABCMeta from abc import abstractmethod, ABCMeta
__all__ = ['DataFlow'] __all__ = ['DataFlow', 'ProxyDataFlow']
class DataFlow(object): class DataFlow(object):
__metaclass__ = ABCMeta __metaclass__ = ABCMeta
...@@ -23,4 +23,18 @@ class DataFlow(object): ...@@ -23,4 +23,18 @@ class DataFlow(object):
""" """
raise NotImplementedError() raise NotImplementedError()
def reset_state(self):
"""
Reset state of the dataflow (usually the random seed)
"""
pass
class ProxyDataFlow(DataFlow):
def __init__(self, ds):
self.ds = ds
def reset_state(self):
self.ds.reset_state()
def size(self):
return self.ds.size()
...@@ -5,14 +5,15 @@ ...@@ -5,14 +5,15 @@
import numpy as np import numpy as np
import copy import copy
from .base import DataFlow from .base import DataFlow, ProxyDataFlow
from .imgaug import AugmentorList, Image from .imgaug import AugmentorList, Image
from ..utils import *
__all__ = ['BatchData', 'FixedSizeData', 'FakeData', 'MapData', __all__ = ['BatchData', 'FixedSizeData', 'FakeData', 'MapData',
'MapDataComponent', 'RandomChooseData', 'MapDataComponent', 'RandomChooseData',
'AugmentImageComponent'] 'AugmentImageComponent']
class BatchData(DataFlow): class BatchData(ProxyDataFlow):
def __init__(self, ds, batch_size, remainder=False): def __init__(self, ds, batch_size, remainder=False):
""" """
Group data in ds into batches Group data in ds into batches
...@@ -20,7 +21,7 @@ class BatchData(DataFlow): ...@@ -20,7 +21,7 @@ class BatchData(DataFlow):
remainder: whether to return the remaining data smaller than a batch_size. remainder: whether to return the remaining data smaller than a batch_size.
if set True, will possibly return a data point of a smaller 1st dimension if set True, will possibly return a data point of a smaller 1st dimension
""" """
self.ds = ds super(BatchData, self).__init__(ds)
if not remainder: if not remainder:
assert batch_size <= ds.size() assert batch_size <= ds.size()
self.batch_size = batch_size self.batch_size = batch_size
...@@ -60,10 +61,10 @@ class BatchData(DataFlow): ...@@ -60,10 +61,10 @@ class BatchData(DataFlow):
np.array([x[k] for x in data_holder], dtype=tp)) np.array([x[k] for x in data_holder], dtype=tp))
return result return result
class FixedSizeData(DataFlow): class FixedSizeData(ProxyDataFlow):
""" generate data from another dataflow, but with a fixed epoch size""" """ generate data from another dataflow, but with a fixed epoch size"""
def __init__(self, ds, size): def __init__(self, ds, size):
self.ds = ds super(FixedSizeData, self).__init__(ds)
self._size = size self._size = size
self.itr = None self.itr = None
...@@ -86,13 +87,13 @@ class FixedSizeData(DataFlow): ...@@ -86,13 +87,13 @@ class FixedSizeData(DataFlow):
if cnt == self._size: if cnt == self._size:
return return
class RepeatedData(DataFlow): class RepeatedData(ProxyDataFlow):
""" repeat another dataflow for certain times """ repeat another dataflow for certain times
if nr == -1, repeat infinitely many times if nr == -1, repeat infinitely many times
""" """
def __init__(self, ds, nr): def __init__(self, ds, nr):
self.nr = nr self.nr = nr
self.ds = ds super(RepeatedData, self).__init__(ds)
def size(self): def size(self):
if self.nr == -1: if self.nr == -1:
...@@ -117,37 +118,35 @@ class FakeData(DataFlow): ...@@ -117,37 +118,35 @@ class FakeData(DataFlow):
""" """
self.shapes = shapes self.shapes = shapes
self._size = size self._size = size
self.rng = get_rng(self)
def size(self): def size(self):
return self._size return self._size
def reset_state(self):
self.rng = get_rng(self)
def get_data(self): def get_data(self):
for _ in xrange(self._size): for _ in xrange(self._size):
yield [np.random.random(k) for k in self.shapes] yield [self.rng.random_sample(k) for k in self.shapes]
class MapData(DataFlow): class MapData(ProxyDataFlow):
""" Map a function to the datapoint""" """ Map a function to the datapoint"""
def __init__(self, ds, func): def __init__(self, ds, func):
self.ds = ds super(MapData, self).__init_(ds)
self.func = func self.func = func
def size(self):
return self.ds.size()
def get_data(self): def get_data(self):
for dp in self.ds.get_data(): for dp in self.ds.get_data():
yield self.func(dp) yield self.func(dp)
class MapDataComponent(DataFlow): class MapDataComponent(ProxyDataFlow):
""" Apply a function to the given index in the datapoint""" """ Apply a function to the given index in the datapoint"""
def __init__(self, ds, func, index=0): def __init__(self, ds, func, index=0):
self.ds = ds super(MapDataComponent, self).__init__(ds)
self.func = func self.func = func
self.index = index self.index = index
def size(self):
return self.ds.size()
def get_data(self): def get_data(self):
for dp in self.ds.get_data(): for dp in self.ds.get_data():
dp = copy.deepcopy(dp) # avoid modifying the original dp dp = copy.deepcopy(dp) # avoid modifying the original dp
...@@ -169,6 +168,13 @@ class RandomChooseData(DataFlow): ...@@ -169,6 +168,13 @@ class RandomChooseData(DataFlow):
prob = 1.0 / len(df_lists) prob = 1.0 / len(df_lists)
self.df_lists = [(k, prob) for k in df_lists] self.df_lists = [(k, prob) for k in df_lists]
def reset_state(self):
for d in self.df_lists:
if isinstance(d, tuple):
d[0].reset_state()
else:
d.reset_state()
def get_data(self): def get_data(self):
itrs = [v[0].get_data() for v in self.df_lists] itrs = [v[0].get_data() for v in self.df_lists]
probs = np.array([v[1] for v in self.df_lists]) probs = np.array([v[1] for v in self.df_lists])
......
...@@ -23,6 +23,7 @@ class PrefetchProcess(multiprocessing.Process): ...@@ -23,6 +23,7 @@ class PrefetchProcess(multiprocessing.Process):
self.queue = queue self.queue = queue
def run(self): def run(self):
self.ds.reset_state()
try: try:
for dp in self.ds.get_data(): for dp in self.ds.get_data():
self.queue.put(dp) self.queue.put(dp)
......
...@@ -30,7 +30,7 @@ class TrainConfig(object): ...@@ -30,7 +30,7 @@ class TrainConfig(object):
initialize variables of a session. default to a new session. initialize variables of a session. default to a new session.
model: a ModelDesc instance model: a ModelDesc instance
step_per_epoch: the number of steps (parameter updates) to perform step_per_epoch: the number of steps (parameter updates) to perform
in each epoch. default to dataset.size() in each epoch.
max_epoch: maximum number of epoch to run training. default to 100 max_epoch: maximum number of epoch to run training. default to 100
nr_tower: int. number of towers. default to 1. nr_tower: int. number of towers. default to 1.
""" """
...@@ -49,7 +49,7 @@ class TrainConfig(object): ...@@ -49,7 +49,7 @@ class TrainConfig(object):
assert_type(self.session_config, tf.ConfigProto) assert_type(self.session_config, tf.ConfigProto)
self.session_init = kwargs.pop('session_init', NewSession()) self.session_init = kwargs.pop('session_init', NewSession())
assert_type(self.session_init, SessionInit) assert_type(self.session_init, SessionInit)
self.step_per_epoch = int(kwargs.pop('step_per_epoch', self.dataset.size())) self.step_per_epoch = int(kwargs.pop('step_per_epoch'))
self.max_epoch = int(kwargs.pop('max_epoch', 100)) self.max_epoch = int(kwargs.pop('max_epoch', 100))
assert self.step_per_epoch > 0 and self.max_epoch > 0 assert self.step_per_epoch > 0 and self.max_epoch > 0
self.nr_tower = int(kwargs.pop('nr_tower', 1)) self.nr_tower = int(kwargs.pop('nr_tower', 1))
......
...@@ -87,4 +87,5 @@ def get_global_step(): ...@@ -87,4 +87,5 @@ def get_global_step():
get_global_step_var()) get_global_step_var())
def get_rng(self): def get_rng(self):
return np.random.RandomState() seed = (id(self) + os.getpid()) % 4294967295
return np.random.RandomState(seed)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment