Commit 7a19c73f authored by Yuxin Wu's avatar Yuxin Wu

fix FasterRCNN evaluation

parent eb97cf1e
......@@ -54,16 +54,15 @@ Speed:
1. If CuDNN warmup is on, the training will start very slowly, until about
10k steps (or more if scale augmentation is used) to reach a maximum speed.
As a result, the ETA is also inaccurate at the beginning.
Warmup is by default on when no scale augmentation is used.
CuDNN warmup is by default on when no scale augmentation is used.
1. After warmup, the training speed will slowly decrease due to more accurate proposals.
1. The code should have around 70% GPU utilization on V100s, and 85%~90% scaling
efficiency from 1 V100 to 8 V100s.
1. This implementation does not contain specialized CUDA ops (e.g. AffineChannel, ROIAlign),
so it can be slightly (~10%) slower than Detectron (Caffe2) and
maskrcnn-benchmark (PyTorch).
1. This implementation does not use specialized CUDA ops (e.g. AffineChannel, ROIAlign).
Therefore it might be slower than other highly-optimized implementations.
Possible Future Enhancements:
......@@ -83,11 +82,6 @@ However, each version of TensorFlow has bugs that I either reported or fixed,
and this implementation touches many of those bugs.
Therefore, not every version of TF ≥ 1.6 supports every feature in this implementation.
This implementation contains workaround for some of those TF bugs.
However, note that the workaround needs to check your TF version by `tf.VERSION`
and may not detect bugs properly if your TF version is not an official release
(e.g., if you use a nighly build).
1. TF < 1.6: Nothing works due to lack of support for empty tensors
([PR](https://github.com/tensorflow/tensorflow/pull/15264))
and `FrozenBN` training
......@@ -98,3 +92,8 @@ and may not detect bugs properly if your TF version is not an official release
1. TF > 1.12: MKL inference will fail ([issue](https://github.com/tensorflow/tensorflow/issues/24650)).
1. TF > 1.12: Horovod training will fail ([issue](https://github.com/tensorflow/tensorflow/issues/25946)).
Latest tensorpack will apply a workaround.
This implementation contains workaround for some of these TF bugs.
However, note that the workaround needs to check your TF version by `tf.VERSION`,
and may not detect bugs properly if your TF version is not an official release
(e.g., if you use a nightly build).
......@@ -101,7 +101,7 @@ def predict_image(img, model_func):
# fill with none
masks = [None] * len(boxes)
results = [DetectionResult(*args) for args in zip(boxes, probs, labels, masks)]
results = [DetectionResult(*args) for args in zip(boxes, probs, labels.tolist(), masks)]
return results
......@@ -129,7 +129,7 @@ def predict_dataflow(df, model_func, tqdm_bar=None):
for r in results:
res = {
'image_id': img_id,
'category_id': r.class_id,
'category_id': int(r.class_id), # int() to make it json-serializable
'bbox': list(r.box),
'score': round(float(r.score), 4),
}
......
......@@ -384,13 +384,16 @@ class ImageNetModel(ModelDesc):
@staticmethod
def compute_loss_and_error(logits, label, label_smoothing=0.):
if label_smoothing == 0.:
if label_smoothing != 0.:
nclass = logits.shape[-1]
label = tf.one_hot(label, nclass) if label.shape.ndims == 1 else label
if label.shape.ndims == 1:
loss = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=logits, labels=label)
else:
nclass = logits.shape[-1]
loss = tf.losses.softmax_cross_entropy(
tf.one_hot(label, nclass),
logits, label_smoothing=label_smoothing, reduction=tf.losses.Reduction.NONE)
label, logits, label_smoothing=label_smoothing,
reduction=tf.losses.Reduction.NONE)
loss = tf.reduce_mean(loss, name='xentropy-loss')
def prediction_incorrect(logits, label, topk=1, name='incorrect_vector'):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment