Commit 2f3b8502 authored by Yuxin Wu's avatar Yuxin Wu

always use reset_state

parent 6607d856
......@@ -9,6 +9,8 @@ class DisturbLabel(ProxyDataFlow):
def __init__(self, ds, prob):
super(DisturbLabel, self).__init__(ds)
self.prob = prob
def reset_state(self):
self.rng = get_rng(self)
def get_data(self):
......
......@@ -29,6 +29,7 @@ args = parser.parse_args()
get_config_func = imp.load_source('config_script', args.config).get_config
config = get_config_func()
config.dataset.reset_state()
if args.output:
mkdir_p(args.output)
......
......@@ -24,6 +24,9 @@ class ExpReplay(DataFlow, Callback):
"""
Implement experience replay in the paper
`Human-level control through deep reinforcement learning`.
This implementation provides the interface as an DataFlow.
This DataFlow is not fork-safe (doesn't support multiprocess prefetching)
"""
def __init__(self,
predictor,
......@@ -80,9 +83,6 @@ class ExpReplay(DataFlow, Callback):
pbar.update()
self._init_memory_flag.set()
def reset_state(self):
raise RuntimeError("Don't run me in multiple processes")
def _populate_exp(self):
""" populate a transition by epsilon-greedy"""
old_s = self.player.current_state()
......
......@@ -106,6 +106,7 @@ class InferenceRunner(Callback):
vc.before_inference()
sess = tf.get_default_session()
self.ds.reset_state()
with tqdm(total=self.ds.size(), ascii=True) as pbar:
for dp in self.ds.get_data():
#feed = dict(zip(self.input_vars, dp)) # TODO custom dp mapping?
......
......@@ -29,7 +29,7 @@ class DataFlow(object):
def reset_state(self):
"""
Reset state of the dataflow,
Reset state of the dataflow. Will always be called before consuming data points.
for example, RNG **HAS** to be reset here if used in the DataFlow.
Otherwise it may not work well with prefetching, because different
processes will have the same RNG state.
......@@ -39,9 +39,6 @@ class DataFlow(object):
class RNGDataFlow(DataFlow):
""" A dataflow with rng"""
def __init__(self):
self.rng = get_rng(self)
def reset_state(self):
self.rng = get_rng(self)
......
......@@ -306,11 +306,7 @@ class JoinData(DataFlow):
class LocallyShuffleData(ProxyDataFlow, RNGDataFlow):
def __init__(self, ds, cache_size):
ProxyDataFlow.__init__(self, ds)
RNGDataFlow.__init__(self)
self.q = deque(maxlen=cache_size)
self.ds_wrap = RepeatedData(ds, -1)
self.ds_itr = self.ds_wrap.get_data()
self.current_cnt = 0
def reset_state(self):
ProxyDataFlow.reset_state(self)
......
......@@ -9,7 +9,7 @@ import numpy as np
from ...utils import logger, get_rng, get_dataset_dir
from ...utils.fs import download
from ..base import DataFlow
from ..base import RNGDataFlow
try:
from scipy.io import loadmat
......@@ -21,7 +21,7 @@ except ImportError:
DATA_URL = "http://www.eecs.berkeley.edu/Research/Projects/CS/vision/grouping/BSR/BSR_bsds500.tgz"
IMG_W, IMG_H = 481, 321
class BSDS500(DataFlow):
class BSDS500(RNGDataFlow):
"""
`Berkeley Segmentation Data Set and Benchmarks 500
<http://www.eecs.berkeley.edu/Research/Projects/CS/vision/grouping/resources.html#bsds500>`_.
......@@ -53,10 +53,6 @@ class BSDS500(DataFlow):
self.shuffle = shuffle
assert name in ['train', 'test', 'val']
self._load(name)
self.rng = get_rng(self)
def reset_state(self):
self.rng = get_rng(self)
def _load(self, name):
image_glob = os.path.join(self.data_root, 'images', name, '*.jpg')
......
......@@ -15,7 +15,7 @@ import logging
from ...utils import logger, get_rng, get_dataset_dir
from ...utils.fs import download
from ..base import DataFlow
from ..base import RNGDataFlow
__all__ = ['Cifar10', 'Cifar100']
......@@ -77,7 +77,7 @@ def get_filenames(dir, cifar_classnum):
os.path.join(dir, 'cifar-100-python', 'test')]
return filenames
class CifarBase(DataFlow):
class CifarBase(RNGDataFlow):
"""
Return [image, label],
image is 32x32x3 in the range [0,255]
......@@ -106,10 +106,6 @@ class CifarBase(DataFlow):
self.data = read_cifar(self.fs, cifar_classnum)
self.dir = dir
self.shuffle = shuffle
self.rng = get_rng(self)
def reset_state(self):
self.rng = get_rng(self)
def size(self):
return 50000 if self.train_or_test == 'train' else 10000
......
......@@ -11,7 +11,7 @@ from six.moves import range
from ...utils import logger, get_rng, get_dataset_dir, memoized
from ...utils.loadcaffe import get_caffe_pb
from ...utils.fs import mkdir_p, download
from ..base import DataFlow
from ..base import RNGDataFlow
__all__ = ['ILSVRCMeta', 'ILSVRC12']
......@@ -79,7 +79,7 @@ class ILSVRCMeta(object):
arr = cv2.resize(arr, size[::-1])
return arr
class ILSVRC12(DataFlow):
class ILSVRC12(RNGDataFlow):
def __init__(self, dir, name, meta_dir=None, shuffle=True):
"""
:param dir: A directory containing a subdir named `name`, where the
......@@ -119,17 +119,10 @@ class ILSVRC12(DataFlow):
self.shuffle = shuffle
self.meta = ILSVRCMeta(meta_dir)
self.imglist = self.meta.get_image_list(name)
self.rng = get_rng(self)
def size(self):
return len(self.imglist)
def reset_state(self):
"""
reset rng for shuffle
"""
self.rng = get_rng(self)
def get_data(self):
"""
Produce original images or shape [h, w, 3], and label
......
......@@ -11,7 +11,7 @@ from six.moves import urllib, range
from ...utils import logger, get_dataset_dir
from ...utils.fs import download
from ..base import DataFlow
from ..base import RNGDataFlow
__all__ = ['Mnist']
......@@ -92,7 +92,7 @@ class DataSet(object):
def num_examples(self):
return self._num_examples
class Mnist(DataFlow):
class Mnist(RNGDataFlow):
"""
Return [image, label],
image is 28x28 in the range [0,1]
......@@ -136,7 +136,7 @@ class Mnist(DataFlow):
ds = self.train if self.train_or_test == 'train' else self.test
idxs = list(range(ds.num_examples))
if self.shuffle:
random.shuffle(idxs)
self.rng.shuffle(idxs)
for k in idxs:
img = ds.images[k].reshape((28, 28))
label = ds.labels[k]
......
......@@ -9,7 +9,7 @@ import numpy as np
from six.moves import range
from ...utils import logger, get_rng, get_dataset_dir
from ..base import DataFlow
from ..base import RNGDataFlow
try:
import scipy.io
......@@ -20,7 +20,7 @@ except ImportError:
SVHN_URL = "http://ufldl.stanford.edu/housenumbers/"
class SVHNDigit(DataFlow):
class SVHNDigit(RNGDataFlow):
"""
SVHN Cropped Digit Dataset
return img of 32x32x3, label of 0-9
......@@ -33,7 +33,6 @@ class SVHNDigit(DataFlow):
:param data_dir: a directory containing the original {train,test,extra}_32x32.mat
"""
self.shuffle = shuffle
self.rng = get_rng(self)
if name in SVHNDigit.Cache:
self.X, self.Y = SVHNDigit.Cache[name]
......@@ -54,9 +53,6 @@ class SVHNDigit(DataFlow):
def size(self):
return self.X.shape[0]
def reset_state(self):
self.rng = get_rng(self)
def get_data(self):
n = self.X.shape[0]
idxs = np.arange(n)
......
......@@ -23,6 +23,7 @@ def dump_dataset_images(ds, dirname, max_count=None, index=0):
mkdir_p(dirname)
if max_count is None:
max_count = sys.maxint
ds.reset_state()
for i, dp in enumerate(ds.get_data()):
if i % 100 == 0:
print(i)
......@@ -34,6 +35,7 @@ def dump_dataset_images(ds, dirname, max_count=None, index=0):
def dataflow_to_process_queue(ds, size, nr_consumer):
"""
Convert a `DataFlow` to a multiprocessing.Queue.
The dataflow will only be reset in the spawned process.
:param ds: a `DataFlow`
:param size: size of the queue
......@@ -50,6 +52,7 @@ def dataflow_to_process_queue(ds, size, nr_consumer):
self.q = q
def run(self):
self.ds.reset_state()
try:
for idx, dp in enumerate(self.ds.get_data()):
self.q.put((idx, dp))
......
......@@ -5,7 +5,7 @@
from ..utils import logger, get_rng
from ..utils.timer import timed_operation
from ..utils.loadcaffe import get_caffe_pb
from .base import DataFlow
from .base import RNGDataFlow
import random
from tqdm import tqdm
......@@ -31,7 +31,7 @@ else:
Adapters for different data format.
"""
class HDF5Data(DataFlow):
class HDF5Data(RNGDataFlow):
"""
Zip data from different paths in an HDF5 file. Will load all data into memory.
"""
......@@ -55,19 +55,18 @@ class HDF5Data(DataFlow):
def get_data(self):
idxs = list(range(self._size))
if self.shuffle:
random.shuffle(idxs)
self.rng.shuffle(idxs)
for k in idxs:
yield [dp[k] for dp in self.dps]
class LMDBData(DataFlow):
class LMDBData(RNGDataFlow):
""" Read a lmdb and produce k,v pair """
def __init__(self, lmdb_dir, shuffle=True):
self._lmdb = lmdb.open(lmdb_dir, readonly=True, lock=False,
map_size=1099511627776 * 2, max_readers=100)
self._txn = self._lmdb.begin()
self._shuffle = shuffle
self.rng = get_rng(self)
self._size = self._txn.stat()['entries']
if shuffle:
self.keys = self._txn.get('__keys__')
......@@ -81,8 +80,8 @@ class LMDBData(DataFlow):
pbar.update()
def reset_state(self):
super(LMDBData, self).reset_state()
self._txn = self._lmdb.begin()
self.rng = get_rng(self)
def size(self):
return self._size
......@@ -96,8 +95,8 @@ class LMDBData(DataFlow):
yield [k, v]
else:
s = self.size()
for i in range(s):
k = self.rng.choice(self.keys)
self.rng.shuffle(self.keys)
for k in self.keys:
v = self._txn.get(k)
yield [k, v]
......
......@@ -35,7 +35,7 @@ class PrefetchProcess(multiprocessing.Process):
self.queue = queue
def run(self):
# reset RNG of ds so each process will produce different data
# reset all ds so each process will produce different data
self.ds.reset_state()
while True:
for dp in self.ds.get_data():
......@@ -73,6 +73,10 @@ class PrefetchData(ProxyDataFlow):
dp = self.queue.get()
yield dp
def reset_state(self):
# do nothing. all ds are reset once and only once in spawned processes
pass
class PrefetchProcessZMQ(multiprocessing.Process):
def __init__(self, ds, conn_name):
"""
......@@ -134,6 +138,10 @@ class PrefetchDataZMQ(ProxyDataFlow):
dp = loads(self.socket.recv(copy=False))
yield dp
def reset_state(self):
# do nothing. all ds are reset once and only once in spawned processes
pass
def __del__(self):
# on exit, logger may not be functional anymore
try:
......
......@@ -23,6 +23,7 @@ def serve_data(ds, addr):
socket.bind(addr)
ds = RepeatedData(ds, -1)
try:
ds.reset_state()
logger.info("Serving data at {}".format(addr))
while True:
for dp in ds.get_data():
......
......@@ -19,7 +19,7 @@ def describe_model():
v.name, shape.as_list(), ele))
size_mb = total * 4 / 1024.0**2
msg.append("Total param={} ({:01f} MB assuming all float32)".format(total, size_mb))
logger.info("Model Params: {}".format('\n'.join(msg)))
logger.info("Model Parameters: {}".format('\n'.join(msg)))
def get_shape_str(tensors):
......
......@@ -40,6 +40,7 @@ class SimpleTrainer(Trainer):
self.init_session_and_coord()
describe_model()
# create an infinte data producer
self.config.dataset.reset_state()
self.data_producer = RepeatedData(self.config.dataset, -1).get_data()
self.main_loop()
......@@ -62,21 +63,22 @@ class SimpleTrainer(Trainer):
return func
class EnqueueThread(threading.Thread):
def __init__(self, trainer, queue, enqueue_op, raw_input_var):
def __init__(self, trainer):
super(EnqueueThread, self).__init__()
self.sess = trainer.sess
self.coord = trainer.coord
self.dataflow = RepeatedData(trainer.config.dataset, -1)
self.input_vars = raw_input_var
self.op = enqueue_op
self.queue = queue
self.input_vars = trainer.input_vars
self.queue = trainer.input_queue
self.op = self.queue.enqueue(self.input_vars)
self.close_op = self.queue.close(cancel_pending_enqueues=True)
self.size_op = self.queue.size()
self.daemon = True
def run(self):
self.dataflow.reset_state()
with self.sess.as_default():
try:
while True:
......@@ -155,8 +157,7 @@ class QueueInputTrainer(Trainer):
def _build_enque_thread(self):
""" create a thread that keeps filling the queue """
enqueue_op = self.input_queue.enqueue(self.input_vars)
self.input_th = EnqueueThread(self, self.input_queue, enqueue_op, self.input_vars)
self.input_th = EnqueueThread(self)
self.extra_threads_procs.append(self.input_th)
def train(self):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment