Commit de9025d6 authored by Yuxin Wu's avatar Yuxin Wu

update docs about MapAndBatch

parent 22582cc7
......@@ -354,13 +354,15 @@ class MultiProcessMapAndBatchDataZMQ(_MultiProcessZMQDataFlow):
enable_death_signal(_warn=self.identity == b'0')
ctx = zmq.Context()
# recv jobs
socket = ctx.socket(zmq.PULL)
socket.setsockopt(zmq.IDENTITY, self.identity)
socket.set_hwm(self.hwm)
socket.set_hwm(self.hwm * self.batch_size)
socket.connect(self.input_pipe)
# send results
out_socket = ctx.socket(zmq.PUSH)
out_socket.set_hwm(max(self.hwm // self.batch_size, 5))
out_socket.set_hwm(max(self.hwm, 5))
out_socket.connect(self.result_pipe)
batch = []
......@@ -374,7 +376,7 @@ class MultiProcessMapAndBatchDataZMQ(_MultiProcessZMQDataFlow):
out_socket.send(dumps(dp), copy=False)
del batch[:]
def __init__(self, ds, num_proc, map_func, batch_size, buffer_size=1024):
def __init__(self, ds, num_proc, map_func, batch_size, buffer_size=None):
"""
Args:
ds (DataFlow): the dataflow to map
......@@ -382,15 +384,18 @@ class MultiProcessMapAndBatchDataZMQ(_MultiProcessZMQDataFlow):
map_func (callable): datapoint -> datapoint | None. Return None to
discard/skip the datapoint.
batch_size (int): batch size
buffer_size (int): number of datapoints in the buffer
buffer_size (int): number of datapoints (not batched) in the buffer.
Defaults to batch_size * 10
"""
super(MultiProcessMapAndBatchDataZMQ, self).__init__()
assert batch_size < buffer_size
self.ds = ds
self.num_proc = num_proc
self.map_func = map_func
self.buffer_size = buffer_size
self.batch_size = batch_size
assert self.batch_size < buffer_size
if buffer_size is None:
buffer_size = batch_size * 10
self.buffer_size = buffer_size
def reset_state(self):
_MultiProcessZMQDataFlow.reset_state(self)
......@@ -401,13 +406,13 @@ class MultiProcessMapAndBatchDataZMQ(_MultiProcessZMQDataFlow):
self.context = zmq.Context()
self.socket = self.context.socket(zmq.PULL)
self.socket.set_hwm(max(5, self.buffer_size * 2 // self.batch_size))
self.socket.set_hwm(max(5, self.buffer_size // self.batch_size))
_bind_guard(self.socket, result_pipe)
dispatcher = MultiProcessMapAndBatchDataZMQ._Dispatcher(self.ds, job_pipe, self.buffer_size)
self._proc_ids = [u'{}'.format(k).encode('utf-8') for k in range(self.num_proc)]
worker_hwm = max(3, self.buffer_size * 2 // self.num_proc // self.batch_size)
worker_hwm = max(3, self.buffer_size // self.num_proc // self.batch_size)
self._procs = [MultiProcessMapAndBatchDataZMQ._Worker(
self._proc_ids[k], self.map_func, job_pipe, result_pipe, worker_hwm, self.batch_size)
for k in range(self.num_proc)]
......
......@@ -287,8 +287,7 @@ def BatchNorm(inputs, axis=None, *, training=None, momentum=0.9, epsilon=1e-5,
center=center, scale=scale,
beta_initializer=beta_initializer,
gamma_initializer=gamma_initializer,
# https://github.com/tensorflow/tensorflow/issues/10857#issuecomment-410185429
fused=(ndims == 4 and axis in [1, 3] and not freeze_bn_backward),
fused=(ndims == 4 and axis in [1, 3]),
_reuse=tf.get_variable_scope().reuse)
use_fp16 = inputs.dtype == tf.float16
if use_fp16:
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment