Commit c3dfe0fe authored by Tom Lane's avatar Tom Lane

Repair breakage of aggregate FILTER option.

An aggregate's input expression(s) are not supposed to be evaluated
at all for a row where its FILTER test fails ... but commit 8ed3f11b
overlooked that requirement.  Reshuffle so that aggregates having a
filter clause evaluate their arguments separately from those without.
This still gets the benefit of doing only one ExecProject in the
common case of multiple Aggrefs, none of which have filters.

While at it, arrange for filter clauses to be included in the common
ExecProject evaluation, thus perhaps buying a little bit even when
there are filters.

Back-patch to v10 where the bug was introduced.

Discussion: https://postgr.es/m/30065.1508161354@sss.pgh.pa.us
parent 60a1d96e
...@@ -268,14 +268,6 @@ typedef struct AggStatePerTransData ...@@ -268,14 +268,6 @@ typedef struct AggStatePerTransData
*/ */
int numInputs; int numInputs;
/*
* At each input row, we evaluate all argument expressions needed for all
* the aggregates in this Agg node in a single ExecProject call. inputoff
* is the starting index of this aggregate's argument expressions in the
* resulting tuple (in AggState->evalslot).
*/
int inputoff;
/* /*
* Number of aggregated input columns to pass to the transfn. This * Number of aggregated input columns to pass to the transfn. This
* includes the ORDER BY columns for ordered-set aggs, but not for plain * includes the ORDER BY columns for ordered-set aggs, but not for plain
...@@ -283,6 +275,16 @@ typedef struct AggStatePerTransData ...@@ -283,6 +275,16 @@ typedef struct AggStatePerTransData
*/ */
int numTransInputs; int numTransInputs;
/*
* At each input row, we perform a single ExecProject call to evaluate all
* argument expressions that will certainly be needed at this row; that
* includes this aggregate's filter expression if it has one, or its
* regular argument expressions (including any ORDER BY columns) if it
* doesn't. inputoff is the starting index of this aggregate's required
* expressions in the resulting tuple.
*/
int inputoff;
/* Oid of the state transition or combine function */ /* Oid of the state transition or combine function */
Oid transfn_oid; Oid transfn_oid;
...@@ -295,9 +297,8 @@ typedef struct AggStatePerTransData ...@@ -295,9 +297,8 @@ typedef struct AggStatePerTransData
/* Oid of state value's datatype */ /* Oid of state value's datatype */
Oid aggtranstype; Oid aggtranstype;
/* ExprStates of the FILTER and argument expressions. */ /* ExprStates for any direct-argument expressions */
ExprState *aggfilter; /* state of FILTER expression, if any */ List *aggdirectargs;
List *aggdirectargs; /* states of direct-argument expressions */
/* /*
* fmgr lookup data for transition function or combine function. Note in * fmgr lookup data for transition function or combine function. Note in
...@@ -353,20 +354,21 @@ typedef struct AggStatePerTransData ...@@ -353,20 +354,21 @@ typedef struct AggStatePerTransData
transtypeByVal; transtypeByVal;
/* /*
* Stuff for evaluation of aggregate inputs in cases where the aggregate * Stuff for evaluation of aggregate inputs, when they must be evaluated
* requires sorted input. The arguments themselves will be evaluated via * separately because there's a FILTER expression. In such cases we will
* AggState->evalslot/evalproj for all aggregates at once, but we only * create a sortslot and the result will be stored there, whether or not
* want to sort the relevant columns for individual aggregates. * we're actually sorting.
*/ */
TupleDesc sortdesc; /* descriptor of input tuples */ ProjectionInfo *evalproj; /* projection machinery */
/* /*
* Slots for holding the evaluated input arguments. These are set up * Slots for holding the evaluated input arguments. These are set up
* during ExecInitAgg() and then used for each input row requiring * during ExecInitAgg() and then used for each input row requiring either
* processing besides what's done in AggState->evalproj. * FILTER or ORDER BY/DISTINCT processing.
*/ */
TupleTableSlot *sortslot; /* current input tuple */ TupleTableSlot *sortslot; /* current input tuple */
TupleTableSlot *uniqslot; /* used for multi-column DISTINCT */ TupleTableSlot *uniqslot; /* used for multi-column DISTINCT */
TupleDesc sortdesc; /* descriptor of input tuples */
/* /*
* These values are working state that is initialized at the start of an * These values are working state that is initialized at the start of an
...@@ -983,30 +985,36 @@ advance_aggregates(AggState *aggstate, AggStatePerGroup pergroup, AggStatePerGro ...@@ -983,30 +985,36 @@ advance_aggregates(AggState *aggstate, AggStatePerGroup pergroup, AggStatePerGro
int numGroupingSets = Max(aggstate->phase->numsets, 1); int numGroupingSets = Max(aggstate->phase->numsets, 1);
int numHashes = aggstate->num_hashes; int numHashes = aggstate->num_hashes;
int numTrans = aggstate->numtrans; int numTrans = aggstate->numtrans;
TupleTableSlot *slot = aggstate->evalslot; TupleTableSlot *combinedslot;
/* compute input for all aggregates */ /* compute required inputs for all aggregates */
if (aggstate->evalproj) combinedslot = ExecProject(aggstate->combinedproj);
aggstate->evalslot = ExecProject(aggstate->evalproj);
for (transno = 0; transno < numTrans; transno++) for (transno = 0; transno < numTrans; transno++)
{ {
AggStatePerTrans pertrans = &aggstate->pertrans[transno]; AggStatePerTrans pertrans = &aggstate->pertrans[transno];
ExprState *filter = pertrans->aggfilter;
int numTransInputs = pertrans->numTransInputs; int numTransInputs = pertrans->numTransInputs;
int i;
int inputoff = pertrans->inputoff; int inputoff = pertrans->inputoff;
TupleTableSlot *slot;
int i;
/* Skip anything FILTERed out */ /* Skip anything FILTERed out */
if (filter) if (pertrans->aggref->aggfilter)
{ {
Datum res; /* Check the result of the filter expression */
bool isnull; if (combinedslot->tts_isnull[inputoff] ||
!DatumGetBool(combinedslot->tts_values[inputoff]))
res = ExecEvalExprSwitchContext(filter, aggstate->tmpcontext,
&isnull);
if (isnull || !DatumGetBool(res))
continue; continue;
/* Now it's safe to evaluate this agg's arguments */
slot = ExecProject(pertrans->evalproj);
/* There's no offset needed in this slot, of course */
inputoff = 0;
}
else
{
/* arguments are already evaluated into combinedslot @ inputoff */
slot = combinedslot;
} }
if (pertrans->numSortCols > 0) if (pertrans->numSortCols > 0)
...@@ -1040,11 +1048,21 @@ advance_aggregates(AggState *aggstate, AggStatePerGroup pergroup, AggStatePerGro ...@@ -1040,11 +1048,21 @@ advance_aggregates(AggState *aggstate, AggStatePerGroup pergroup, AggStatePerGro
tuplesort_putdatum(pertrans->sortstates[setno], tuplesort_putdatum(pertrans->sortstates[setno],
slot->tts_values[inputoff], slot->tts_values[inputoff],
slot->tts_isnull[inputoff]); slot->tts_isnull[inputoff]);
else if (pertrans->aggref->aggfilter)
{
/*
* When filtering and ordering, we already have a slot
* containing just the argument columns.
*/
Assert(slot == pertrans->sortslot);
tuplesort_puttupleslot(pertrans->sortstates[setno], slot);
}
else else
{ {
/* /*
* Copy slot contents, starting from inputoff, into sort * Copy argument columns from combined slot, starting at
* slot. * inputoff, into sortslot, so that we can store just the
* columns we want.
*/ */
ExecClearTuple(pertrans->sortslot); ExecClearTuple(pertrans->sortslot);
memcpy(pertrans->sortslot->tts_values, memcpy(pertrans->sortslot->tts_values,
...@@ -1053,9 +1071,9 @@ advance_aggregates(AggState *aggstate, AggStatePerGroup pergroup, AggStatePerGro ...@@ -1053,9 +1071,9 @@ advance_aggregates(AggState *aggstate, AggStatePerGroup pergroup, AggStatePerGro
memcpy(pertrans->sortslot->tts_isnull, memcpy(pertrans->sortslot->tts_isnull,
&slot->tts_isnull[inputoff], &slot->tts_isnull[inputoff],
pertrans->numInputs * sizeof(bool)); pertrans->numInputs * sizeof(bool));
pertrans->sortslot->tts_nvalid = pertrans->numInputs;
ExecStoreVirtualTuple(pertrans->sortslot); ExecStoreVirtualTuple(pertrans->sortslot);
tuplesort_puttupleslot(pertrans->sortstates[setno], pertrans->sortslot); tuplesort_puttupleslot(pertrans->sortstates[setno],
pertrans->sortslot);
} }
} }
} }
...@@ -1127,7 +1145,7 @@ combine_aggregates(AggState *aggstate, AggStatePerGroup pergroup) ...@@ -1127,7 +1145,7 @@ combine_aggregates(AggState *aggstate, AggStatePerGroup pergroup)
Assert(aggstate->phase->numsets <= 1); Assert(aggstate->phase->numsets <= 1);
/* compute input for all aggregates */ /* compute input for all aggregates */
slot = ExecProject(aggstate->evalproj); slot = ExecProject(aggstate->combinedproj);
for (transno = 0; transno < numTrans; transno++) for (transno = 0; transno < numTrans; transno++)
{ {
...@@ -2691,6 +2709,8 @@ ExecInitAgg(Agg *node, EState *estate, int eflags) ...@@ -2691,6 +2709,8 @@ ExecInitAgg(Agg *node, EState *estate, int eflags)
int phase; int phase;
int phaseidx; int phaseidx;
List *combined_inputeval; List *combined_inputeval;
TupleDesc combineddesc;
TupleTableSlot *combinedslot;
ListCell *l; ListCell *l;
Bitmapset *all_grouped_cols = NULL; Bitmapset *all_grouped_cols = NULL;
int numGroupingSets = 1; int numGroupingSets = 1;
...@@ -3366,19 +3386,17 @@ ExecInitAgg(Agg *node, EState *estate, int eflags) ...@@ -3366,19 +3386,17 @@ ExecInitAgg(Agg *node, EState *estate, int eflags)
aggstate->numtrans = transno + 1; aggstate->numtrans = transno + 1;
/* /*
* Build a single projection computing the aggregate arguments for all * Build a single projection computing the required arguments for all
* aggregates at once; if there's more than one, that's considerably * aggregates at once; if there's more than one, that's considerably
* faster than doing it separately for each. * faster than doing it separately for each.
* *
* First create a targetlist combining the targetlists of all the * First create a targetlist representing the values to compute.
* per-trans states.
*/ */
combined_inputeval = NIL; combined_inputeval = NIL;
column_offset = 0; column_offset = 0;
for (transno = 0; transno < aggstate->numtrans; transno++) for (transno = 0; transno < aggstate->numtrans; transno++)
{ {
AggStatePerTrans pertrans = &pertransstates[transno]; AggStatePerTrans pertrans = &pertransstates[transno];
ListCell *arg;
/* /*
* Mark this per-trans state with its starting column in the combined * Mark this per-trans state with its starting column in the combined
...@@ -3387,38 +3405,70 @@ ExecInitAgg(Agg *node, EState *estate, int eflags) ...@@ -3387,38 +3405,70 @@ ExecInitAgg(Agg *node, EState *estate, int eflags)
pertrans->inputoff = column_offset; pertrans->inputoff = column_offset;
/* /*
* Adjust resnos in the copied target entries to match the combined * If the aggregate has a FILTER, we can only evaluate the filter
* slot. * expression, not the actual input expressions, during the combined
* eval step --- unless we're ignoring the filter because this node is
* running combinefns not transfns.
*/ */
foreach(arg, pertrans->aggref->args) if (pertrans->aggref->aggfilter &&
!DO_AGGSPLIT_COMBINE(aggstate->aggsplit))
{ {
TargetEntry *source_tle = lfirst_node(TargetEntry, arg);
TargetEntry *tle; TargetEntry *tle;
tle = flatCopyTargetEntry(source_tle); tle = makeTargetEntry(pertrans->aggref->aggfilter,
tle->resno += column_offset; column_offset + 1, NULL, false);
combined_inputeval = lappend(combined_inputeval, tle); combined_inputeval = lappend(combined_inputeval, tle);
column_offset++;
/*
* We'll need separate projection machinery for the real args.
* Arrange to evaluate them into the sortslot previously created.
*/
Assert(pertrans->sortslot);
pertrans->evalproj = ExecBuildProjectionInfo(pertrans->aggref->args,
aggstate->tmpcontext,
pertrans->sortslot,
&aggstate->ss.ps,
NULL);
} }
else
{
/*
* Add agg's input expressions to combined_inputeval, adjusting
* resnos in the copied target entries to match the combined slot.
*/
ListCell *arg;
foreach(arg, pertrans->aggref->args)
{
TargetEntry *source_tle = lfirst_node(TargetEntry, arg);
TargetEntry *tle;
tle = flatCopyTargetEntry(source_tle);
tle->resno += column_offset;
column_offset += list_length(pertrans->aggref->args); combined_inputeval = lappend(combined_inputeval, tle);
}
column_offset += list_length(pertrans->aggref->args);
}
} }
/* Now create a projection for the combined targetlist */ /* Now create a projection for the combined targetlist */
aggstate->evaldesc = ExecTypeFromTL(combined_inputeval, false); combineddesc = ExecTypeFromTL(combined_inputeval, false);
aggstate->evalslot = ExecInitExtraTupleSlot(estate); combinedslot = ExecInitExtraTupleSlot(estate);
aggstate->evalproj = ExecBuildProjectionInfo(combined_inputeval, ExecSetSlotDescriptor(combinedslot, combineddesc);
aggstate->tmpcontext, aggstate->combinedproj = ExecBuildProjectionInfo(combined_inputeval,
aggstate->evalslot, aggstate->tmpcontext,
&aggstate->ss.ps, combinedslot,
NULL); &aggstate->ss.ps,
ExecSetSlotDescriptor(aggstate->evalslot, aggstate->evaldesc); NULL);
/* /*
* Last, check whether any more aggregates got added onto the node while * Last, check whether any more aggregates got added onto the node while
* we processed the expressions for the aggregate arguments (including not * we processed the expressions for the aggregate arguments (including not
* only the regular arguments handled immediately above, but any FILTER * only the regular arguments and FILTER expressions handled immediately
* expressions and direct arguments we might've handled earlier). If so, * above, but any direct arguments we might've handled earlier). If so,
* we have nested aggregate functions, which is semantically nonsensical, * we have nested aggregate functions, which is semantically nonsensical,
* so complain. (This should have been caught by the parser, so we don't * so complain. (This should have been caught by the parser, so we don't
* need to work hard on a helpful error message; but we defend against it * need to work hard on a helpful error message; but we defend against it
...@@ -3483,6 +3533,8 @@ build_pertrans_for_aggref(AggStatePerTrans pertrans, ...@@ -3483,6 +3533,8 @@ build_pertrans_for_aggref(AggStatePerTrans pertrans,
else else
pertrans->numTransInputs = numArguments; pertrans->numTransInputs = numArguments;
/* inputoff and evalproj will be set up later, in ExecInitAgg */
/* /*
* When combining states, we have no use at all for the aggregate * When combining states, we have no use at all for the aggregate
* function's transfn. Instead we use the combinefn. In this case, the * function's transfn. Instead we use the combinefn. In this case, the
...@@ -3598,9 +3650,7 @@ build_pertrans_for_aggref(AggStatePerTrans pertrans, ...@@ -3598,9 +3650,7 @@ build_pertrans_for_aggref(AggStatePerTrans pertrans,
} }
/* Initialize the input and FILTER expressions */ /* Initialize any direct-argument expressions */
pertrans->aggfilter = ExecInitExpr(aggref->aggfilter,
(PlanState *) aggstate);
pertrans->aggdirectargs = ExecInitExprList(aggref->aggdirectargs, pertrans->aggdirectargs = ExecInitExprList(aggref->aggdirectargs,
(PlanState *) aggstate); (PlanState *) aggstate);
...@@ -3634,16 +3684,20 @@ build_pertrans_for_aggref(AggStatePerTrans pertrans, ...@@ -3634,16 +3684,20 @@ build_pertrans_for_aggref(AggStatePerTrans pertrans,
pertrans->numSortCols = numSortCols; pertrans->numSortCols = numSortCols;
pertrans->numDistinctCols = numDistinctCols; pertrans->numDistinctCols = numDistinctCols;
if (numSortCols > 0) /*
* If we have either sorting or filtering to do, create a tupledesc and
* slot corresponding to the aggregated inputs (including sort
* expressions) of the agg.
*/
if (numSortCols > 0 || aggref->aggfilter)
{ {
/*
* Get a tupledesc and slot corresponding to the aggregated inputs
* (including sort expressions) of the agg.
*/
pertrans->sortdesc = ExecTypeFromTL(aggref->args, false); pertrans->sortdesc = ExecTypeFromTL(aggref->args, false);
pertrans->sortslot = ExecInitExtraTupleSlot(estate); pertrans->sortslot = ExecInitExtraTupleSlot(estate);
ExecSetSlotDescriptor(pertrans->sortslot, pertrans->sortdesc); ExecSetSlotDescriptor(pertrans->sortslot, pertrans->sortdesc);
}
if (numSortCols > 0)
{
/* /*
* We don't implement DISTINCT or ORDER BY aggs in the HASHED case * We don't implement DISTINCT or ORDER BY aggs in the HASHED case
* (yet) * (yet)
......
...@@ -1830,10 +1830,8 @@ typedef struct AggState ...@@ -1830,10 +1830,8 @@ typedef struct AggState
int num_hashes; int num_hashes;
AggStatePerHash perhash; AggStatePerHash perhash;
AggStatePerGroup *hash_pergroup; /* array of per-group pointers */ AggStatePerGroup *hash_pergroup; /* array of per-group pointers */
/* support for evaluation of agg inputs */ /* support for evaluation of agg input expressions: */
TupleTableSlot *evalslot; /* slot for agg inputs */ ProjectionInfo *combinedproj; /* projection machinery */
ProjectionInfo *evalproj; /* projection machinery */
TupleDesc evaldesc; /* descriptor of input tuples */
} AggState; } AggState;
/* ---------------- /* ----------------
......
...@@ -1388,6 +1388,12 @@ select min(unique1) filter (where unique1 > 100) from tenk1; ...@@ -1388,6 +1388,12 @@ select min(unique1) filter (where unique1 > 100) from tenk1;
101 101
(1 row) (1 row)
select sum(1/ten) filter (where ten > 0) from tenk1;
sum
------
1000
(1 row)
select ten, sum(distinct four) filter (where four::text ~ '123') from onek a select ten, sum(distinct four) filter (where four::text ~ '123') from onek a
group by ten; group by ten;
ten | sum ten | sum
......
...@@ -524,6 +524,8 @@ drop table bytea_test_table; ...@@ -524,6 +524,8 @@ drop table bytea_test_table;
select min(unique1) filter (where unique1 > 100) from tenk1; select min(unique1) filter (where unique1 > 100) from tenk1;
select sum(1/ten) filter (where ten > 0) from tenk1;
select ten, sum(distinct four) filter (where four::text ~ '123') from onek a select ten, sum(distinct four) filter (where four::text ~ '123') from onek a
group by ten; group by ten;
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment