Commit a9864bf0 authored by Yuxin Wu's avatar Yuxin Wu

small changes; fix travis

parent c4897600
......@@ -26,10 +26,10 @@ matrix:
env: TF_VERSION=1.3.0 TF_TYPE=release
- os: linux
python: 2.7
env: TF_VERSION=1.10.0 TF_TYPE=release
env: TF_VERSION=1.12.0 TF_TYPE=release
- os: linux
python: 3.6
env: TF_VERSION=1.10.0 TF_TYPE=release PYPI=true
env: TF_VERSION=1.12.0 TF_TYPE=release PYPI=true
- os: linux
python: 2.7
env: TF_TYPE=nightly
......@@ -43,7 +43,7 @@ install:
- pip install -U pip # the pip version on travis is too old
- pip install .
- pip install flake8 scikit-image opencv-python lmdb h5py msgpack
# check that dataflow can be imported alone
# check that dataflow can be imported alone without tensorflow
- python -c "import tensorpack.dataflow"
- ./tests/install-tensorflow.sh
......
......@@ -222,9 +222,9 @@ def fastrcnn_predictions(boxes, scores):
return mask
# TF bug in version 1.11, 1.12: https://github.com/tensorflow/tensorflow/issues/22750
parallel = 1 if (get_tf_version_tuple() in [(1, 11), (1, 12)]) else 10
buggy_tf = get_tf_version_tuple() in [(1, 11), (1, 12)]
masks = tf.map_fn(f, (scores, boxes), dtype=tf.bool,
parallel_iterations=parallel) # #cat x N
parallel_iterations=1 if buggy_tf else 10) # #cat x N
selected_indices = tf.where(masks) # #selection x 2, each is (cat_id, box_id)
scores = tf.boolean_mask(scores, masks)
......
......@@ -16,6 +16,7 @@ if STATICA_HACK:
from .noise import *
from .paste import *
from .transform import *
from .external import *
import os
......
......@@ -33,12 +33,12 @@ class IAAugmentor(ImageAugmentor):
def _get_augment_params(self, img):
return (self._aug.to_deterministic(), img.shape)
def _augment(self, img, p):
aug, _ = p
def _augment(self, img, param):
aug, _ = param
return aug.augment_image(img)
def _augment_coords(self, coords, p):
aug, shape = p
def _augment_coords(self, coords, param):
aug, shape = param
points = [IA.Keypoint(x=x, y=y) for x, y in coords]
points = IA.KeypointsOnImage(points, shape=shape)
augmented = aug.augment_keypoints([points])[0].keypoints
......
......@@ -6,6 +6,7 @@ os.environ['OPENCV_OPENCL_RUNTIME'] = 'disabled' # https://github.com/opencv
try:
# issue#1924 may happen on old systems
import cv2 # noqa
# cv2.setNumThreads(0)
if int(cv2.__version__.split('.')[0]) == 3:
cv2.ocl.setUseOpenCL(False)
# check if cv is built with cuda or openmp
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment