Commit de9025d6 authored by Yuxin Wu's avatar Yuxin Wu

update docs about MapAndBatch

parent 22582cc7
...@@ -354,13 +354,15 @@ class MultiProcessMapAndBatchDataZMQ(_MultiProcessZMQDataFlow): ...@@ -354,13 +354,15 @@ class MultiProcessMapAndBatchDataZMQ(_MultiProcessZMQDataFlow):
enable_death_signal(_warn=self.identity == b'0') enable_death_signal(_warn=self.identity == b'0')
ctx = zmq.Context() ctx = zmq.Context()
# recv jobs
socket = ctx.socket(zmq.PULL) socket = ctx.socket(zmq.PULL)
socket.setsockopt(zmq.IDENTITY, self.identity) socket.setsockopt(zmq.IDENTITY, self.identity)
socket.set_hwm(self.hwm) socket.set_hwm(self.hwm * self.batch_size)
socket.connect(self.input_pipe) socket.connect(self.input_pipe)
# send results
out_socket = ctx.socket(zmq.PUSH) out_socket = ctx.socket(zmq.PUSH)
out_socket.set_hwm(max(self.hwm // self.batch_size, 5)) out_socket.set_hwm(max(self.hwm, 5))
out_socket.connect(self.result_pipe) out_socket.connect(self.result_pipe)
batch = [] batch = []
...@@ -374,7 +376,7 @@ class MultiProcessMapAndBatchDataZMQ(_MultiProcessZMQDataFlow): ...@@ -374,7 +376,7 @@ class MultiProcessMapAndBatchDataZMQ(_MultiProcessZMQDataFlow):
out_socket.send(dumps(dp), copy=False) out_socket.send(dumps(dp), copy=False)
del batch[:] del batch[:]
def __init__(self, ds, num_proc, map_func, batch_size, buffer_size=1024): def __init__(self, ds, num_proc, map_func, batch_size, buffer_size=None):
""" """
Args: Args:
ds (DataFlow): the dataflow to map ds (DataFlow): the dataflow to map
...@@ -382,15 +384,18 @@ class MultiProcessMapAndBatchDataZMQ(_MultiProcessZMQDataFlow): ...@@ -382,15 +384,18 @@ class MultiProcessMapAndBatchDataZMQ(_MultiProcessZMQDataFlow):
map_func (callable): datapoint -> datapoint | None. Return None to map_func (callable): datapoint -> datapoint | None. Return None to
discard/skip the datapoint. discard/skip the datapoint.
batch_size (int): batch size batch_size (int): batch size
buffer_size (int): number of datapoints in the buffer buffer_size (int): number of datapoints (not batched) in the buffer.
Defaults to batch_size * 10
""" """
super(MultiProcessMapAndBatchDataZMQ, self).__init__() super(MultiProcessMapAndBatchDataZMQ, self).__init__()
assert batch_size < buffer_size
self.ds = ds self.ds = ds
self.num_proc = num_proc self.num_proc = num_proc
self.map_func = map_func self.map_func = map_func
self.buffer_size = buffer_size
self.batch_size = batch_size self.batch_size = batch_size
assert self.batch_size < buffer_size if buffer_size is None:
buffer_size = batch_size * 10
self.buffer_size = buffer_size
def reset_state(self): def reset_state(self):
_MultiProcessZMQDataFlow.reset_state(self) _MultiProcessZMQDataFlow.reset_state(self)
...@@ -401,13 +406,13 @@ class MultiProcessMapAndBatchDataZMQ(_MultiProcessZMQDataFlow): ...@@ -401,13 +406,13 @@ class MultiProcessMapAndBatchDataZMQ(_MultiProcessZMQDataFlow):
self.context = zmq.Context() self.context = zmq.Context()
self.socket = self.context.socket(zmq.PULL) self.socket = self.context.socket(zmq.PULL)
self.socket.set_hwm(max(5, self.buffer_size * 2 // self.batch_size)) self.socket.set_hwm(max(5, self.buffer_size // self.batch_size))
_bind_guard(self.socket, result_pipe) _bind_guard(self.socket, result_pipe)
dispatcher = MultiProcessMapAndBatchDataZMQ._Dispatcher(self.ds, job_pipe, self.buffer_size) dispatcher = MultiProcessMapAndBatchDataZMQ._Dispatcher(self.ds, job_pipe, self.buffer_size)
self._proc_ids = [u'{}'.format(k).encode('utf-8') for k in range(self.num_proc)] self._proc_ids = [u'{}'.format(k).encode('utf-8') for k in range(self.num_proc)]
worker_hwm = max(3, self.buffer_size * 2 // self.num_proc // self.batch_size) worker_hwm = max(3, self.buffer_size // self.num_proc // self.batch_size)
self._procs = [MultiProcessMapAndBatchDataZMQ._Worker( self._procs = [MultiProcessMapAndBatchDataZMQ._Worker(
self._proc_ids[k], self.map_func, job_pipe, result_pipe, worker_hwm, self.batch_size) self._proc_ids[k], self.map_func, job_pipe, result_pipe, worker_hwm, self.batch_size)
for k in range(self.num_proc)] for k in range(self.num_proc)]
......
...@@ -287,8 +287,7 @@ def BatchNorm(inputs, axis=None, *, training=None, momentum=0.9, epsilon=1e-5, ...@@ -287,8 +287,7 @@ def BatchNorm(inputs, axis=None, *, training=None, momentum=0.9, epsilon=1e-5,
center=center, scale=scale, center=center, scale=scale,
beta_initializer=beta_initializer, beta_initializer=beta_initializer,
gamma_initializer=gamma_initializer, gamma_initializer=gamma_initializer,
# https://github.com/tensorflow/tensorflow/issues/10857#issuecomment-410185429 fused=(ndims == 4 and axis in [1, 3]),
fused=(ndims == 4 and axis in [1, 3] and not freeze_bn_backward),
_reuse=tf.get_variable_scope().reuse) _reuse=tf.get_variable_scope().reuse)
use_fp16 = inputs.dtype == tf.float16 use_fp16 = inputs.dtype == tf.float16
if use_fp16: if use_fp16:
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment