Commit 25ca5a9a authored by Teodor Sigaev's avatar Teodor Sigaev

Replace plain-memory ordered array by binary tree in ts_stat() function.

Performance is increased from 50% up to 10^3 times depending on data.
parent 18004101
...@@ -7,7 +7,7 @@ ...@@ -7,7 +7,7 @@
* *
* *
* IDENTIFICATION * IDENTIFICATION
* $PostgreSQL: pgsql/src/backend/utils/adt/tsvector_op.c,v 1.17 2008/11/10 21:49:16 alvherre Exp $ * $PostgreSQL: pgsql/src/backend/utils/adt/tsvector_op.c,v 1.18 2008/11/17 12:17:09 teodor Exp $
* *
*------------------------------------------------------------------------- *-------------------------------------------------------------------------
*/ */
...@@ -34,34 +34,33 @@ typedef struct ...@@ -34,34 +34,33 @@ typedef struct
char *operand; char *operand;
} CHKVAL; } CHKVAL;
typedef struct
{
uint32 cur;
TSVector stat;
} StatStorage;
typedef struct typedef struct StatEntry
{ {
uint32 len; uint32 ndoc; /* zero indicates that we already was here while
uint32 pos; walking throug the tree */
uint32 ndoc;
uint32 nentry; uint32 nentry;
struct StatEntry *left;
struct StatEntry *right;
uint32 lenlexeme;
char lexeme[1];
} StatEntry; } StatEntry;
#define STATENTRYHDRSZ (offsetof(StatEntry, lexeme))
typedef struct typedef struct
{ {
int32 vl_len_; /* varlena header (do not touch directly!) */
int4 size;
int4 weight; int4 weight;
char data[1];
} tsstat;
#define STATHDRSIZE (sizeof(int4) * 4) uint32 maxdepth;
#define CALCSTATSIZE(x, lenstr) ( (x) * sizeof(StatEntry) + STATHDRSIZE + (lenstr) )
#define STATPTR(x) ( (StatEntry*) ( (char*)(x) + STATHDRSIZE ) ) StatEntry **stack;
#define STATSTRPTR(x) ( (char*)(x) + STATHDRSIZE + ( sizeof(StatEntry) * ((TSVector)(x))->size ) ) uint32 stackpos;
#define STATSTRSIZE(x) ( VARSIZE((TSVector)(x)) - STATHDRSIZE - ( sizeof(StatEntry) * ((TSVector)(x))->size ) )
StatEntry* root;
} TSVectorStat;
#define STATHDRSIZE (offsetof(TSVectorStat, data))
static Datum tsvector_update_trigger(PG_FUNCTION_ARGS, bool config_column); static Datum tsvector_update_trigger(PG_FUNCTION_ARGS, bool config_column);
...@@ -801,92 +800,95 @@ check_weight(TSVector txt, WordEntry *wptr, int8 weight) ...@@ -801,92 +800,95 @@ check_weight(TSVector txt, WordEntry *wptr, int8 weight)
return num; return num;
} }
#define compareStatWord(a,e,s,t) \ #define compareStatWord(a,e,t) \
tsCompareString(STATSTRPTR(s) + (a)->pos, (a)->len, \ tsCompareString((a)->lexeme, (a)->lenlexeme, \
STRPTR(t) + (e)->pos, (e)->len, \ STRPTR(t) + (e)->pos, (e)->len, \
false) false)
typedef struct WordEntryMark static void
insertStatEntry(MemoryContext persistentContext, TSVectorStat *stat, TSVector txt, uint32 off)
{ {
WordEntry *newentry; WordEntry *we = ARRPTR(txt) + off;
StatEntry *pos; StatEntry *node = stat->root,
} WordEntryMark; *pnode=NULL;
int n,
res;
uint32 depth=1;
if (stat->weight == 0)
n = (we->haspos) ? POSDATALEN(txt, we) : 1;
else
n = (we->haspos) ? check_weight(txt, we, stat->weight) : 0;
static tsstat * if ( n == 0 )
formstat(tsstat *stat, TSVector txt, List *entries) return; /* nothing to insert */
{
tsstat *newstat;
uint32 totallen,
nentry,
len = list_length(entries);
uint32 slen = 0;
WordEntry *ptr;
char *curptr;
StatEntry *sptr,
*nptr;
ListCell *entry;
StatEntry *PosSE = STATPTR(stat),
*prevPosSE;
WordEntryMark *mark;
foreach( entry, entries )
{
mark = (WordEntryMark*)lfirst(entry);
slen += mark->newentry->len;
}
nentry = stat->size + len; while( node )
slen += STATSTRSIZE(stat); {
totallen = CALCSTATSIZE(nentry, slen); res = compareStatWord(node, we, txt);
newstat = palloc(totallen);
SET_VARSIZE(newstat, totallen);
newstat->weight = stat->weight;
newstat->size = nentry;
memcpy(STATSTRPTR(newstat), STATSTRPTR(stat), STATSTRSIZE(stat)); if (res == 0)
curptr = STATSTRPTR(newstat) + STATSTRSIZE(stat); {
break;
}
else
{
pnode = node;
node = ( res < 0 ) ? node->left : node->right;
}
depth++;
}
sptr = STATPTR(stat); if (depth > stat->maxdepth)
nptr = STATPTR(newstat); stat->maxdepth = depth;
foreach(entry, entries) if (node == NULL)
{ {
prevPosSE = PosSE; node = MemoryContextAlloc(persistentContext, STATENTRYHDRSZ + we->len );
node->left = node->right = NULL;
mark = (WordEntryMark*)lfirst(entry); node->ndoc = 1;
ptr = mark->newentry; node->nentry = n;
PosSE = mark->pos; node->lenlexeme = we->len;
memcpy(node->lexeme, STRPTR(txt) + we->pos, node->lenlexeme);
/*
* Copy missed entries if ( pnode==NULL )
*/
if ( PosSE > prevPosSE )
{ {
memcpy( nptr, prevPosSE, sizeof(StatEntry) * (PosSE-prevPosSE) ); stat->root = node;
nptr += PosSE-prevPosSE;
} }
/*
* Copy new entry
*/
if (ptr->haspos)
nptr->nentry = (stat->weight) ? check_weight(txt, ptr, stat->weight) : POSDATALEN(txt, ptr);
else else
nptr->nentry = 1; {
nptr->ndoc = 1; if (res < 0)
nptr->len = ptr->len; pnode->left = node;
memcpy(curptr, STRPTR(txt) + ptr->pos, nptr->len); else
nptr->pos = curptr - STATSTRPTR(newstat); pnode->right = node;
curptr += nptr->len; }
nptr++;
pfree(mark);
} }
else
{
node->ndoc++;
node->nentry += n;
}
}
if ( PosSE < (StatEntry *) STATSTRPTR(stat) ) static void
memcpy(nptr, PosSE, sizeof(StatEntry) * (stat->size - (PosSE - STATPTR(stat)))); chooseNextStatEntry(MemoryContext persistentContext, TSVectorStat *stat, TSVector txt,
uint32 low, uint32 high, uint32 offset)
return newstat; {
uint32 pos;
uint32 middle = (low + high) >> 1;
pos = (low + middle) >> 1;
if (low != middle && pos >= offset && pos - offset < txt->size)
insertStatEntry( persistentContext, stat, txt, pos - offset );
pos = (high + middle + 1) >> 1;
if (middle + 1 != high && pos >= offset && pos - offset < txt->size)
insertStatEntry( persistentContext, stat, txt, pos - offset );
if (low != middle)
chooseNextStatEntry(persistentContext, stat, txt, low, middle, offset);
if (high != middle + 1)
chooseNextStatEntry(persistentContext, stat, txt, middle + 1, high, offset);
} }
/* /*
...@@ -901,115 +903,69 @@ formstat(tsstat *stat, TSVector txt, List *entries) ...@@ -901,115 +903,69 @@ formstat(tsstat *stat, TSVector txt, List *entries)
* where vector_column is a tsvector-type column in vector_table. * where vector_column is a tsvector-type column in vector_table.
*/ */
static tsstat * static TSVectorStat *
ts_accum(tsstat *stat, Datum data) ts_accum(MemoryContext persistentContext, TSVectorStat *stat, Datum data)
{ {
tsstat *newstat; TSVector txt = DatumGetTSVector(data);
TSVector txt = DatumGetTSVector(data); uint32 i,
StatEntry *sptr; nbit = 0,
WordEntry *wptr; offset;
int n = 0;
List *newentries=NIL;
StatEntry *StopLow;
if (stat == NULL) if (stat == NULL)
{ /* Init in first */ { /* Init in first */
stat = palloc(STATHDRSIZE); stat = MemoryContextAllocZero(persistentContext, sizeof(TSVectorStat));
SET_VARSIZE(stat, STATHDRSIZE); stat->maxdepth = 1;
stat->size = 0;
stat->weight = 0;
} }
/* simple check of correctness */ /* simple check of correctness */
if (txt == NULL || txt->size == 0) if (txt == NULL || txt->size == 0)
{ {
if (txt != (TSVector) DatumGetPointer(data)) if (txt && txt != (TSVector) DatumGetPointer(data))
pfree(txt); pfree(txt);
return stat; return stat;
} }
sptr = STATPTR(stat); i = txt->size - 1;
wptr = ARRPTR(txt); for (; i > 0; i >>= 1)
StopLow = STATPTR(stat); nbit++;
while (wptr - ARRPTR(txt) < txt->size)
{
StatEntry *StopHigh = (StatEntry *) STATSTRPTR(stat);
int cmp;
/*
* We do not set StopLow to begin of array because tsvector is ordered
* with the sames rule, so we can search from last stopped position
*/
while (StopLow < StopHigh)
{
sptr = StopLow + (StopHigh - StopLow) / 2;
cmp = compareStatWord(sptr, wptr, stat, txt);
if (cmp == 0)
{
if (stat->weight == 0)
{
sptr->ndoc++;
sptr->nentry += (wptr->haspos) ? POSDATALEN(txt, wptr) : 1;
}
else if (wptr->haspos && (n = check_weight(txt, wptr, stat->weight)) != 0)
{
sptr->ndoc++;
sptr->nentry += n;
}
break;
}
else if (cmp < 0)
StopLow = sptr + 1;
else
StopHigh = sptr;
}
if (StopLow >= StopHigh)
{ /* not found */
if (stat->weight == 0 || check_weight(txt, wptr, stat->weight) != 0)
{
WordEntryMark *mark = (WordEntryMark*)palloc(sizeof(WordEntryMark));
mark->newentry = wptr; nbit = 1 << nbit;
mark->pos = StopLow; offset = (nbit - txt->size) / 2;
newentries = lappend( newentries, mark );
} insertStatEntry( persistentContext, stat, txt, (nbit >> 1) - offset );
} chooseNextStatEntry(persistentContext, stat, txt, 0, nbit, offset);
wptr++;
}
if (list_length(newentries) == 0) return stat;
{ /* no new words */
if (txt != (TSVector) DatumGetPointer(data))
pfree(txt);
return stat;
}
newstat = formstat(stat, txt, newentries);
list_free(newentries);
if (txt != (TSVector) DatumGetPointer(data))
pfree(txt);
return newstat;
} }
static void static void
ts_setup_firstcall(FunctionCallInfo fcinfo, FuncCallContext *funcctx, ts_setup_firstcall(FunctionCallInfo fcinfo, FuncCallContext *funcctx,
tsstat *stat) TSVectorStat *stat)
{ {
TupleDesc tupdesc; TupleDesc tupdesc;
MemoryContext oldcontext; MemoryContext oldcontext;
StatStorage *st; StatEntry *node;
funcctx->user_fctx = (void *) stat;
oldcontext = MemoryContextSwitchTo(funcctx->multi_call_memory_ctx); oldcontext = MemoryContextSwitchTo(funcctx->multi_call_memory_ctx);
st = palloc(sizeof(StatStorage));
st->cur = 0; stat->stack = palloc0(sizeof(StatEntry *) * (stat->maxdepth + 1));
st->stat = palloc(VARSIZE(stat)); stat->stackpos = 0;
memcpy(st->stat, stat, VARSIZE(stat));
funcctx->user_fctx = (void *) st; node = stat->root;
/* find leftmost value */
for (;;)
{
stat->stack[ stat->stackpos ] = node;
if (node->left)
{
stat->stackpos++;
node = node->left;
}
else
break;
}
tupdesc = CreateTemplateTupleDesc(3, false); tupdesc = CreateTemplateTupleDesc(3, false);
TupleDescInitEntry(tupdesc, (AttrNumber) 1, "word", TupleDescInitEntry(tupdesc, (AttrNumber) 1, "word",
...@@ -1024,26 +980,72 @@ ts_setup_firstcall(FunctionCallInfo fcinfo, FuncCallContext *funcctx, ...@@ -1024,26 +980,72 @@ ts_setup_firstcall(FunctionCallInfo fcinfo, FuncCallContext *funcctx,
MemoryContextSwitchTo(oldcontext); MemoryContextSwitchTo(oldcontext);
} }
static StatEntry *
walkStatEntryTree(TSVectorStat *stat)
{
StatEntry *node = stat->stack[ stat->stackpos ];
if ( node == NULL )
return NULL;
if ( node->ndoc != 0 )
{
/* return entry itself: we already was at left sublink */
return node;
}
else if (node->right && node->right != stat->stack[stat->stackpos + 1])
{
/* go on right sublink */
stat->stackpos++;
node = node->right;
/* find most-left value */
for (;;)
{
stat->stack[stat->stackpos] = node;
if (node->left)
{
stat->stackpos++;
node = node->left;
}
else
break;
}
}
else
{
/* we already return all left subtree, itself and right subtree */
if (stat->stackpos == 0)
return NULL;
stat->stackpos--;
return walkStatEntryTree(stat);
}
return node;
}
static Datum static Datum
ts_process_call(FuncCallContext *funcctx) ts_process_call(FuncCallContext *funcctx)
{ {
StatStorage *st; TSVectorStat *st;
StatEntry *entry;
st = (TSVectorStat *) funcctx->user_fctx;
st = (StatStorage *) funcctx->user_fctx; entry = walkStatEntryTree(st);
if (st->cur < st->stat->size) if (entry != NULL)
{ {
Datum result; Datum result;
char *values[3]; char *values[3];
char ndoc[16]; char ndoc[16];
char nentry[16]; char nentry[16];
StatEntry *entry = STATPTR(st->stat) + st->cur;
HeapTuple tuple; HeapTuple tuple;
values[0] = palloc(entry->len + 1); values[0] = palloc(entry->lenlexeme + 1);
memcpy(values[0], STATSTRPTR(st->stat) + entry->pos, entry->len); memcpy(values[0], entry->lexeme, entry->lenlexeme);
(values[0])[entry->len] = '\0'; (values[0])[entry->lenlexeme] = '\0';
sprintf(ndoc, "%d", entry->ndoc); sprintf(ndoc, "%d", entry->ndoc);
values[1] = ndoc; values[1] = ndoc;
sprintf(nentry, "%d", entry->nentry); sprintf(nentry, "%d", entry->nentry);
...@@ -1053,25 +1055,22 @@ ts_process_call(FuncCallContext *funcctx) ...@@ -1053,25 +1055,22 @@ ts_process_call(FuncCallContext *funcctx)
result = HeapTupleGetDatum(tuple); result = HeapTupleGetDatum(tuple);
pfree(values[0]); pfree(values[0]);
st->cur++;
/* mark entry as already visited */
entry->ndoc = 0;
return result; return result;
} }
else
{
pfree(st->stat);
pfree(st);
}
return (Datum) 0; return (Datum) 0;
} }
static tsstat * static TSVectorStat *
ts_stat_sql(text *txt, text *ws) ts_stat_sql(MemoryContext persistentContext, text *txt, text *ws)
{ {
char *query = text_to_cstring(txt); char *query = text_to_cstring(txt);
int i; int i;
tsstat *newstat, TSVectorStat *stat;
*stat;
bool isnull; bool isnull;
Portal portal; Portal portal;
SPIPlanPtr plan; SPIPlanPtr plan;
...@@ -1094,10 +1093,8 @@ ts_stat_sql(text *txt, text *ws) ...@@ -1094,10 +1093,8 @@ ts_stat_sql(text *txt, text *ws)
(errcode(ERRCODE_INVALID_PARAMETER_VALUE), (errcode(ERRCODE_INVALID_PARAMETER_VALUE),
errmsg("ts_stat query must return one tsvector column"))); errmsg("ts_stat query must return one tsvector column")));
stat = palloc(STATHDRSIZE); stat = MemoryContextAllocZero(persistentContext, sizeof(TSVectorStat));
SET_VARSIZE(stat, STATHDRSIZE); stat->maxdepth = 1;
stat->size = 0;
stat->weight = 0;
if (ws) if (ws)
{ {
...@@ -1141,12 +1138,7 @@ ts_stat_sql(text *txt, text *ws) ...@@ -1141,12 +1138,7 @@ ts_stat_sql(text *txt, text *ws)
Datum data = SPI_getbinval(SPI_tuptable->vals[i], SPI_tuptable->tupdesc, 1, &isnull); Datum data = SPI_getbinval(SPI_tuptable->vals[i], SPI_tuptable->tupdesc, 1, &isnull);
if (!isnull) if (!isnull)
{ stat = ts_accum(persistentContext, stat, data);
newstat = ts_accum(stat, data);
if (stat != newstat && stat)
pfree(stat);
stat = newstat;
}
} }
SPI_freetuptable(SPI_tuptable); SPI_freetuptable(SPI_tuptable);
...@@ -1169,12 +1161,12 @@ ts_stat1(PG_FUNCTION_ARGS) ...@@ -1169,12 +1161,12 @@ ts_stat1(PG_FUNCTION_ARGS)
if (SRF_IS_FIRSTCALL()) if (SRF_IS_FIRSTCALL())
{ {
tsstat *stat; TSVectorStat *stat;
text *txt = PG_GETARG_TEXT_P(0); text *txt = PG_GETARG_TEXT_P(0);
funcctx = SRF_FIRSTCALL_INIT(); funcctx = SRF_FIRSTCALL_INIT();
SPI_connect(); SPI_connect();
stat = ts_stat_sql(txt, NULL); stat = ts_stat_sql(funcctx->multi_call_memory_ctx, txt, NULL);
PG_FREE_IF_COPY(txt, 0); PG_FREE_IF_COPY(txt, 0);
ts_setup_firstcall(fcinfo, funcctx, stat); ts_setup_firstcall(fcinfo, funcctx, stat);
SPI_finish(); SPI_finish();
...@@ -1194,13 +1186,13 @@ ts_stat2(PG_FUNCTION_ARGS) ...@@ -1194,13 +1186,13 @@ ts_stat2(PG_FUNCTION_ARGS)
if (SRF_IS_FIRSTCALL()) if (SRF_IS_FIRSTCALL())
{ {
tsstat *stat; TSVectorStat *stat;
text *txt = PG_GETARG_TEXT_P(0); text *txt = PG_GETARG_TEXT_P(0);
text *ws = PG_GETARG_TEXT_P(1); text *ws = PG_GETARG_TEXT_P(1);
funcctx = SRF_FIRSTCALL_INIT(); funcctx = SRF_FIRSTCALL_INIT();
SPI_connect(); SPI_connect();
stat = ts_stat_sql(txt, ws); stat = ts_stat_sql(funcctx->multi_call_memory_ctx, txt, ws);
PG_FREE_IF_COPY(txt, 0); PG_FREE_IF_COPY(txt, 0);
PG_FREE_IF_COPY(ws, 1); PG_FREE_IF_COPY(ws, 1);
ts_setup_firstcall(fcinfo, funcctx, stat); ts_setup_firstcall(fcinfo, funcctx, stat);
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment