Commit 8ab0d4a6 authored by Yuxin Wu's avatar Yuxin Wu

update docs

parent d7a13cb7
...@@ -369,11 +369,12 @@ def get_train_dataflow(): ...@@ -369,11 +369,12 @@ def get_train_dataflow():
return ret return ret
if cfg.DATA.NUM_WORKERS > 0: if cfg.DATA.NUM_WORKERS > 0:
buffer_size = cfg.DATA.NUM_WORKERS * 20
if cfg.TRAINER == 'horovod': if cfg.TRAINER == 'horovod':
buffer_size = cfg.DATA.NUM_WORKERS * 10 # one dataflow for each process, therefore don't need large buffer
ds = MultiThreadMapData(ds, cfg.DATA.NUM_WORKERS, preprocess, buffer_size=buffer_size) ds = MultiThreadMapData(ds, cfg.DATA.NUM_WORKERS, preprocess, buffer_size=buffer_size)
# MPI does not like fork() # MPI does not like fork()
else: else:
buffer_size = cfg.DATA.NUM_WORKERS * 20
ds = MultiProcessMapDataZMQ(ds, cfg.DATA.NUM_WORKERS, preprocess, buffer_size=buffer_size) ds = MultiProcessMapDataZMQ(ds, cfg.DATA.NUM_WORKERS, preprocess, buffer_size=buffer_size)
else: else:
ds = MapData(ds, preprocess) ds = MapData(ds, preprocess)
......
...@@ -144,7 +144,8 @@ class MultiProcessPrefetchData(ProxyDataFlow): ...@@ -144,7 +144,8 @@ class MultiProcessPrefetchData(ProxyDataFlow):
`Birthday Paradox <https://en.wikipedia.org/wiki/Birthday_problem>`_ `Birthday Paradox <https://en.wikipedia.org/wiki/Birthday_problem>`_
and know that you'll likely see duplicates. and know that you'll likely see duplicates.
To utilize parallelism with stricter data integrity, you can use the parallel versions of `MapData`. To utilize parallelism with more strict data integrity, you can use
the parallel versions of :class:`MapData`: :class:`MultiThreadMapData`, :class:`MultiProcessMapData`.
2. This has more serialization overhead than :class:`PrefetchDataZMQ` when data is large. 2. This has more serialization overhead than :class:`PrefetchDataZMQ` when data is large.
3. You can nest like this: ``PrefetchDataZMQ(PrefetchData(df, nr_proc=a), nr_proc=b)``. 3. You can nest like this: ``PrefetchDataZMQ(PrefetchData(df, nr_proc=a), nr_proc=b)``.
A total of ``a`` instances of ``df`` worker processes will be created. A total of ``a`` instances of ``df`` worker processes will be created.
...@@ -241,7 +242,8 @@ class PrefetchDataZMQ(_MultiProcessZMQDataFlow): ...@@ -241,7 +242,8 @@ class PrefetchDataZMQ(_MultiProcessZMQDataFlow):
`Birthday Paradox <https://en.wikipedia.org/wiki/Birthday_problem>`_ `Birthday Paradox <https://en.wikipedia.org/wiki/Birthday_problem>`_
and know that you'll likely see duplicates. and know that you'll likely see duplicates.
To utilize parallelism with stricter data integrity, you can use the parallel versions of `MapData`. To utilize parallelism with more strict data integrity, you can use
the parallel versions of :class:`MapData`: :class:`MultiThreadMapData`, :class:`MultiProcessMapData`.
2. `reset_state()` of the given dataflow will be called **once and only once** in the worker processes. 2. `reset_state()` of the given dataflow will be called **once and only once** in the worker processes.
3. The fork of processes happened in this dataflow's `reset_state()` method. 3. The fork of processes happened in this dataflow's `reset_state()` method.
Please note that forking a TensorFlow GPU session may be unsafe. Please note that forking a TensorFlow GPU session may be unsafe.
...@@ -365,7 +367,8 @@ class MultiThreadPrefetchData(DataFlow): ...@@ -365,7 +367,8 @@ class MultiThreadPrefetchData(DataFlow):
`Birthday Paradox <https://en.wikipedia.org/wiki/Birthday_problem>`_ `Birthday Paradox <https://en.wikipedia.org/wiki/Birthday_problem>`_
and know that you'll likely see duplicates. and know that you'll likely see duplicates.
To utilize parallelism with stricter data integrity, you can use the parallel versions of `MapData`. To utilize parallelism with more strict data integrity, you can use
the parallel versions of :class:`MapData`: :class:`MultiThreadMapData`, :class:`MultiProcessMapData`.
""" """
class _Worker(StoppableThread): class _Worker(StoppableThread):
......
...@@ -94,6 +94,9 @@ class TrainConfig(object): ...@@ -94,6 +94,9 @@ class TrainConfig(object):
starting_epoch (int): The index of the first epoch. starting_epoch (int): The index of the first epoch.
steps_per_epoch (int): the number of steps (defined by :meth:`Trainer.run_step`) to run in each epoch. steps_per_epoch (int): the number of steps (defined by :meth:`Trainer.run_step`) to run in each epoch.
Defaults to the input data size. You may want to divide it by the #GPUs in multi-GPU training. Defaults to the input data size. You may want to divide it by the #GPUs in multi-GPU training.
Number of steps per epoch only affects the schedule of callbacks.
It does not affect the sequence of input data seen by the model.
max_epoch (int): maximum number of epoch to run training. max_epoch (int): maximum number of epoch to run training.
""" """
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment