Commit 02f90879 authored by Alexander Korotkov's avatar Alexander Korotkov

Fix handling of NULL distances in KNN-GiST

In order to implement NULL LAST semantic GiST previously assumed distance to
the NULL value to be Inf.  However, our distance functions can return Inf and
NaN for non-null values.  In such cases, NULL LAST semantic appears to be
broken.  This commit fixes that by introducing separate array of null flags for
distances.

Backpatch to all supported versions.

Discussion: https://postgr.es/m/CAPpHfdsNvNdA0DBS%2BwMpFrgwT6C3-q50sFVGLSiuWnV3FqOJuQ%40mail.gmail.com
Author: Alexander Korotkov
Backpatch-through: 9.4
parent e5d8f359
...@@ -112,8 +112,9 @@ gistkillitems(IndexScanDesc scan) ...@@ -112,8 +112,9 @@ gistkillitems(IndexScanDesc scan)
* Similarly, *recheck_distances_p is set to indicate whether the distances * Similarly, *recheck_distances_p is set to indicate whether the distances
* need to be rechecked, and it is also ignored for non-leaf entries. * need to be rechecked, and it is also ignored for non-leaf entries.
* *
* If we are doing an ordered scan, so->distances[] is filled with distance * If we are doing an ordered scan, so->distancesValues[] and
* data from the distance() functions before returning success. * so->distancesNulls[] is filled with distance data from the distance()
* functions before returning success.
* *
* We must decompress the key in the IndexTuple before passing it to the * We must decompress the key in the IndexTuple before passing it to the
* sk_funcs (which actually are the opclass Consistent or Distance methods). * sk_funcs (which actually are the opclass Consistent or Distance methods).
...@@ -134,7 +135,8 @@ gistindex_keytest(IndexScanDesc scan, ...@@ -134,7 +135,8 @@ gistindex_keytest(IndexScanDesc scan,
GISTSTATE *giststate = so->giststate; GISTSTATE *giststate = so->giststate;
ScanKey key = scan->keyData; ScanKey key = scan->keyData;
int keySize = scan->numberOfKeys; int keySize = scan->numberOfKeys;
double *distance_p; double *distance_value_p;
bool *distance_null_p;
Relation r = scan->indexRelation; Relation r = scan->indexRelation;
*recheck_p = false; *recheck_p = false;
...@@ -152,7 +154,10 @@ gistindex_keytest(IndexScanDesc scan, ...@@ -152,7 +154,10 @@ gistindex_keytest(IndexScanDesc scan,
if (GistPageIsLeaf(page)) /* shouldn't happen */ if (GistPageIsLeaf(page)) /* shouldn't happen */
elog(ERROR, "invalid GiST tuple found on leaf page"); elog(ERROR, "invalid GiST tuple found on leaf page");
for (i = 0; i < scan->numberOfOrderBys; i++) for (i = 0; i < scan->numberOfOrderBys; i++)
so->distances[i] = -get_float8_infinity(); {
so->distanceValues[i] = -get_float8_infinity();
so->distanceNulls[i] = false;
}
return true; return true;
} }
...@@ -235,7 +240,8 @@ gistindex_keytest(IndexScanDesc scan, ...@@ -235,7 +240,8 @@ gistindex_keytest(IndexScanDesc scan,
/* OK, it passes --- now let's compute the distances */ /* OK, it passes --- now let's compute the distances */
key = scan->orderByData; key = scan->orderByData;
distance_p = so->distances; distance_value_p = so->distanceValues;
distance_null_p = so->distanceNulls;
keySize = scan->numberOfOrderBys; keySize = scan->numberOfOrderBys;
while (keySize > 0) while (keySize > 0)
{ {
...@@ -249,8 +255,9 @@ gistindex_keytest(IndexScanDesc scan, ...@@ -249,8 +255,9 @@ gistindex_keytest(IndexScanDesc scan,
if ((key->sk_flags & SK_ISNULL) || isNull) if ((key->sk_flags & SK_ISNULL) || isNull)
{ {
/* Assume distance computes as null and sorts to the end */ /* Assume distance computes as null */
*distance_p = get_float8_infinity(); *distance_value_p = 0.0;
*distance_null_p = true;
} }
else else
{ {
...@@ -287,11 +294,13 @@ gistindex_keytest(IndexScanDesc scan, ...@@ -287,11 +294,13 @@ gistindex_keytest(IndexScanDesc scan,
ObjectIdGetDatum(key->sk_subtype), ObjectIdGetDatum(key->sk_subtype),
PointerGetDatum(&recheck)); PointerGetDatum(&recheck));
*recheck_distances_p |= recheck; *recheck_distances_p |= recheck;
*distance_p = DatumGetFloat8(dist); *distance_value_p = DatumGetFloat8(dist);
*distance_null_p = false;
} }
key++; key++;
distance_p++; distance_value_p++;
distance_null_p++;
keySize--; keySize--;
} }
...@@ -304,7 +313,8 @@ gistindex_keytest(IndexScanDesc scan, ...@@ -304,7 +313,8 @@ gistindex_keytest(IndexScanDesc scan,
* *
* scan: index scan we are executing * scan: index scan we are executing
* pageItem: search queue item identifying an index page to scan * pageItem: search queue item identifying an index page to scan
* myDistances: distances array associated with pageItem, or NULL at the root * myDistanceValues: distances array associated with pageItem, or NULL at the root
* myDistanceNulls: null flags for myDistanceValues array, or NULL at the root
* tbm: if not NULL, gistgetbitmap's output bitmap * tbm: if not NULL, gistgetbitmap's output bitmap
* ntids: if not NULL, gistgetbitmap's output tuple counter * ntids: if not NULL, gistgetbitmap's output tuple counter
* *
...@@ -321,7 +331,8 @@ gistindex_keytest(IndexScanDesc scan, ...@@ -321,7 +331,8 @@ gistindex_keytest(IndexScanDesc scan,
* sibling will be processed next. * sibling will be processed next.
*/ */
static void static void
gistScanPage(IndexScanDesc scan, GISTSearchItem *pageItem, double *myDistances, gistScanPage(IndexScanDesc scan, GISTSearchItem *pageItem,
double *myDistanceValues, bool *myDistanceNulls,
TIDBitmap *tbm, int64 *ntids) TIDBitmap *tbm, int64 *ntids)
{ {
GISTScanOpaque so = (GISTScanOpaque) scan->opaque; GISTScanOpaque so = (GISTScanOpaque) scan->opaque;
...@@ -359,7 +370,7 @@ gistScanPage(IndexScanDesc scan, GISTSearchItem *pageItem, double *myDistances, ...@@ -359,7 +370,7 @@ gistScanPage(IndexScanDesc scan, GISTSearchItem *pageItem, double *myDistances,
GISTSearchItem *item; GISTSearchItem *item;
/* This can't happen when starting at the root */ /* This can't happen when starting at the root */
Assert(myDistances != NULL); Assert(myDistanceValues != NULL && myDistanceNulls != NULL);
oldcxt = MemoryContextSwitchTo(so->queueCxt); oldcxt = MemoryContextSwitchTo(so->queueCxt);
...@@ -369,8 +380,10 @@ gistScanPage(IndexScanDesc scan, GISTSearchItem *pageItem, double *myDistances, ...@@ -369,8 +380,10 @@ gistScanPage(IndexScanDesc scan, GISTSearchItem *pageItem, double *myDistances,
item->data.parentlsn = pageItem->data.parentlsn; item->data.parentlsn = pageItem->data.parentlsn;
/* Insert it into the queue using same distances as for this page */ /* Insert it into the queue using same distances as for this page */
memcpy(item->distances, myDistances, memcpy(GISTSearchItemDistanceValues(item, scan->numberOfOrderBys),
sizeof(double) * scan->numberOfOrderBys); myDistanceValues, sizeof(double) * scan->numberOfOrderBys);
memcpy(GISTSearchItemDistanceNulls(item, scan->numberOfOrderBys),
myDistanceNulls, sizeof(bool) * scan->numberOfOrderBys);
pairingheap_add(so->queue, &item->phNode); pairingheap_add(so->queue, &item->phNode);
...@@ -479,6 +492,7 @@ gistScanPage(IndexScanDesc scan, GISTSearchItem *pageItem, double *myDistances, ...@@ -479,6 +492,7 @@ gistScanPage(IndexScanDesc scan, GISTSearchItem *pageItem, double *myDistances,
* search. * search.
*/ */
GISTSearchItem *item; GISTSearchItem *item;
int nOrderBys = scan->numberOfOrderBys;
oldcxt = MemoryContextSwitchTo(so->queueCxt); oldcxt = MemoryContextSwitchTo(so->queueCxt);
...@@ -513,8 +527,10 @@ gistScanPage(IndexScanDesc scan, GISTSearchItem *pageItem, double *myDistances, ...@@ -513,8 +527,10 @@ gistScanPage(IndexScanDesc scan, GISTSearchItem *pageItem, double *myDistances,
} }
/* Insert it into the queue using new distance data */ /* Insert it into the queue using new distance data */
memcpy(item->distances, so->distances, memcpy(GISTSearchItemDistanceValues(item, nOrderBys),
sizeof(double) * scan->numberOfOrderBys); so->distanceValues, sizeof(double) * nOrderBys);
memcpy(GISTSearchItemDistanceNulls(item, nOrderBys),
so->distanceNulls, sizeof(bool) * nOrderBys);
pairingheap_add(so->queue, &item->phNode); pairingheap_add(so->queue, &item->phNode);
...@@ -579,7 +595,8 @@ getNextNearest(IndexScanDesc scan) ...@@ -579,7 +595,8 @@ getNextNearest(IndexScanDesc scan)
scan->xs_recheck = item->data.heap.recheck; scan->xs_recheck = item->data.heap.recheck;
index_store_float8_orderby_distances(scan, so->orderByTypes, index_store_float8_orderby_distances(scan, so->orderByTypes,
item->distances, GISTSearchItemDistanceValues(item, scan->numberOfOrderBys),
GISTSearchItemDistanceNulls(item, scan->numberOfOrderBys),
item->data.heap.recheckDistances); item->data.heap.recheckDistances);
/* in an index-only scan, also return the reconstructed tuple. */ /* in an index-only scan, also return the reconstructed tuple. */
...@@ -592,7 +609,10 @@ getNextNearest(IndexScanDesc scan) ...@@ -592,7 +609,10 @@ getNextNearest(IndexScanDesc scan)
/* visit an index page, extract its items into queue */ /* visit an index page, extract its items into queue */
CHECK_FOR_INTERRUPTS(); CHECK_FOR_INTERRUPTS();
gistScanPage(scan, item, item->distances, NULL, NULL); gistScanPage(scan, item,
GISTSearchItemDistanceValues(item, scan->numberOfOrderBys),
GISTSearchItemDistanceNulls(item, scan->numberOfOrderBys),
NULL, NULL);
} }
pfree(item); pfree(item);
...@@ -630,7 +650,7 @@ gistgettuple(IndexScanDesc scan, ScanDirection dir) ...@@ -630,7 +650,7 @@ gistgettuple(IndexScanDesc scan, ScanDirection dir)
fakeItem.blkno = GIST_ROOT_BLKNO; fakeItem.blkno = GIST_ROOT_BLKNO;
memset(&fakeItem.data.parentlsn, 0, sizeof(GistNSN)); memset(&fakeItem.data.parentlsn, 0, sizeof(GistNSN));
gistScanPage(scan, &fakeItem, NULL, NULL, NULL); gistScanPage(scan, &fakeItem, NULL, NULL, NULL, NULL);
} }
if (scan->numberOfOrderBys > 0) if (scan->numberOfOrderBys > 0)
...@@ -724,7 +744,10 @@ gistgettuple(IndexScanDesc scan, ScanDirection dir) ...@@ -724,7 +744,10 @@ gistgettuple(IndexScanDesc scan, ScanDirection dir)
* this page, we fall out of the inner "do" and loop around to * this page, we fall out of the inner "do" and loop around to
* return them. * return them.
*/ */
gistScanPage(scan, item, item->distances, NULL, NULL); gistScanPage(scan, item,
GISTSearchItemDistanceValues(item, scan->numberOfOrderBys),
GISTSearchItemDistanceNulls(item, scan->numberOfOrderBys),
NULL, NULL);
pfree(item); pfree(item);
} while (so->nPageData == 0); } while (so->nPageData == 0);
...@@ -755,7 +778,7 @@ gistgetbitmap(IndexScanDesc scan, TIDBitmap *tbm) ...@@ -755,7 +778,7 @@ gistgetbitmap(IndexScanDesc scan, TIDBitmap *tbm)
fakeItem.blkno = GIST_ROOT_BLKNO; fakeItem.blkno = GIST_ROOT_BLKNO;
memset(&fakeItem.data.parentlsn, 0, sizeof(GistNSN)); memset(&fakeItem.data.parentlsn, 0, sizeof(GistNSN));
gistScanPage(scan, &fakeItem, NULL, tbm, &ntids); gistScanPage(scan, &fakeItem, NULL, NULL, tbm, &ntids);
/* /*
* While scanning a leaf page, ItemPointers of matching heap tuples will * While scanning a leaf page, ItemPointers of matching heap tuples will
...@@ -770,7 +793,10 @@ gistgetbitmap(IndexScanDesc scan, TIDBitmap *tbm) ...@@ -770,7 +793,10 @@ gistgetbitmap(IndexScanDesc scan, TIDBitmap *tbm)
CHECK_FOR_INTERRUPTS(); CHECK_FOR_INTERRUPTS();
gistScanPage(scan, item, item->distances, tbm, &ntids); gistScanPage(scan, item,
GISTSearchItemDistanceValues(item, scan->numberOfOrderBys),
GISTSearchItemDistanceNulls(item, scan->numberOfOrderBys),
tbm, &ntids);
pfree(item); pfree(item);
} }
......
...@@ -33,14 +33,30 @@ pairingheap_GISTSearchItem_cmp(const pairingheap_node *a, const pairingheap_node ...@@ -33,14 +33,30 @@ pairingheap_GISTSearchItem_cmp(const pairingheap_node *a, const pairingheap_node
const GISTSearchItem *sb = (const GISTSearchItem *) b; const GISTSearchItem *sb = (const GISTSearchItem *) b;
IndexScanDesc scan = (IndexScanDesc) arg; IndexScanDesc scan = (IndexScanDesc) arg;
int i; int i;
double *da = GISTSearchItemDistanceValues(sa, scan->numberOfOrderBys),
*db = GISTSearchItemDistanceValues(sb, scan->numberOfOrderBys);
bool *na = GISTSearchItemDistanceNulls(sa, scan->numberOfOrderBys),
*nb = GISTSearchItemDistanceNulls(sb, scan->numberOfOrderBys);
/* Order according to distance comparison */ /* Order according to distance comparison */
for (i = 0; i < scan->numberOfOrderBys; i++) for (i = 0; i < scan->numberOfOrderBys; i++)
{ {
int cmp = -float8_cmp_internal(sa->distances[i], sb->distances[i]); if (na[i])
{
if (!nb[i])
return -1;
}
else if (nb[i])
{
return 1;
}
else
{
int cmp = -float8_cmp_internal(da[i], db[i]);
if (cmp != 0) if (cmp != 0)
return cmp; return cmp;
}
} }
/* Heap items go before inner pages, to ensure a depth-first search */ /* Heap items go before inner pages, to ensure a depth-first search */
...@@ -84,7 +100,8 @@ gistbeginscan(Relation r, int nkeys, int norderbys) ...@@ -84,7 +100,8 @@ gistbeginscan(Relation r, int nkeys, int norderbys)
so->queueCxt = giststate->scanCxt; /* see gistrescan */ so->queueCxt = giststate->scanCxt; /* see gistrescan */
/* workspaces with size dependent on numberOfOrderBys: */ /* workspaces with size dependent on numberOfOrderBys: */
so->distances = palloc(sizeof(double) * scan->numberOfOrderBys); so->distanceValues = palloc(sizeof(double) * scan->numberOfOrderBys);
so->distanceNulls = palloc(sizeof(bool) * scan->numberOfOrderBys);
so->qual_ok = true; /* in case there are zero keys */ so->qual_ok = true; /* in case there are zero keys */
if (scan->numberOfOrderBys > 0) if (scan->numberOfOrderBys > 0)
{ {
......
...@@ -847,13 +847,14 @@ index_getprocinfo(Relation irel, ...@@ -847,13 +847,14 @@ index_getprocinfo(Relation irel,
*/ */
void void
index_store_float8_orderby_distances(IndexScanDesc scan, Oid *orderByTypes, index_store_float8_orderby_distances(IndexScanDesc scan, Oid *orderByTypes,
double *distances, bool recheckOrderBy) double *distanceValues,
bool *distanceNulls, bool recheckOrderBy)
{ {
int i; int i;
scan->xs_recheckorderby = recheckOrderBy; scan->xs_recheckorderby = recheckOrderBy;
if (!distances) if (!distanceValues)
{ {
Assert(!scan->xs_recheckorderby); Assert(!scan->xs_recheckorderby);
...@@ -868,6 +869,11 @@ index_store_float8_orderby_distances(IndexScanDesc scan, Oid *orderByTypes, ...@@ -868,6 +869,11 @@ index_store_float8_orderby_distances(IndexScanDesc scan, Oid *orderByTypes,
for (i = 0; i < scan->numberOfOrderBys; i++) for (i = 0; i < scan->numberOfOrderBys; i++)
{ {
if (distanceNulls && distanceNulls[i])
{
scan->xs_orderbyvals[i] = (Datum) 0;
scan->xs_orderbynulls[i] = true;
}
if (orderByTypes[i] == FLOAT8OID) if (orderByTypes[i] == FLOAT8OID)
{ {
#ifndef USE_FLOAT8_BYVAL #ifndef USE_FLOAT8_BYVAL
...@@ -875,7 +881,7 @@ index_store_float8_orderby_distances(IndexScanDesc scan, Oid *orderByTypes, ...@@ -875,7 +881,7 @@ index_store_float8_orderby_distances(IndexScanDesc scan, Oid *orderByTypes,
if (!scan->xs_orderbynulls[i]) if (!scan->xs_orderbynulls[i])
pfree(DatumGetPointer(scan->xs_orderbyvals[i])); pfree(DatumGetPointer(scan->xs_orderbyvals[i]));
#endif #endif
scan->xs_orderbyvals[i] = Float8GetDatum(distances[i]); scan->xs_orderbyvals[i] = Float8GetDatum(distanceValues[i]);
scan->xs_orderbynulls[i] = false; scan->xs_orderbynulls[i] = false;
} }
else if (orderByTypes[i] == FLOAT4OID) else if (orderByTypes[i] == FLOAT4OID)
...@@ -886,7 +892,7 @@ index_store_float8_orderby_distances(IndexScanDesc scan, Oid *orderByTypes, ...@@ -886,7 +892,7 @@ index_store_float8_orderby_distances(IndexScanDesc scan, Oid *orderByTypes,
if (!scan->xs_orderbynulls[i]) if (!scan->xs_orderbynulls[i])
pfree(DatumGetPointer(scan->xs_orderbyvals[i])); pfree(DatumGetPointer(scan->xs_orderbyvals[i]));
#endif #endif
scan->xs_orderbyvals[i] = Float4GetDatum((float4) distances[i]); scan->xs_orderbyvals[i] = Float4GetDatum((float4) distanceValues[i]);
scan->xs_orderbynulls[i] = false; scan->xs_orderbynulls[i] = false;
} }
else else
......
...@@ -929,6 +929,7 @@ spggettuple(IndexScanDesc scan, ScanDirection dir) ...@@ -929,6 +929,7 @@ spggettuple(IndexScanDesc scan, ScanDirection dir)
if (so->numberOfOrderBys > 0) if (so->numberOfOrderBys > 0)
index_store_float8_orderby_distances(scan, so->orderByTypes, index_store_float8_orderby_distances(scan, so->orderByTypes,
so->distances[so->iPtr], so->distances[so->iPtr],
NULL,
so->recheckDistances[so->iPtr]); so->recheckDistances[so->iPtr]);
so->iPtr++; so->iPtr++;
return true; return true;
......
...@@ -178,7 +178,9 @@ extern RegProcedure index_getprocid(Relation irel, AttrNumber attnum, ...@@ -178,7 +178,9 @@ extern RegProcedure index_getprocid(Relation irel, AttrNumber attnum,
extern FmgrInfo *index_getprocinfo(Relation irel, AttrNumber attnum, extern FmgrInfo *index_getprocinfo(Relation irel, AttrNumber attnum,
uint16 procnum); uint16 procnum);
extern void index_store_float8_orderby_distances(IndexScanDesc scan, extern void index_store_float8_orderby_distances(IndexScanDesc scan,
Oid *orderByTypes, double *distances, Oid *orderByTypes,
double *distanceValues,
bool *distanceNulls,
bool recheckOrderBy); bool recheckOrderBy);
/* /*
......
...@@ -137,13 +137,30 @@ typedef struct GISTSearchItem ...@@ -137,13 +137,30 @@ typedef struct GISTSearchItem
/* we must store parentlsn to detect whether a split occurred */ /* we must store parentlsn to detect whether a split occurred */
GISTSearchHeapItem heap; /* heap info, if heap tuple */ GISTSearchHeapItem heap; /* heap info, if heap tuple */
} data; } data;
double distances[FLEXIBLE_ARRAY_MEMBER]; /* numberOfOrderBys
* entries */ /*
* This data structure is followed by arrays of distance values and
* distance null flags. Size of both arrays is
* IndexScanDesc->numberOfOrderBys. See macros below for accessing those
* arrays.
*/
} GISTSearchItem; } GISTSearchItem;
#define GISTSearchItemIsHeap(item) ((item).blkno == InvalidBlockNumber) #define GISTSearchItemIsHeap(item) ((item).blkno == InvalidBlockNumber)
#define SizeOfGISTSearchItem(n_distances) (offsetof(GISTSearchItem, distances) + sizeof(double) * (n_distances)) #define SizeOfGISTSearchItem(n_distances) (DOUBLEALIGN(sizeof(GISTSearchItem)) + \
(sizeof(double) + sizeof(bool)) * (n_distances))
/*
* We actually don't need n_distances compute pointer to distance values.
* Nevertheless take n_distances as argument to have same arguments list for
* GISTSearchItemDistanceValues() and GISTSearchItemDistanceNulls().
*/
#define GISTSearchItemDistanceValues(item, n_distances) \
((double *) ((Pointer) (item) + DOUBLEALIGN(sizeof(GISTSearchItem))))
#define GISTSearchItemDistanceNulls(item, n_distances) \
((bool *) ((Pointer) (item) + DOUBLEALIGN(sizeof(GISTSearchItem)) + sizeof(double) * (n_distances)))
/* /*
* GISTScanOpaqueData: private state for a scan of a GiST index * GISTScanOpaqueData: private state for a scan of a GiST index
...@@ -159,7 +176,8 @@ typedef struct GISTScanOpaqueData ...@@ -159,7 +176,8 @@ typedef struct GISTScanOpaqueData
bool firstCall; /* true until first gistgettuple call */ bool firstCall; /* true until first gistgettuple call */
/* pre-allocated workspace arrays */ /* pre-allocated workspace arrays */
double *distances; /* output area for gistindex_keytest */ double *distanceValues; /* output area for gistindex_keytest */
bool *distanceNulls;
/* info about killed items if any (killedItems is NULL if never used) */ /* info about killed items if any (killedItems is NULL if never used) */
OffsetNumber *killedItems; /* offset numbers of killed items */ OffsetNumber *killedItems; /* offset numbers of killed items */
......
...@@ -531,8 +531,8 @@ SELECT * FROM point_tbl ORDER BY f1 <-> '0,1'; ...@@ -531,8 +531,8 @@ SELECT * FROM point_tbl ORDER BY f1 <-> '0,1';
(-5,-12) (-5,-12)
(5.1,34.5) (5.1,34.5)
(1e+300,Infinity) (1e+300,Infinity)
(NaN,NaN) (NaN,NaN)
(10 rows) (10 rows)
EXPLAIN (COSTS OFF) EXPLAIN (COSTS OFF)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment