Commit 869bc638 authored by Yuxin Wu's avatar Yuxin Wu

bind option in zmq tools. more notes about horovod.

parent 28f36c44
...@@ -101,7 +101,6 @@ def get_config(model, fake=False): ...@@ -101,7 +101,6 @@ def get_config(model, fake=False):
callbacks=callbacks, callbacks=callbacks,
steps_per_epoch=100 if args.fake else 1280000 // args.batch, steps_per_epoch=100 if args.fake else 1280000 // args.batch,
max_epoch=110, max_epoch=110,
nr_tower=nr_tower
) )
......
...@@ -83,7 +83,7 @@ def fbresnet_augmentor(isTrain): ...@@ -83,7 +83,7 @@ def fbresnet_augmentor(isTrain):
def get_imagenet_dataflow( def get_imagenet_dataflow(
datadir, name, batch_size, datadir, name, batch_size,
augmentors): augmentors, parallel=None):
""" """
See explanations in the tutorial: See explanations in the tutorial:
http://tensorpack.readthedocs.io/en/latest/tutorial/efficient-dataflow.html http://tensorpack.readthedocs.io/en/latest/tutorial/efficient-dataflow.html
...@@ -92,11 +92,12 @@ def get_imagenet_dataflow( ...@@ -92,11 +92,12 @@ def get_imagenet_dataflow(
assert datadir is not None assert datadir is not None
assert isinstance(augmentors, list) assert isinstance(augmentors, list)
isTrain = name == 'train' isTrain = name == 'train'
cpu = min(40, multiprocessing.cpu_count()) if parallel is None:
parallel = min(40, multiprocessing.cpu_count())
if isTrain: if isTrain:
ds = dataset.ILSVRC12(datadir, name, shuffle=True) ds = dataset.ILSVRC12(datadir, name, shuffle=True)
ds = AugmentImageComponent(ds, augmentors, copy=False) ds = AugmentImageComponent(ds, augmentors, copy=False)
ds = PrefetchDataZMQ(ds, cpu) ds = PrefetchDataZMQ(ds, parallel)
ds = BatchData(ds, batch_size, remainder=False) ds = BatchData(ds, batch_size, remainder=False)
else: else:
ds = dataset.ILSVRC12Files(datadir, name, shuffle=False) ds = dataset.ILSVRC12Files(datadir, name, shuffle=False)
...@@ -107,7 +108,7 @@ def get_imagenet_dataflow( ...@@ -107,7 +108,7 @@ def get_imagenet_dataflow(
im = cv2.imread(fname, cv2.IMREAD_COLOR) im = cv2.imread(fname, cv2.IMREAD_COLOR)
im = aug.augment(im) im = aug.augment(im)
return im, cls return im, cls
ds = MultiThreadMapData(ds, cpu, mapf, buffer_size=2000, strict=True) ds = MultiThreadMapData(ds, parallel, mapf, buffer_size=2000, strict=True)
ds = BatchData(ds, batch_size, remainder=True) ds = BatchData(ds, batch_size, remainder=True)
ds = PrefetchDataZMQ(ds, 1) ds = PrefetchDataZMQ(ds, 1)
return ds return ds
......
...@@ -9,10 +9,11 @@ import pprint ...@@ -9,10 +9,11 @@ import pprint
from termcolor import colored from termcolor import colored
from collections import deque, defaultdict from collections import deque, defaultdict
from six.moves import range, map from six.moves import range, map
import tqdm
from .base import DataFlow, ProxyDataFlow, RNGDataFlow, DataFlowReentrantGuard from .base import DataFlow, ProxyDataFlow, RNGDataFlow, DataFlowReentrantGuard
from ..utils import logger from ..utils import logger
from ..utils.utils import get_tqdm, get_rng from ..utils.utils import get_tqdm, get_rng, get_tqdm_kwargs
from ..utils.develop import log_deprecated from ..utils.develop import log_deprecated
__all__ = ['TestDataSpeed', 'PrintData', 'BatchData', 'BatchDataByShape', 'FixedSizeData', 'MapData', __all__ = ['TestDataSpeed', 'PrintData', 'BatchData', 'BatchDataByShape', 'FixedSizeData', 'MapData',
...@@ -23,14 +24,16 @@ __all__ = ['TestDataSpeed', 'PrintData', 'BatchData', 'BatchDataByShape', 'Fixed ...@@ -23,14 +24,16 @@ __all__ = ['TestDataSpeed', 'PrintData', 'BatchData', 'BatchDataByShape', 'Fixed
class TestDataSpeed(ProxyDataFlow): class TestDataSpeed(ProxyDataFlow):
""" Test the speed of some DataFlow """ """ Test the speed of some DataFlow """
def __init__(self, ds, size=5000): def __init__(self, ds, size=5000, warmup=0):
""" """
Args: Args:
ds (DataFlow): the DataFlow to test. ds (DataFlow): the DataFlow to test.
size (int): number of datapoints to fetch. size (int): number of datapoints to fetch.
warmup (int): warmup iterations
""" """
super(TestDataSpeed, self).__init__(ds) super(TestDataSpeed, self).__init__(ds)
self.test_size = size self.test_size = int(size)
self.warmup = int(warmup)
def get_data(self): def get_data(self):
""" Will run testing at the beginning, then produce data normally. """ """ Will run testing at the beginning, then produce data normally. """
...@@ -43,10 +46,14 @@ class TestDataSpeed(ProxyDataFlow): ...@@ -43,10 +46,14 @@ class TestDataSpeed(ProxyDataFlow):
Start testing with a progress bar. Start testing with a progress bar.
""" """
self.ds.reset_state() self.ds.reset_state()
itr = self.ds.get_data()
if self.warmup:
for d in tqdm.trange(self.warmup, **get_tqdm_kwargs()):
next(itr)
# add smoothing for speed benchmark # add smoothing for speed benchmark
with get_tqdm(total=self.test_size, with get_tqdm(total=self.test_size,
leave=True, smoothing=0.2) as pbar: leave=True, smoothing=0.2) as pbar:
for idx, dp in enumerate(self.ds.get_data()): for idx, dp in enumerate(itr):
pbar.update() pbar.update()
if idx == self.test_size - 1: if idx == self.test_size - 1:
break break
......
...@@ -33,10 +33,10 @@ def send_dataflow_zmq(df, addr, hwm=50, format=None, bind=False): ...@@ -33,10 +33,10 @@ def send_dataflow_zmq(df, addr, hwm=50, format=None, bind=False):
hwm (int): ZMQ high-water mark (buffer size) hwm (int): ZMQ high-water mark (buffer size)
format (str): The serialization format. format (str): The serialization format.
Default format would use :mod:`tensorpack.utils.serialize`. Default format would use :mod:`tensorpack.utils.serialize`.
An alternate format is 'zmq_op', used by https://github.com/tensorpack/zmq_ops. An alternate format is 'zmq_ops', used by https://github.com/tensorpack/zmq_ops.
bind (bool): whether to bind or connect to the endpoint. bind (bool): whether to bind or connect to the endpoint.
""" """
assert format in [None, 'zmq_op'] assert format in [None, 'zmq_op', 'zmq_ops']
if format is None: if format is None:
dump_fn = dumps dump_fn = dumps
else: else:
...@@ -52,7 +52,8 @@ def send_dataflow_zmq(df, addr, hwm=50, format=None, bind=False): ...@@ -52,7 +52,8 @@ def send_dataflow_zmq(df, addr, hwm=50, format=None, bind=False):
socket.connect(addr) socket.connect(addr)
try: try:
df.reset_state() df.reset_state()
logger.info("Serving data to {} ...".format(addr)) logger.info("Serving data to {} with {} format ...".format(
addr, 'default' if format is None else 'zmq_ops'))
INTERVAL = 200 INTERVAL = 200
q = deque(maxlen=INTERVAL) q = deque(maxlen=INTERVAL)
...@@ -60,7 +61,7 @@ def send_dataflow_zmq(df, addr, hwm=50, format=None, bind=False): ...@@ -60,7 +61,7 @@ def send_dataflow_zmq(df, addr, hwm=50, format=None, bind=False):
total = df.size() total = df.size()
except NotImplementedError: except NotImplementedError:
total = 0 total = 0
tqdm_args = get_tqdm_kwargs(leave=True) tqdm_args = get_tqdm_kwargs(leave=True, smoothing=0.8)
tqdm_args['bar_format'] = tqdm_args['bar_format'] + "{postfix}" tqdm_args['bar_format'] = tqdm_args['bar_format'] + "{postfix}"
while True: while True:
with tqdm.trange(total, **tqdm_args) as pbar: with tqdm.trange(total, **tqdm_args) as pbar:
...@@ -87,24 +88,32 @@ class RemoteDataZMQ(DataFlow): ...@@ -87,24 +88,32 @@ class RemoteDataZMQ(DataFlow):
Attributes: Attributes:
cnt1, cnt2 (int): number of data points received from addr1 and addr2 cnt1, cnt2 (int): number of data points received from addr1 and addr2
""" """
def __init__(self, addr1, addr2=None, hwm=50): def __init__(self, addr1, addr2=None, hwm=50, bind=True):
""" """
Args: Args:
addr1,addr2 (str): addr of the socket to connect to. addr1,addr2 (str): addr of the zmq endpoint to connect to.
Use both if you need two protocols (e.g. both IPC and TCP). Use both if you need two protocols (e.g. both IPC and TCP).
I don't think you'll ever need 3. I don't think you'll ever need 3.
hwm (int): ZMQ high-water mark (buffer size) hwm (int): ZMQ high-water mark (buffer size)
bind (bool): whether to connect or bind the endpoint
""" """
assert addr1 assert addr1
self._addr1 = addr1 self._addr1 = addr1
self._addr2 = addr2 self._addr2 = addr2
self._hwm = int(hwm) self._hwm = int(hwm)
self._guard = DataFlowReentrantGuard() self._guard = DataFlowReentrantGuard()
self._bind = bind
def reset_state(self): def reset_state(self):
self.cnt1 = 0 self.cnt1 = 0
self.cnt2 = 0 self.cnt2 = 0
def bind_or_connect(self, socket, addr):
if self._bind:
socket.bind(addr)
else:
socket.connect(addr)
def get_data(self): def get_data(self):
with self._guard: with self._guard:
try: try:
...@@ -112,7 +121,7 @@ class RemoteDataZMQ(DataFlow): ...@@ -112,7 +121,7 @@ class RemoteDataZMQ(DataFlow):
if self._addr2 is None: if self._addr2 is None:
socket = ctx.socket(zmq.PULL) socket = ctx.socket(zmq.PULL)
socket.set_hwm(self._hwm) socket.set_hwm(self._hwm)
socket.bind(self._addr1) self.bind_or_connect(socket, self._addr1)
while True: while True:
dp = loads(socket.recv(copy=False).bytes) dp = loads(socket.recv(copy=False).bytes)
...@@ -121,11 +130,11 @@ class RemoteDataZMQ(DataFlow): ...@@ -121,11 +130,11 @@ class RemoteDataZMQ(DataFlow):
else: else:
socket1 = ctx.socket(zmq.PULL) socket1 = ctx.socket(zmq.PULL)
socket1.set_hwm(self._hwm) socket1.set_hwm(self._hwm)
socket1.bind(self._addr1) self.bind_or_connect(socket1, self._addr1)
socket2 = ctx.socket(zmq.PULL) socket2 = ctx.socket(zmq.PULL)
socket2.set_hwm(self._hwm) socket2.set_hwm(self._hwm)
socket2.bind(self._addr2) self.bind_or_connect(socket2, self._addr2)
poller = zmq.Poller() poller = zmq.Poller()
poller.register(socket1, zmq.POLLIN) poller.register(socket1, zmq.POLLIN)
......
...@@ -375,7 +375,7 @@ class ZMQInput(TensorInput): ...@@ -375,7 +375,7 @@ class ZMQInput(TensorInput):
Recv tensors from a ZMQ endpoint, with ops from https://github.com/tensorpack/zmq_ops. Recv tensors from a ZMQ endpoint, with ops from https://github.com/tensorpack/zmq_ops.
It works with :meth:`dataflow.remote.send_dataflow_zmq(format='zmq_op')`. It works with :meth:`dataflow.remote.send_dataflow_zmq(format='zmq_op')`.
""" """
def __init__(self, end_point, hwm): def __init__(self, end_point, hwm, bind=True):
""" """
Args: Args:
end_point (str): end_point (str):
...@@ -383,6 +383,7 @@ class ZMQInput(TensorInput): ...@@ -383,6 +383,7 @@ class ZMQInput(TensorInput):
""" """
self._end_point = end_point self._end_point = end_point
self._hwm = int(hwm) self._hwm = int(hwm)
self._bind = bind
def fn(): def fn():
ret = self._zmq_pull_socket.pull() ret = self._zmq_pull_socket.pull()
...@@ -401,7 +402,8 @@ class ZMQInput(TensorInput): ...@@ -401,7 +402,8 @@ class ZMQInput(TensorInput):
self._zmq_pull_socket = zmq_ops.ZMQPullSocket( self._zmq_pull_socket = zmq_ops.ZMQPullSocket(
self._end_point, self._end_point,
[x.type for x in inputs_desc], [x.type for x in inputs_desc],
self._hwm) hwm=self._hwm,
bind=self._bind)
class TFDatasetInput(FeedfreeInput): class TFDatasetInput(FeedfreeInput):
......
...@@ -280,10 +280,16 @@ class HorovodTrainer(SingleCostTrainer): ...@@ -280,10 +280,16 @@ class HorovodTrainer(SingleCostTrainer):
--output-filename mylog -x LD_LIBRARY_PATH -x CUDA_VISIBLE_DEVICES=0,1,2,3 \ --output-filename mylog -x LD_LIBRARY_PATH -x CUDA_VISIBLE_DEVICES=0,1,2,3 \
python train.py python train.py
(Add other environment variables you need by -x, e.g. PYTHONPATH, PATH)
Note: Note:
1. If using all GPUs, you can always skip the `CUDA_VISIBLE_DEVICES` option. 1. Gradients are averaged among all processes.
2. If using all GPUs, you can always skip the `CUDA_VISIBLE_DEVICES` option.
3. Due to the use of MPI, training is less informative (no progress bar).
2. Due to the use of MPI, training is less informative (no progress bar). 4. MPI often fails to kill all processes. Be sure to check it.
""" """
def __init__(self): def __init__(self):
hvd.init() hvd.init()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment