Commit aec8cd3a authored by Yuxin Wu's avatar Yuxin Wu

discretize ND

parent 562ab0b9
......@@ -29,11 +29,16 @@ Multi-GPU training is ready to use by simply changing the trainer.
## Dependencies:
+ Python 2 or 3
+ TensorFlow
+ TensorFlow >= 0.8
+ Python bindings for OpenCV
+ other requirements:
```
pip install --user -r requirements.txt
pip install --user -r opt-requirements.txt (some optional dependencies)
```
+ allow `import tensorpack` everywhere:
```
export PYTHONPATH=$PYTHONPATH:`readlink -f path/to/tensorpack`
```
......@@ -62,7 +62,7 @@ def BatchNorm(x, use_local_stat=True, decay=0.9, epsilon=1e-5):
else:
# use training-statistics in prediction
assert not use_local_stat
with tf.name_scope(None): # https://github.com/tensorflow/tensorflow/issues/2740
with tf.name_scope(None):
ema = tf.train.ExponentialMovingAverage(decay=decay, name=emaname)
ema_apply_op = ema.apply([batch_mean, batch_var])
ema_mean, ema_var = ema.average(batch_mean), ema.average(batch_var)
......
......@@ -8,7 +8,7 @@ from abc import abstractmethod, ABCMeta
import numpy as np
from six.moves import range
__all__ = ['UniformDiscretizer1D']
__all__ = ['UniformDiscretizer1D', 'UniformDiscretizerND']
@memoized
def log_once(s):
......@@ -74,8 +74,31 @@ class UniformDiscretizer1D(Discretizer1D):
return ret
class UniformDiscretizerND(Discretizer):
def __init__(self, *min_max_spacing):
"""
:params min_max_spacing: (minv, maxv, spacing) for each dimension
"""
self.n = len(min_max_spacing)
self.discretizers = [UniformDiscretizer1D(*k) for k in min_max_spacing]
self.nr_bins = [k.get_nr_bin() for k in self.discretizers]
def get_nr_bin(self):
return np.prod(self.nr_bins)
def get_bin(self, v):
assert len(v) == self.n
bin_id = [self.discretizers[k].get_bin(v[k]) for k in range(self.n)]
acc, res = 1, 0
for k in reversed(list(range(self.n))):
res += bin_id[k] * acc
acc *= self.nr_bins[k]
return res
if __name__ == '__main__':
u = UniformDiscretizer1D(-10, 10, 0.12)
#u = UniformDiscretizer1D(-10, 10, 0.12)
u = UniformDiscretizerND((0, 100, 1), (0, 100, 1), (0, 100, 1))
import IPython as IP;
IP.embed(config=IP.terminal.ipapp.load_default_config())
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment