Commit 04c81965 authored by Yuxin Wu's avatar Yuxin Wu

update docs

parent 0226451c
...@@ -449,13 +449,16 @@ class TFDatasetInput(FeedfreeInput): ...@@ -449,13 +449,16 @@ class TFDatasetInput(FeedfreeInput):
Use a :class:`tf.data.Dataset` instance as input. Use a :class:`tf.data.Dataset` instance as input.
Note: Note:
In training, the dataset should be infinite (use :func:`repeat()`). 1. In training, the given dataset or dataflow has to be infinite
(you can use :func:`repeat()`, or :class:`RepeatedData` ).
2. TensorFlow may keep the dataflow alive even if the dataset is no
longer used.
""" """
def __init__(self, dataset): def __init__(self, dataset):
""" """
Args: Args:
dataset (tf.data.Dataset or DataFlow): if a DataFlow, the dataflow dataset (tf.data.Dataset or DataFlow):
has to be infinite.
""" """
if isinstance(dataset, tf.data.Dataset): if isinstance(dataset, tf.data.Dataset):
self._dataset = dataset self._dataset = dataset
...@@ -519,6 +522,10 @@ class TFDatasetInput(FeedfreeInput): ...@@ -519,6 +522,10 @@ class TFDatasetInput(FeedfreeInput):
Returns: Returns:
(tf.data.Dataset) (tf.data.Dataset)
Note:
TensorFlow may keep the dataflow alive even if the dataset is no
longer used.
""" """
# TODO theoretically it can support dict # TODO theoretically it can support dict
assert isinstance(df, DataFlow), df assert isinstance(df, DataFlow), df
......
...@@ -322,7 +322,7 @@ def BatchNorm(inputs, axis=None, training=None, momentum=0.9, epsilon=1e-5, ...@@ -322,7 +322,7 @@ def BatchNorm(inputs, axis=None, training=None, momentum=0.9, epsilon=1e-5,
logger.warn("BatchNorm(sync_statistics='horovod') is used with only one process!") logger.warn("BatchNorm(sync_statistics='horovod') is used with only one process!")
else: else:
import horovod import horovod
hvd_version = tuple(map(int, horovod.__version__.split('.'))) hvd_version = tuple(map(int, horovod.__version__.split('.')[:3]))
assert hvd_version >= (0, 13, 6), "sync_statistics=horovod needs horovod>=0.13.6 !" assert hvd_version >= (0, 13, 6), "sync_statistics=horovod needs horovod>=0.13.6 !"
batch_mean = hvd.allreduce(batch_mean, average=True) batch_mean = hvd.allreduce(batch_mean, average=True)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment