Commit 23239bd7 authored by Yuxin Wu's avatar Yuxin Wu

update docs

parent d8e2929b
...@@ -462,9 +462,19 @@ class HorovodTrainer(SingleCostTrainer): ...@@ -462,9 +462,19 @@ class HorovodTrainer(SingleCostTrainer):
class BytePSTrainer(HorovodTrainer): class BytePSTrainer(HorovodTrainer):
""" """
BytePS trainer. Supports both multi-GPU and distributed training. BytePS trainer. Supports both multi-GPU and distributed training.
It achieves better scalability than horovod in distributed training, if the model is communication
intensive and you have properly set up the machines following its
`best practices <https://github.com/bytedance/byteps/blob/master/docs/best-practice.md>`_
which requires a few extra bandwidth servers than horovod.
To use it, switch the trainer, and fefer to BytePS documentation on how to To use it, switch the trainer, and refer to BytePS documentation on how to
launch server/scheduler/workers. launch server/scheduler/workers.
Attributes:
hvd (module): the byteps module that contains horovod-compatible APIs
like `rank(),size()`.
This attribute exists so that downstream code that uses these APIs
does not need to worry about which library is being used under the hood.
""" """
def __init__(self, average=True): def __init__(self, average=True):
""" """
...@@ -474,7 +484,8 @@ class BytePSTrainer(HorovodTrainer): ...@@ -474,7 +484,8 @@ class BytePSTrainer(HorovodTrainer):
import byteps.tensorflow as bps import byteps.tensorflow as bps
self.hvd = bps # BytePS has the same interface as Horovod self.hvd = bps # BytePS has the same interface as Horovod
self.hvd.allreduce = bps.push_pull # https://github.com/bytedance/byteps/issues/8 self.hvd.allreduce = bps.push_pull # https://github.com/bytedance/byteps/issues/8
# TODO bootstrap env vars assert os.environ.get("DMLC_ROLE", None) == "worker"
assert "DMLC_WORKER_ID" in os.environ and "DMLC_NUM_WORKER" in os.environ
bps.init() bps.init()
self.is_chief = bps.rank() == 0 self.is_chief = bps.rank() == 0
......
...@@ -56,7 +56,7 @@ def _getlogger(): ...@@ -56,7 +56,7 @@ def _getlogger():
_logger = _getlogger() _logger = _getlogger()
_LOGGING_METHOD = ['info', 'warning', 'error', 'critical', 'exception', 'debug', 'setLevel'] _LOGGING_METHOD = ['info', 'warning', 'error', 'critical', 'exception', 'debug', 'setLevel', 'addFilter']
# export logger functions # export logger functions
for func in _LOGGING_METHOD: for func in _LOGGING_METHOD:
locals()[func] = getattr(_logger, func) locals()[func] = getattr(_logger, func)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment