Commit fbc13fb4 authored by Yuxin Wu's avatar Yuxin Wu

notes, logs, online moments

parent 99a8ee54
......@@ -54,9 +54,10 @@ To Train:
ILSVRC2012_val_00000001.JPEG
...
And better to have:
And you'll need the following to be able to fetch data efficiently
Fast disk random access (Not necessarily SSD. I used a RAID of HDD, but not sure if plain HDD is enough)
More than 12 CPU cores (for data processing)
More than 10G of free memory
To Run Pretrained Model:
./alexnet-dorefa.py --load alexnet-126.npy --run a.jpg --dorefa 1,2,6
......@@ -303,6 +304,7 @@ if __name__ == '__main__':
assert args.gpu is not None, "Need to specify a list of gpu for training!"
NR_GPU = len(args.gpu.split(','))
BATCH_SIZE = TOTAL_BATCH_SIZE // NR_GPU
logger.info("Batch per tower: {}".format(BATCH_SIZE))
config = get_config()
if args.load:
......
......@@ -21,16 +21,15 @@ class TestDataSpeed(ProxyDataFlow):
self.test_size = size
def get_data(self):
with get_tqdm(total=self.test_size) as pbar:
for dp in self.ds.get_data():
pbar.update()
self.start_test()
for dp in self.ds.get_data():
yield dp
def start_test(self):
self.ds.reset_state()
for k in self.get_data():
pass
with get_tqdm(total=self.test_size) as pbar:
for dp in self.ds.get_data():
pbar.update()
class BatchData(ProxyDataFlow):
def __init__(self, ds, batch_size, remainder=False):
......
......@@ -72,13 +72,18 @@ class HDF5Data(RNGDataFlow):
class LMDBData(RNGDataFlow):
""" Read a lmdb and produce k,v pair """
def __init__(self, lmdb_path, shuffle=True):
self._lmdb = lmdb.open(lmdb_path, subdir=os.path.isdir(lmdb_path),
readonly=True, lock=False,
self._lmdb_path = lmdb_path
self._shuffle = shuffle
self.open_lmdb()
def open_lmdb(self):
self._lmdb = lmdb.open(self._lmdb_path,
subdir=os.path.isdir(self._lmdb_path),
readonly=True, lock=False, readahead=False,
map_size=1099511627776 * 2, max_readers=100)
self._txn = self._lmdb.begin()
self._shuffle = shuffle
self._size = self._txn.stat()['entries']
if shuffle:
if self._shuffle:
# get the list of keys either from __keys__ or by iterating
self.keys = loads(self._txn.get('__keys__'))
if not self.keys:
......@@ -92,7 +97,7 @@ class LMDBData(RNGDataFlow):
def reset_state(self):
super(LMDBData, self).reset_state()
self._txn = self._lmdb.begin()
self.open_lmdb()
def size(self):
return self._size
......
......@@ -4,6 +4,7 @@
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
import numpy as np
import copy
from six.moves import range
from .base import DataFlow, RNGDataFlow
from ..utils.serialize import loads
......@@ -41,7 +42,7 @@ class FakeData(RNGDataFlow):
else:
v = [self.rng.rand(*k).astype(self.dtype) for k in self.shapes]
for _ in range(self._size):
yield v
yield copy.deepcopy(v)
class DataFromQueue(DataFlow):
""" Produce data from a queue """
......
......@@ -74,7 +74,7 @@ class EnqueueThread(threading.Thread):
if self.coord.should_stop():
return
feed = dict(zip(self.placehdrs, dp))
#print 'TFQ:', self.sess.run([self.op, self.size_op], feed_dict=feed)[1]
#print 'qsize:', self.sess.run([self.op, self.size_op], feed_dict=feed)[1]
self.op.run(feed_dict=feed)
except tf.errors.CancelledError as e:
pass
......
......@@ -4,6 +4,7 @@
# Author: Yuxin Wu <ppwwyyxxc@gmail.com>
import operator
import inspect, six, functools
import collections
......@@ -34,16 +35,20 @@ class memoized(object):
self.func = func
self.cache = {}
def __call__(self, *args):
if not isinstance(args, collections.Hashable):
def __call__(self, *args, **kwargs):
kwlist = tuple(sorted(list(kwargs), key=operator.itemgetter(0)))
if not isinstance(args, collections.Hashable) or \
not isinstance(kwlist, collections.Hashable):
logger.warn("Arguments to memoized call is unhashable!")
# uncacheable. a list, for instance.
# better to not cache than blow up.
return self.func(*args)
if args in self.cache:
return self.cache[args]
return self.func(*args, **kwargs)
key = (args, kwlist)
if key in self.cache:
return self.cache[key]
else:
value = self.func(*args)
self.cache[args] = value
value = self.func(*args, **kwargs)
self.cache[key] = value
return value
def __repr__(self):
......@@ -57,9 +62,9 @@ class memoized(object):
_MEMOIZED_NOARGS = {}
def memoized_ignoreargs(func):
h = hash(func) # make sure it is hashable. is it necessary?
def wrapper(*args):
def wrapper(*args, **kwargs):
if func not in _MEMOIZED_NOARGS:
res = func(*args)
res = func(*args, **kwargs)
_MEMOIZED_NOARGS[func] = res
return res
return _MEMOIZED_NOARGS[func]
......
......@@ -119,21 +119,20 @@ class BinaryStatistics(object):
return 1 - self.recall
class OnlineMoments(object):
"""Compute 1st and 2nd moments online
See algorithm at: https://www.wikiwand.com/en/Algorithms_for_calculating_variance#/Online_algorithm
"""
def __init__(self):
self._mean = None
self._var = None
self._mean = 0
self._M2 = 0
self._n = 0
def feed(self, x):
self._n += 1
if self._mean is None:
self._mean = x
self._var = 0
else:
diff = (x - self._mean)
ninv = 1.0 / self._n
self._mean += diff * ninv
self._var = (self._n-2.0)/(self._n-1) * self._var + diff * diff * ninv
delta = x - self._mean
self._mean += delta * (1.0 / self._n)
delta2 = x - self._mean
self._M2 += delta * delta2
@property
def mean(self):
......@@ -141,8 +140,8 @@ class OnlineMoments(object):
@property
def variance(self):
return self._var
return self._M2 / (self._n-1)
@property
def std(self):
return np.sqrt(self._var)
return np.sqrt(self.variance)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment