Commit c2d99a44 authored by Yuxin Wu's avatar Yuxin Wu

update

parent f128a5c6
...@@ -18,7 +18,8 @@ class Hue(PhotometricAugmentor): ...@@ -18,7 +18,8 @@ class Hue(PhotometricAugmentor):
def __init__(self, range=(0, 180), rgb=True): def __init__(self, range=(0, 180), rgb=True):
""" """
Args: Args:
range(list or tuple): range from which the applied hue offset is selected (maximum [-90,90] or [0,180]) range(list or tuple): range from which the applied hue offset is selected
(maximum range can be [-90,90] for both uint8 and float32)
rgb (bool): whether input is RGB or BGR. rgb (bool): whether input is RGB or BGR.
""" """
super(Hue, self).__init__() super(Hue, self).__init__()
......
...@@ -7,7 +7,7 @@ import six ...@@ -7,7 +7,7 @@ import six
import tensorflow as tf import tensorflow as tf
from ..input_source import PlaceholderInput from ..input_source import PlaceholderInput
from ..tfutils.common import get_tensors_by_names from ..tfutils.common import get_tensors_by_names, get_op_tensor_name
from ..tfutils.tower import PredictTowerContext from ..tfutils.tower import PredictTowerContext
__all__ = ['PredictorBase', __all__ = ['PredictorBase',
...@@ -34,6 +34,9 @@ class PredictorBase(object): ...@@ -34,6 +34,9 @@ class PredictorBase(object):
.. code-block:: python .. code-block:: python
predictor(e1, e2) predictor(e1, e2)
Returns:
list[array]: list of outputs
""" """
output = self._do_call(dp) output = self._do_call(dp)
if self.return_input: if self.return_input:
...@@ -98,9 +101,14 @@ class OnlinePredictor(PredictorBase): ...@@ -98,9 +101,14 @@ class OnlinePredictor(PredictorBase):
will use the default session at the first call. will use the default session at the first call.
Note that in TensorFlow, default session is thread-local. Note that in TensorFlow, default session is thread-local.
""" """
def normalize_name(t):
if isinstance(t, six.string_types):
return get_op_tensor_name(t)[1]
return t
self.return_input = return_input self.return_input = return_input
self.input_tensors = input_tensors self.input_tensors = [normalize_name(x) for x in input_tensors]
self.output_tensors = output_tensors self.output_tensors = [normalize_name(x) for x in output_tensors]
self.sess = sess self.sess = sess
if sess is not None: if sess is not None:
......
...@@ -396,8 +396,7 @@ class HorovodTrainer(SingleCostTrainer): ...@@ -396,8 +396,7 @@ class HorovodTrainer(SingleCostTrainer):
compression: `hvd.Compression.fp16` or `hvd.Compression.none` compression: `hvd.Compression.fp16` or `hvd.Compression.none`
""" """
if 'pyarrow' in sys.modules: if 'pyarrow' in sys.modules:
logger.warn("Horovod and pyarrow may conflict due to pyarrow bugs. " logger.warn("Horovod and pyarrow may conflict due to pyarrow bugs.")
"Uninstall pyarrow and use msgpack instead.")
# lazy import # lazy import
import horovod.tensorflow as hvd import horovod.tensorflow as hvd
import horovod import horovod
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment