Commit 1095c8b8 authored by Yuxin Wu's avatar Yuxin Wu

Merge branch 'reset-state'

parents 6607d856 2f3b8502
...@@ -9,6 +9,8 @@ class DisturbLabel(ProxyDataFlow): ...@@ -9,6 +9,8 @@ class DisturbLabel(ProxyDataFlow):
def __init__(self, ds, prob): def __init__(self, ds, prob):
super(DisturbLabel, self).__init__(ds) super(DisturbLabel, self).__init__(ds)
self.prob = prob self.prob = prob
def reset_state(self):
self.rng = get_rng(self) self.rng = get_rng(self)
def get_data(self): def get_data(self):
......
...@@ -29,6 +29,7 @@ args = parser.parse_args() ...@@ -29,6 +29,7 @@ args = parser.parse_args()
get_config_func = imp.load_source('config_script', args.config).get_config get_config_func = imp.load_source('config_script', args.config).get_config
config = get_config_func() config = get_config_func()
config.dataset.reset_state()
if args.output: if args.output:
mkdir_p(args.output) mkdir_p(args.output)
......
...@@ -24,6 +24,9 @@ class ExpReplay(DataFlow, Callback): ...@@ -24,6 +24,9 @@ class ExpReplay(DataFlow, Callback):
""" """
Implement experience replay in the paper Implement experience replay in the paper
`Human-level control through deep reinforcement learning`. `Human-level control through deep reinforcement learning`.
This implementation provides the interface as an DataFlow.
This DataFlow is not fork-safe (doesn't support multiprocess prefetching)
""" """
def __init__(self, def __init__(self,
predictor, predictor,
...@@ -80,9 +83,6 @@ class ExpReplay(DataFlow, Callback): ...@@ -80,9 +83,6 @@ class ExpReplay(DataFlow, Callback):
pbar.update() pbar.update()
self._init_memory_flag.set() self._init_memory_flag.set()
def reset_state(self):
raise RuntimeError("Don't run me in multiple processes")
def _populate_exp(self): def _populate_exp(self):
""" populate a transition by epsilon-greedy""" """ populate a transition by epsilon-greedy"""
old_s = self.player.current_state() old_s = self.player.current_state()
......
...@@ -106,6 +106,7 @@ class InferenceRunner(Callback): ...@@ -106,6 +106,7 @@ class InferenceRunner(Callback):
vc.before_inference() vc.before_inference()
sess = tf.get_default_session() sess = tf.get_default_session()
self.ds.reset_state()
with tqdm(total=self.ds.size(), ascii=True) as pbar: with tqdm(total=self.ds.size(), ascii=True) as pbar:
for dp in self.ds.get_data(): for dp in self.ds.get_data():
#feed = dict(zip(self.input_vars, dp)) # TODO custom dp mapping? #feed = dict(zip(self.input_vars, dp)) # TODO custom dp mapping?
......
...@@ -29,7 +29,7 @@ class DataFlow(object): ...@@ -29,7 +29,7 @@ class DataFlow(object):
def reset_state(self): def reset_state(self):
""" """
Reset state of the dataflow, Reset state of the dataflow. Will always be called before consuming data points.
for example, RNG **HAS** to be reset here if used in the DataFlow. for example, RNG **HAS** to be reset here if used in the DataFlow.
Otherwise it may not work well with prefetching, because different Otherwise it may not work well with prefetching, because different
processes will have the same RNG state. processes will have the same RNG state.
...@@ -39,9 +39,6 @@ class DataFlow(object): ...@@ -39,9 +39,6 @@ class DataFlow(object):
class RNGDataFlow(DataFlow): class RNGDataFlow(DataFlow):
""" A dataflow with rng""" """ A dataflow with rng"""
def __init__(self):
self.rng = get_rng(self)
def reset_state(self): def reset_state(self):
self.rng = get_rng(self) self.rng = get_rng(self)
......
...@@ -306,11 +306,7 @@ class JoinData(DataFlow): ...@@ -306,11 +306,7 @@ class JoinData(DataFlow):
class LocallyShuffleData(ProxyDataFlow, RNGDataFlow): class LocallyShuffleData(ProxyDataFlow, RNGDataFlow):
def __init__(self, ds, cache_size): def __init__(self, ds, cache_size):
ProxyDataFlow.__init__(self, ds) ProxyDataFlow.__init__(self, ds)
RNGDataFlow.__init__(self)
self.q = deque(maxlen=cache_size) self.q = deque(maxlen=cache_size)
self.ds_wrap = RepeatedData(ds, -1)
self.ds_itr = self.ds_wrap.get_data()
self.current_cnt = 0
def reset_state(self): def reset_state(self):
ProxyDataFlow.reset_state(self) ProxyDataFlow.reset_state(self)
......
...@@ -9,7 +9,7 @@ import numpy as np ...@@ -9,7 +9,7 @@ import numpy as np
from ...utils import logger, get_rng, get_dataset_dir from ...utils import logger, get_rng, get_dataset_dir
from ...utils.fs import download from ...utils.fs import download
from ..base import DataFlow from ..base import RNGDataFlow
try: try:
from scipy.io import loadmat from scipy.io import loadmat
...@@ -21,7 +21,7 @@ except ImportError: ...@@ -21,7 +21,7 @@ except ImportError:
DATA_URL = "http://www.eecs.berkeley.edu/Research/Projects/CS/vision/grouping/BSR/BSR_bsds500.tgz" DATA_URL = "http://www.eecs.berkeley.edu/Research/Projects/CS/vision/grouping/BSR/BSR_bsds500.tgz"
IMG_W, IMG_H = 481, 321 IMG_W, IMG_H = 481, 321
class BSDS500(DataFlow): class BSDS500(RNGDataFlow):
""" """
`Berkeley Segmentation Data Set and Benchmarks 500 `Berkeley Segmentation Data Set and Benchmarks 500
<http://www.eecs.berkeley.edu/Research/Projects/CS/vision/grouping/resources.html#bsds500>`_. <http://www.eecs.berkeley.edu/Research/Projects/CS/vision/grouping/resources.html#bsds500>`_.
...@@ -53,10 +53,6 @@ class BSDS500(DataFlow): ...@@ -53,10 +53,6 @@ class BSDS500(DataFlow):
self.shuffle = shuffle self.shuffle = shuffle
assert name in ['train', 'test', 'val'] assert name in ['train', 'test', 'val']
self._load(name) self._load(name)
self.rng = get_rng(self)
def reset_state(self):
self.rng = get_rng(self)
def _load(self, name): def _load(self, name):
image_glob = os.path.join(self.data_root, 'images', name, '*.jpg') image_glob = os.path.join(self.data_root, 'images', name, '*.jpg')
......
...@@ -15,7 +15,7 @@ import logging ...@@ -15,7 +15,7 @@ import logging
from ...utils import logger, get_rng, get_dataset_dir from ...utils import logger, get_rng, get_dataset_dir
from ...utils.fs import download from ...utils.fs import download
from ..base import DataFlow from ..base import RNGDataFlow
__all__ = ['Cifar10', 'Cifar100'] __all__ = ['Cifar10', 'Cifar100']
...@@ -77,7 +77,7 @@ def get_filenames(dir, cifar_classnum): ...@@ -77,7 +77,7 @@ def get_filenames(dir, cifar_classnum):
os.path.join(dir, 'cifar-100-python', 'test')] os.path.join(dir, 'cifar-100-python', 'test')]
return filenames return filenames
class CifarBase(DataFlow): class CifarBase(RNGDataFlow):
""" """
Return [image, label], Return [image, label],
image is 32x32x3 in the range [0,255] image is 32x32x3 in the range [0,255]
...@@ -106,10 +106,6 @@ class CifarBase(DataFlow): ...@@ -106,10 +106,6 @@ class CifarBase(DataFlow):
self.data = read_cifar(self.fs, cifar_classnum) self.data = read_cifar(self.fs, cifar_classnum)
self.dir = dir self.dir = dir
self.shuffle = shuffle self.shuffle = shuffle
self.rng = get_rng(self)
def reset_state(self):
self.rng = get_rng(self)
def size(self): def size(self):
return 50000 if self.train_or_test == 'train' else 10000 return 50000 if self.train_or_test == 'train' else 10000
......
...@@ -11,7 +11,7 @@ from six.moves import range ...@@ -11,7 +11,7 @@ from six.moves import range
from ...utils import logger, get_rng, get_dataset_dir, memoized from ...utils import logger, get_rng, get_dataset_dir, memoized
from ...utils.loadcaffe import get_caffe_pb from ...utils.loadcaffe import get_caffe_pb
from ...utils.fs import mkdir_p, download from ...utils.fs import mkdir_p, download
from ..base import DataFlow from ..base import RNGDataFlow
__all__ = ['ILSVRCMeta', 'ILSVRC12'] __all__ = ['ILSVRCMeta', 'ILSVRC12']
...@@ -79,7 +79,7 @@ class ILSVRCMeta(object): ...@@ -79,7 +79,7 @@ class ILSVRCMeta(object):
arr = cv2.resize(arr, size[::-1]) arr = cv2.resize(arr, size[::-1])
return arr return arr
class ILSVRC12(DataFlow): class ILSVRC12(RNGDataFlow):
def __init__(self, dir, name, meta_dir=None, shuffle=True): def __init__(self, dir, name, meta_dir=None, shuffle=True):
""" """
:param dir: A directory containing a subdir named `name`, where the :param dir: A directory containing a subdir named `name`, where the
...@@ -119,17 +119,10 @@ class ILSVRC12(DataFlow): ...@@ -119,17 +119,10 @@ class ILSVRC12(DataFlow):
self.shuffle = shuffle self.shuffle = shuffle
self.meta = ILSVRCMeta(meta_dir) self.meta = ILSVRCMeta(meta_dir)
self.imglist = self.meta.get_image_list(name) self.imglist = self.meta.get_image_list(name)
self.rng = get_rng(self)
def size(self): def size(self):
return len(self.imglist) return len(self.imglist)
def reset_state(self):
"""
reset rng for shuffle
"""
self.rng = get_rng(self)
def get_data(self): def get_data(self):
""" """
Produce original images or shape [h, w, 3], and label Produce original images or shape [h, w, 3], and label
......
...@@ -11,7 +11,7 @@ from six.moves import urllib, range ...@@ -11,7 +11,7 @@ from six.moves import urllib, range
from ...utils import logger, get_dataset_dir from ...utils import logger, get_dataset_dir
from ...utils.fs import download from ...utils.fs import download
from ..base import DataFlow from ..base import RNGDataFlow
__all__ = ['Mnist'] __all__ = ['Mnist']
...@@ -92,7 +92,7 @@ class DataSet(object): ...@@ -92,7 +92,7 @@ class DataSet(object):
def num_examples(self): def num_examples(self):
return self._num_examples return self._num_examples
class Mnist(DataFlow): class Mnist(RNGDataFlow):
""" """
Return [image, label], Return [image, label],
image is 28x28 in the range [0,1] image is 28x28 in the range [0,1]
...@@ -136,7 +136,7 @@ class Mnist(DataFlow): ...@@ -136,7 +136,7 @@ class Mnist(DataFlow):
ds = self.train if self.train_or_test == 'train' else self.test ds = self.train if self.train_or_test == 'train' else self.test
idxs = list(range(ds.num_examples)) idxs = list(range(ds.num_examples))
if self.shuffle: if self.shuffle:
random.shuffle(idxs) self.rng.shuffle(idxs)
for k in idxs: for k in idxs:
img = ds.images[k].reshape((28, 28)) img = ds.images[k].reshape((28, 28))
label = ds.labels[k] label = ds.labels[k]
......
...@@ -9,7 +9,7 @@ import numpy as np ...@@ -9,7 +9,7 @@ import numpy as np
from six.moves import range from six.moves import range
from ...utils import logger, get_rng, get_dataset_dir from ...utils import logger, get_rng, get_dataset_dir
from ..base import DataFlow from ..base import RNGDataFlow
try: try:
import scipy.io import scipy.io
...@@ -20,7 +20,7 @@ except ImportError: ...@@ -20,7 +20,7 @@ except ImportError:
SVHN_URL = "http://ufldl.stanford.edu/housenumbers/" SVHN_URL = "http://ufldl.stanford.edu/housenumbers/"
class SVHNDigit(DataFlow): class SVHNDigit(RNGDataFlow):
""" """
SVHN Cropped Digit Dataset SVHN Cropped Digit Dataset
return img of 32x32x3, label of 0-9 return img of 32x32x3, label of 0-9
...@@ -33,7 +33,6 @@ class SVHNDigit(DataFlow): ...@@ -33,7 +33,6 @@ class SVHNDigit(DataFlow):
:param data_dir: a directory containing the original {train,test,extra}_32x32.mat :param data_dir: a directory containing the original {train,test,extra}_32x32.mat
""" """
self.shuffle = shuffle self.shuffle = shuffle
self.rng = get_rng(self)
if name in SVHNDigit.Cache: if name in SVHNDigit.Cache:
self.X, self.Y = SVHNDigit.Cache[name] self.X, self.Y = SVHNDigit.Cache[name]
...@@ -54,9 +53,6 @@ class SVHNDigit(DataFlow): ...@@ -54,9 +53,6 @@ class SVHNDigit(DataFlow):
def size(self): def size(self):
return self.X.shape[0] return self.X.shape[0]
def reset_state(self):
self.rng = get_rng(self)
def get_data(self): def get_data(self):
n = self.X.shape[0] n = self.X.shape[0]
idxs = np.arange(n) idxs = np.arange(n)
......
...@@ -23,6 +23,7 @@ def dump_dataset_images(ds, dirname, max_count=None, index=0): ...@@ -23,6 +23,7 @@ def dump_dataset_images(ds, dirname, max_count=None, index=0):
mkdir_p(dirname) mkdir_p(dirname)
if max_count is None: if max_count is None:
max_count = sys.maxint max_count = sys.maxint
ds.reset_state()
for i, dp in enumerate(ds.get_data()): for i, dp in enumerate(ds.get_data()):
if i % 100 == 0: if i % 100 == 0:
print(i) print(i)
...@@ -34,6 +35,7 @@ def dump_dataset_images(ds, dirname, max_count=None, index=0): ...@@ -34,6 +35,7 @@ def dump_dataset_images(ds, dirname, max_count=None, index=0):
def dataflow_to_process_queue(ds, size, nr_consumer): def dataflow_to_process_queue(ds, size, nr_consumer):
""" """
Convert a `DataFlow` to a multiprocessing.Queue. Convert a `DataFlow` to a multiprocessing.Queue.
The dataflow will only be reset in the spawned process.
:param ds: a `DataFlow` :param ds: a `DataFlow`
:param size: size of the queue :param size: size of the queue
...@@ -50,6 +52,7 @@ def dataflow_to_process_queue(ds, size, nr_consumer): ...@@ -50,6 +52,7 @@ def dataflow_to_process_queue(ds, size, nr_consumer):
self.q = q self.q = q
def run(self): def run(self):
self.ds.reset_state()
try: try:
for idx, dp in enumerate(self.ds.get_data()): for idx, dp in enumerate(self.ds.get_data()):
self.q.put((idx, dp)) self.q.put((idx, dp))
......
...@@ -5,7 +5,7 @@ ...@@ -5,7 +5,7 @@
from ..utils import logger, get_rng from ..utils import logger, get_rng
from ..utils.timer import timed_operation from ..utils.timer import timed_operation
from ..utils.loadcaffe import get_caffe_pb from ..utils.loadcaffe import get_caffe_pb
from .base import DataFlow from .base import RNGDataFlow
import random import random
from tqdm import tqdm from tqdm import tqdm
...@@ -31,7 +31,7 @@ else: ...@@ -31,7 +31,7 @@ else:
Adapters for different data format. Adapters for different data format.
""" """
class HDF5Data(DataFlow): class HDF5Data(RNGDataFlow):
""" """
Zip data from different paths in an HDF5 file. Will load all data into memory. Zip data from different paths in an HDF5 file. Will load all data into memory.
""" """
...@@ -55,19 +55,18 @@ class HDF5Data(DataFlow): ...@@ -55,19 +55,18 @@ class HDF5Data(DataFlow):
def get_data(self): def get_data(self):
idxs = list(range(self._size)) idxs = list(range(self._size))
if self.shuffle: if self.shuffle:
random.shuffle(idxs) self.rng.shuffle(idxs)
for k in idxs: for k in idxs:
yield [dp[k] for dp in self.dps] yield [dp[k] for dp in self.dps]
class LMDBData(DataFlow): class LMDBData(RNGDataFlow):
""" Read a lmdb and produce k,v pair """ """ Read a lmdb and produce k,v pair """
def __init__(self, lmdb_dir, shuffle=True): def __init__(self, lmdb_dir, shuffle=True):
self._lmdb = lmdb.open(lmdb_dir, readonly=True, lock=False, self._lmdb = lmdb.open(lmdb_dir, readonly=True, lock=False,
map_size=1099511627776 * 2, max_readers=100) map_size=1099511627776 * 2, max_readers=100)
self._txn = self._lmdb.begin() self._txn = self._lmdb.begin()
self._shuffle = shuffle self._shuffle = shuffle
self.rng = get_rng(self)
self._size = self._txn.stat()['entries'] self._size = self._txn.stat()['entries']
if shuffle: if shuffle:
self.keys = self._txn.get('__keys__') self.keys = self._txn.get('__keys__')
...@@ -81,8 +80,8 @@ class LMDBData(DataFlow): ...@@ -81,8 +80,8 @@ class LMDBData(DataFlow):
pbar.update() pbar.update()
def reset_state(self): def reset_state(self):
super(LMDBData, self).reset_state()
self._txn = self._lmdb.begin() self._txn = self._lmdb.begin()
self.rng = get_rng(self)
def size(self): def size(self):
return self._size return self._size
...@@ -96,8 +95,8 @@ class LMDBData(DataFlow): ...@@ -96,8 +95,8 @@ class LMDBData(DataFlow):
yield [k, v] yield [k, v]
else: else:
s = self.size() s = self.size()
for i in range(s): self.rng.shuffle(self.keys)
k = self.rng.choice(self.keys) for k in self.keys:
v = self._txn.get(k) v = self._txn.get(k)
yield [k, v] yield [k, v]
......
...@@ -35,7 +35,7 @@ class PrefetchProcess(multiprocessing.Process): ...@@ -35,7 +35,7 @@ class PrefetchProcess(multiprocessing.Process):
self.queue = queue self.queue = queue
def run(self): def run(self):
# reset RNG of ds so each process will produce different data # reset all ds so each process will produce different data
self.ds.reset_state() self.ds.reset_state()
while True: while True:
for dp in self.ds.get_data(): for dp in self.ds.get_data():
...@@ -73,6 +73,10 @@ class PrefetchData(ProxyDataFlow): ...@@ -73,6 +73,10 @@ class PrefetchData(ProxyDataFlow):
dp = self.queue.get() dp = self.queue.get()
yield dp yield dp
def reset_state(self):
# do nothing. all ds are reset once and only once in spawned processes
pass
class PrefetchProcessZMQ(multiprocessing.Process): class PrefetchProcessZMQ(multiprocessing.Process):
def __init__(self, ds, conn_name): def __init__(self, ds, conn_name):
""" """
...@@ -134,6 +138,10 @@ class PrefetchDataZMQ(ProxyDataFlow): ...@@ -134,6 +138,10 @@ class PrefetchDataZMQ(ProxyDataFlow):
dp = loads(self.socket.recv(copy=False)) dp = loads(self.socket.recv(copy=False))
yield dp yield dp
def reset_state(self):
# do nothing. all ds are reset once and only once in spawned processes
pass
def __del__(self): def __del__(self):
# on exit, logger may not be functional anymore # on exit, logger may not be functional anymore
try: try:
......
...@@ -23,6 +23,7 @@ def serve_data(ds, addr): ...@@ -23,6 +23,7 @@ def serve_data(ds, addr):
socket.bind(addr) socket.bind(addr)
ds = RepeatedData(ds, -1) ds = RepeatedData(ds, -1)
try: try:
ds.reset_state()
logger.info("Serving data at {}".format(addr)) logger.info("Serving data at {}".format(addr))
while True: while True:
for dp in ds.get_data(): for dp in ds.get_data():
......
...@@ -19,7 +19,7 @@ def describe_model(): ...@@ -19,7 +19,7 @@ def describe_model():
v.name, shape.as_list(), ele)) v.name, shape.as_list(), ele))
size_mb = total * 4 / 1024.0**2 size_mb = total * 4 / 1024.0**2
msg.append("Total param={} ({:01f} MB assuming all float32)".format(total, size_mb)) msg.append("Total param={} ({:01f} MB assuming all float32)".format(total, size_mb))
logger.info("Model Params: {}".format('\n'.join(msg))) logger.info("Model Parameters: {}".format('\n'.join(msg)))
def get_shape_str(tensors): def get_shape_str(tensors):
......
...@@ -40,6 +40,7 @@ class SimpleTrainer(Trainer): ...@@ -40,6 +40,7 @@ class SimpleTrainer(Trainer):
self.init_session_and_coord() self.init_session_and_coord()
describe_model() describe_model()
# create an infinte data producer # create an infinte data producer
self.config.dataset.reset_state()
self.data_producer = RepeatedData(self.config.dataset, -1).get_data() self.data_producer = RepeatedData(self.config.dataset, -1).get_data()
self.main_loop() self.main_loop()
...@@ -62,21 +63,22 @@ class SimpleTrainer(Trainer): ...@@ -62,21 +63,22 @@ class SimpleTrainer(Trainer):
return func return func
class EnqueueThread(threading.Thread): class EnqueueThread(threading.Thread):
def __init__(self, trainer, queue, enqueue_op, raw_input_var): def __init__(self, trainer):
super(EnqueueThread, self).__init__() super(EnqueueThread, self).__init__()
self.sess = trainer.sess self.sess = trainer.sess
self.coord = trainer.coord self.coord = trainer.coord
self.dataflow = RepeatedData(trainer.config.dataset, -1) self.dataflow = RepeatedData(trainer.config.dataset, -1)
self.input_vars = raw_input_var self.input_vars = trainer.input_vars
self.op = enqueue_op self.queue = trainer.input_queue
self.queue = queue self.op = self.queue.enqueue(self.input_vars)
self.close_op = self.queue.close(cancel_pending_enqueues=True) self.close_op = self.queue.close(cancel_pending_enqueues=True)
self.size_op = self.queue.size() self.size_op = self.queue.size()
self.daemon = True self.daemon = True
def run(self): def run(self):
self.dataflow.reset_state()
with self.sess.as_default(): with self.sess.as_default():
try: try:
while True: while True:
...@@ -155,8 +157,7 @@ class QueueInputTrainer(Trainer): ...@@ -155,8 +157,7 @@ class QueueInputTrainer(Trainer):
def _build_enque_thread(self): def _build_enque_thread(self):
""" create a thread that keeps filling the queue """ """ create a thread that keeps filling the queue """
enqueue_op = self.input_queue.enqueue(self.input_vars) self.input_th = EnqueueThread(self)
self.input_th = EnqueueThread(self, self.input_queue, enqueue_op, self.input_vars)
self.extra_threads_procs.append(self.input_th) self.extra_threads_procs.append(self.input_th)
def train(self): def train(self):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment