Commit fbc13fb4 authored by Yuxin Wu's avatar Yuxin Wu

notes, logs, online moments

parent 99a8ee54
...@@ -54,9 +54,10 @@ To Train: ...@@ -54,9 +54,10 @@ To Train:
ILSVRC2012_val_00000001.JPEG ILSVRC2012_val_00000001.JPEG
... ...
And better to have: And you'll need the following to be able to fetch data efficiently
Fast disk random access (Not necessarily SSD. I used a RAID of HDD, but not sure if plain HDD is enough) Fast disk random access (Not necessarily SSD. I used a RAID of HDD, but not sure if plain HDD is enough)
More than 12 CPU cores (for data processing) More than 12 CPU cores (for data processing)
More than 10G of free memory
To Run Pretrained Model: To Run Pretrained Model:
./alexnet-dorefa.py --load alexnet-126.npy --run a.jpg --dorefa 1,2,6 ./alexnet-dorefa.py --load alexnet-126.npy --run a.jpg --dorefa 1,2,6
...@@ -303,6 +304,7 @@ if __name__ == '__main__': ...@@ -303,6 +304,7 @@ if __name__ == '__main__':
assert args.gpu is not None, "Need to specify a list of gpu for training!" assert args.gpu is not None, "Need to specify a list of gpu for training!"
NR_GPU = len(args.gpu.split(',')) NR_GPU = len(args.gpu.split(','))
BATCH_SIZE = TOTAL_BATCH_SIZE // NR_GPU BATCH_SIZE = TOTAL_BATCH_SIZE // NR_GPU
logger.info("Batch per tower: {}".format(BATCH_SIZE))
config = get_config() config = get_config()
if args.load: if args.load:
......
...@@ -21,16 +21,15 @@ class TestDataSpeed(ProxyDataFlow): ...@@ -21,16 +21,15 @@ class TestDataSpeed(ProxyDataFlow):
self.test_size = size self.test_size = size
def get_data(self): def get_data(self):
with get_tqdm(total=self.test_size) as pbar: self.start_test()
for dp in self.ds.get_data():
pbar.update()
for dp in self.ds.get_data(): for dp in self.ds.get_data():
yield dp yield dp
def start_test(self): def start_test(self):
self.ds.reset_state() self.ds.reset_state()
for k in self.get_data(): with get_tqdm(total=self.test_size) as pbar:
pass for dp in self.ds.get_data():
pbar.update()
class BatchData(ProxyDataFlow): class BatchData(ProxyDataFlow):
def __init__(self, ds, batch_size, remainder=False): def __init__(self, ds, batch_size, remainder=False):
......
...@@ -72,13 +72,18 @@ class HDF5Data(RNGDataFlow): ...@@ -72,13 +72,18 @@ class HDF5Data(RNGDataFlow):
class LMDBData(RNGDataFlow): class LMDBData(RNGDataFlow):
""" Read a lmdb and produce k,v pair """ """ Read a lmdb and produce k,v pair """
def __init__(self, lmdb_path, shuffle=True): def __init__(self, lmdb_path, shuffle=True):
self._lmdb = lmdb.open(lmdb_path, subdir=os.path.isdir(lmdb_path), self._lmdb_path = lmdb_path
readonly=True, lock=False, self._shuffle = shuffle
self.open_lmdb()
def open_lmdb(self):
self._lmdb = lmdb.open(self._lmdb_path,
subdir=os.path.isdir(self._lmdb_path),
readonly=True, lock=False, readahead=False,
map_size=1099511627776 * 2, max_readers=100) map_size=1099511627776 * 2, max_readers=100)
self._txn = self._lmdb.begin() self._txn = self._lmdb.begin()
self._shuffle = shuffle
self._size = self._txn.stat()['entries'] self._size = self._txn.stat()['entries']
if shuffle: if self._shuffle:
# get the list of keys either from __keys__ or by iterating # get the list of keys either from __keys__ or by iterating
self.keys = loads(self._txn.get('__keys__')) self.keys = loads(self._txn.get('__keys__'))
if not self.keys: if not self.keys:
...@@ -92,7 +97,7 @@ class LMDBData(RNGDataFlow): ...@@ -92,7 +97,7 @@ class LMDBData(RNGDataFlow):
def reset_state(self): def reset_state(self):
super(LMDBData, self).reset_state() super(LMDBData, self).reset_state()
self._txn = self._lmdb.begin() self.open_lmdb()
def size(self): def size(self):
return self._size return self._size
......
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
# Author: Yuxin Wu <ppwwyyxxc@gmail.com> # Author: Yuxin Wu <ppwwyyxxc@gmail.com>
import numpy as np import numpy as np
import copy
from six.moves import range from six.moves import range
from .base import DataFlow, RNGDataFlow from .base import DataFlow, RNGDataFlow
from ..utils.serialize import loads from ..utils.serialize import loads
...@@ -41,7 +42,7 @@ class FakeData(RNGDataFlow): ...@@ -41,7 +42,7 @@ class FakeData(RNGDataFlow):
else: else:
v = [self.rng.rand(*k).astype(self.dtype) for k in self.shapes] v = [self.rng.rand(*k).astype(self.dtype) for k in self.shapes]
for _ in range(self._size): for _ in range(self._size):
yield v yield copy.deepcopy(v)
class DataFromQueue(DataFlow): class DataFromQueue(DataFlow):
""" Produce data from a queue """ """ Produce data from a queue """
......
...@@ -74,7 +74,7 @@ class EnqueueThread(threading.Thread): ...@@ -74,7 +74,7 @@ class EnqueueThread(threading.Thread):
if self.coord.should_stop(): if self.coord.should_stop():
return return
feed = dict(zip(self.placehdrs, dp)) feed = dict(zip(self.placehdrs, dp))
#print 'TFQ:', self.sess.run([self.op, self.size_op], feed_dict=feed)[1] #print 'qsize:', self.sess.run([self.op, self.size_op], feed_dict=feed)[1]
self.op.run(feed_dict=feed) self.op.run(feed_dict=feed)
except tf.errors.CancelledError as e: except tf.errors.CancelledError as e:
pass pass
......
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
# Author: Yuxin Wu <ppwwyyxxc@gmail.com> # Author: Yuxin Wu <ppwwyyxxc@gmail.com>
import operator
import inspect, six, functools import inspect, six, functools
import collections import collections
...@@ -34,16 +35,20 @@ class memoized(object): ...@@ -34,16 +35,20 @@ class memoized(object):
self.func = func self.func = func
self.cache = {} self.cache = {}
def __call__(self, *args): def __call__(self, *args, **kwargs):
if not isinstance(args, collections.Hashable): kwlist = tuple(sorted(list(kwargs), key=operator.itemgetter(0)))
if not isinstance(args, collections.Hashable) or \
not isinstance(kwlist, collections.Hashable):
logger.warn("Arguments to memoized call is unhashable!")
# uncacheable. a list, for instance. # uncacheable. a list, for instance.
# better to not cache than blow up. # better to not cache than blow up.
return self.func(*args) return self.func(*args, **kwargs)
if args in self.cache: key = (args, kwlist)
return self.cache[args] if key in self.cache:
return self.cache[key]
else: else:
value = self.func(*args) value = self.func(*args, **kwargs)
self.cache[args] = value self.cache[key] = value
return value return value
def __repr__(self): def __repr__(self):
...@@ -57,9 +62,9 @@ class memoized(object): ...@@ -57,9 +62,9 @@ class memoized(object):
_MEMOIZED_NOARGS = {} _MEMOIZED_NOARGS = {}
def memoized_ignoreargs(func): def memoized_ignoreargs(func):
h = hash(func) # make sure it is hashable. is it necessary? h = hash(func) # make sure it is hashable. is it necessary?
def wrapper(*args): def wrapper(*args, **kwargs):
if func not in _MEMOIZED_NOARGS: if func not in _MEMOIZED_NOARGS:
res = func(*args) res = func(*args, **kwargs)
_MEMOIZED_NOARGS[func] = res _MEMOIZED_NOARGS[func] = res
return res return res
return _MEMOIZED_NOARGS[func] return _MEMOIZED_NOARGS[func]
......
...@@ -119,21 +119,20 @@ class BinaryStatistics(object): ...@@ -119,21 +119,20 @@ class BinaryStatistics(object):
return 1 - self.recall return 1 - self.recall
class OnlineMoments(object): class OnlineMoments(object):
"""Compute 1st and 2nd moments online
See algorithm at: https://www.wikiwand.com/en/Algorithms_for_calculating_variance#/Online_algorithm
"""
def __init__(self): def __init__(self):
self._mean = None self._mean = 0
self._var = None self._M2 = 0
self._n = 0 self._n = 0
def feed(self, x): def feed(self, x):
self._n += 1 self._n += 1
if self._mean is None: delta = x - self._mean
self._mean = x self._mean += delta * (1.0 / self._n)
self._var = 0 delta2 = x - self._mean
else: self._M2 += delta * delta2
diff = (x - self._mean)
ninv = 1.0 / self._n
self._mean += diff * ninv
self._var = (self._n-2.0)/(self._n-1) * self._var + diff * diff * ninv
@property @property
def mean(self): def mean(self):
...@@ -141,8 +140,8 @@ class OnlineMoments(object): ...@@ -141,8 +140,8 @@ class OnlineMoments(object):
@property @property
def variance(self): def variance(self):
return self._var return self._M2 / (self._n-1)
@property @property
def std(self): def std(self):
return np.sqrt(self._var) return np.sqrt(self.variance)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment