Commit 804163bc authored by Heikki Linnakangas's avatar Heikki Linnakangas

Share transition state between different aggregates when possible.

If there are two different aggregates in the query with same inputs, and
the aggregates have the same initial condition and transition function,
only calculate the state value once, and only call the final functions
separately. For example, AVG(x) and SUM(x) aggregates have the same
transition function, which accumulates the sum and number of input tuples.
For a query like "SELECT AVG(x), SUM(x) FROM x", we can therefore
accumulate the state function only once, which gives a nice speedup.

David Rowley, reviewed and edited by me.
parent dee0200f
...@@ -4487,35 +4487,15 @@ ExecInitExpr(Expr *node, PlanState *parent) ...@@ -4487,35 +4487,15 @@ ExecInitExpr(Expr *node, PlanState *parent)
break; break;
case T_Aggref: case T_Aggref:
{ {
Aggref *aggref = (Aggref *) node;
AggrefExprState *astate = makeNode(AggrefExprState); AggrefExprState *astate = makeNode(AggrefExprState);
astate->xprstate.evalfunc = (ExprStateEvalFunc) ExecEvalAggref; astate->xprstate.evalfunc = (ExprStateEvalFunc) ExecEvalAggref;
if (parent && IsA(parent, AggState)) if (parent && IsA(parent, AggState))
{ {
AggState *aggstate = (AggState *) parent; AggState *aggstate = (AggState *) parent;
int naggs;
aggstate->aggs = lcons(astate, aggstate->aggs); aggstate->aggs = lcons(astate, aggstate->aggs);
naggs = ++aggstate->numaggs; aggstate->numaggs++;
astate->aggdirectargs = (List *) ExecInitExpr((Expr *) aggref->aggdirectargs,
parent);
astate->args = (List *) ExecInitExpr((Expr *) aggref->args,
parent);
astate->aggfilter = ExecInitExpr(aggref->aggfilter,
parent);
/*
* Complain if the aggregate's arguments contain any
* aggregates; nested agg functions are semantically
* nonsensical. (This should have been caught earlier,
* but we defend against it here anyway.)
*/
if (naggs != aggstate->numaggs)
ereport(ERROR,
(errcode(ERRCODE_GROUPING_ERROR),
errmsg("aggregate function calls cannot be nested")));
} }
else else
{ {
......
...@@ -152,17 +152,28 @@ ...@@ -152,17 +152,28 @@
/* /*
* AggStatePerAggData - per-aggregate working state for the Agg scan * AggStatePerTransData - per aggregate state value information
*
* Working state for updating the aggregate's state value, by calling the
* transition function with an input row. This struct does not store the
* information needed to produce the final aggregate result from the transition
* state, that's stored in AggStatePerAggData instead. This separation allows
* multiple aggregate results to be produced from a single state value.
*/ */
typedef struct AggStatePerAggData typedef struct AggStatePerTransData
{ {
/* /*
* These values are set up during ExecInitAgg() and do not change * These values are set up during ExecInitAgg() and do not change
* thereafter: * thereafter:
*/ */
/* Links to Aggref expr and state nodes this working state is for */ /*
AggrefExprState *aggrefstate; * Link to an Aggref expr this state value is for.
*
* There can be multiple Aggref's sharing the same state value, as long as
* the inputs and transition function are identical. This points to the
* first one of them.
*/
Aggref *aggref; Aggref *aggref;
/* /*
...@@ -186,25 +197,22 @@ typedef struct AggStatePerAggData ...@@ -186,25 +197,22 @@ typedef struct AggStatePerAggData
*/ */
int numTransInputs; int numTransInputs;
/* /* Oid of the state transition function */
* Number of arguments to pass to the finalfn. This is always at least 1
* (the transition state value) plus any ordered-set direct args. If the
* finalfn wants extra args then we pass nulls corresponding to the
* aggregated input columns.
*/
int numFinalArgs;
/* Oids of transfer functions */
Oid transfn_oid; Oid transfn_oid;
Oid finalfn_oid; /* may be InvalidOid */
/* Oid of state value's datatype */
Oid aggtranstype;
/* ExprStates of the FILTER and argument expressions. */
ExprState *aggfilter; /* state of FILTER expression, if any */
List *args; /* states of aggregated-argument expressions */
List *aggdirectargs; /* states of direct-argument expressions */
/* /*
* fmgr lookup data for transfer functions --- only valid when * fmgr lookup data for transition function. Note in particular that the
* corresponding oid is not InvalidOid. Note in particular that fn_strict * fn_strict flag is kept here.
* flags are kept here.
*/ */
FmgrInfo transfn; FmgrInfo transfn;
FmgrInfo finalfn;
/* Input collation derived for aggregate */ /* Input collation derived for aggregate */
Oid aggCollation; Oid aggCollation;
...@@ -236,17 +244,15 @@ typedef struct AggStatePerAggData ...@@ -236,17 +244,15 @@ typedef struct AggStatePerAggData
bool initValueIsNull; bool initValueIsNull;
/* /*
* We need the len and byval info for the agg's input, result, and * We need the len and byval info for the agg's input and transition data
* transition data types in order to know how to copy/delete values. * types in order to know how to copy/delete values.
* *
* Note that the info for the input type is used only when handling * Note that the info for the input type is used only when handling
* DISTINCT aggs with just one argument, so there is only one input type. * DISTINCT aggs with just one argument, so there is only one input type.
*/ */
int16 inputtypeLen, int16 inputtypeLen,
resulttypeLen,
transtypeLen; transtypeLen;
bool inputtypeByVal, bool inputtypeByVal,
resulttypeByVal,
transtypeByVal; transtypeByVal;
/* /*
...@@ -288,6 +294,54 @@ typedef struct AggStatePerAggData ...@@ -288,6 +294,54 @@ typedef struct AggStatePerAggData
* worth the extra space consumption. * worth the extra space consumption.
*/ */
FunctionCallInfoData transfn_fcinfo; FunctionCallInfoData transfn_fcinfo;
} AggStatePerTransData;
/*
* AggStatePerAggData - per-aggregate information
*
* This contains the information needed to call the final function, to produce
* a final aggregate result from the state value. If there are multiple
* identical Aggrefs in the query, they can all share the same per-agg data.
*
* These values are set up during ExecInitAgg() and do not change thereafter.
*/
typedef struct AggStatePerAggData
{
/*
* Link to an Aggref expr this state value is for.
*
* There can be multiple identical Aggref's sharing the same per-agg. This
* points to the first one of them.
*/
Aggref *aggref;
/* index to the state value which this agg should use */
int transno;
/* Optional Oid of final function (may be InvalidOid) */
Oid finalfn_oid;
/*
* fmgr lookup data for final function --- only valid when finalfn_oid oid
* is not InvalidOid.
*/
FmgrInfo finalfn;
/*
* Number of arguments to pass to the finalfn. This is always at least 1
* (the transition state value) plus any ordered-set direct args. If the
* finalfn wants extra args then we pass nulls corresponding to the
* aggregated input columns.
*/
int numFinalArgs;
/*
* We need the len and byval info for the agg's result data type in order
* to know how to copy/delete values.
*/
int16 resulttypeLen;
bool resulttypeByVal;
} AggStatePerAggData; } AggStatePerAggData;
/* /*
...@@ -358,25 +412,23 @@ typedef struct AggHashEntryData ...@@ -358,25 +412,23 @@ typedef struct AggHashEntryData
AggStatePerGroupData pergroup[FLEXIBLE_ARRAY_MEMBER]; AggStatePerGroupData pergroup[FLEXIBLE_ARRAY_MEMBER];
} AggHashEntryData; } AggHashEntryData;
static void initialize_phase(AggState *aggstate, int newphase); static void initialize_phase(AggState *aggstate, int newphase);
static TupleTableSlot *fetch_input_tuple(AggState *aggstate); static TupleTableSlot *fetch_input_tuple(AggState *aggstate);
static void initialize_aggregates(AggState *aggstate, static void initialize_aggregates(AggState *aggstate,
AggStatePerAgg peragg,
AggStatePerGroup pergroup, AggStatePerGroup pergroup,
int numReset); int numReset);
static void advance_transition_function(AggState *aggstate, static void advance_transition_function(AggState *aggstate,
AggStatePerAgg peraggstate, AggStatePerTrans pertrans,
AggStatePerGroup pergroupstate); AggStatePerGroup pergroupstate);
static void advance_aggregates(AggState *aggstate, AggStatePerGroup pergroup); static void advance_aggregates(AggState *aggstate, AggStatePerGroup pergroup);
static void process_ordered_aggregate_single(AggState *aggstate, static void process_ordered_aggregate_single(AggState *aggstate,
AggStatePerAgg peraggstate, AggStatePerTrans pertrans,
AggStatePerGroup pergroupstate); AggStatePerGroup pergroupstate);
static void process_ordered_aggregate_multi(AggState *aggstate, static void process_ordered_aggregate_multi(AggState *aggstate,
AggStatePerAgg peraggstate, AggStatePerTrans pertrans,
AggStatePerGroup pergroupstate); AggStatePerGroup pergroupstate);
static void finalize_aggregate(AggState *aggstate, static void finalize_aggregate(AggState *aggstate,
AggStatePerAgg peraggstate, AggStatePerAgg peragg,
AggStatePerGroup pergroupstate, AggStatePerGroup pergroupstate,
Datum *resultVal, bool *resultIsNull); Datum *resultVal, bool *resultIsNull);
static void prepare_projection_slot(AggState *aggstate, static void prepare_projection_slot(AggState *aggstate,
...@@ -396,6 +448,17 @@ static TupleTableSlot *agg_retrieve_direct(AggState *aggstate); ...@@ -396,6 +448,17 @@ static TupleTableSlot *agg_retrieve_direct(AggState *aggstate);
static void agg_fill_hash_table(AggState *aggstate); static void agg_fill_hash_table(AggState *aggstate);
static TupleTableSlot *agg_retrieve_hash_table(AggState *aggstate); static TupleTableSlot *agg_retrieve_hash_table(AggState *aggstate);
static Datum GetAggInitVal(Datum textInitVal, Oid transtype); static Datum GetAggInitVal(Datum textInitVal, Oid transtype);
static void build_pertrans_for_aggref(AggStatePerTrans pertrans,
AggState *aggsate, EState *estate,
Aggref *aggref, Oid aggtransfn, Oid aggtranstype,
Datum initValue, bool initValueIsNull,
Oid *inputTypes, int numArguments);
static int find_compatible_peragg(Aggref *newagg, AggState *aggstate,
int lastaggno, List **same_input_transnos);
static int find_compatible_pertrans(AggState *aggstate, Aggref *newagg,
Oid aggtransfn, Oid aggtranstype,
Datum initValue, bool initValueIsNull,
List *possible_matches);
/* /*
...@@ -498,20 +561,20 @@ fetch_input_tuple(AggState *aggstate) ...@@ -498,20 +561,20 @@ fetch_input_tuple(AggState *aggstate)
* When called, CurrentMemoryContext should be the per-query context. * When called, CurrentMemoryContext should be the per-query context.
*/ */
static void static void
initialize_aggregate(AggState *aggstate, AggStatePerAgg peraggstate, initialize_aggregate(AggState *aggstate, AggStatePerTrans pertrans,
AggStatePerGroup pergroupstate) AggStatePerGroup pergroupstate)
{ {
/* /*
* Start a fresh sort operation for each DISTINCT/ORDER BY aggregate. * Start a fresh sort operation for each DISTINCT/ORDER BY aggregate.
*/ */
if (peraggstate->numSortCols > 0) if (pertrans->numSortCols > 0)
{ {
/* /*
* In case of rescan, maybe there could be an uncompleted sort * In case of rescan, maybe there could be an uncompleted sort
* operation? Clean it up if so. * operation? Clean it up if so.
*/ */
if (peraggstate->sortstates[aggstate->current_set]) if (pertrans->sortstates[aggstate->current_set])
tuplesort_end(peraggstate->sortstates[aggstate->current_set]); tuplesort_end(pertrans->sortstates[aggstate->current_set]);
/* /*
...@@ -519,21 +582,21 @@ initialize_aggregate(AggState *aggstate, AggStatePerAgg peraggstate, ...@@ -519,21 +582,21 @@ initialize_aggregate(AggState *aggstate, AggStatePerAgg peraggstate,
* otherwise sort the full tuple. (See comments for * otherwise sort the full tuple. (See comments for
* process_ordered_aggregate_single.) * process_ordered_aggregate_single.)
*/ */
if (peraggstate->numInputs == 1) if (pertrans->numInputs == 1)
peraggstate->sortstates[aggstate->current_set] = pertrans->sortstates[aggstate->current_set] =
tuplesort_begin_datum(peraggstate->evaldesc->attrs[0]->atttypid, tuplesort_begin_datum(pertrans->evaldesc->attrs[0]->atttypid,
peraggstate->sortOperators[0], pertrans->sortOperators[0],
peraggstate->sortCollations[0], pertrans->sortCollations[0],
peraggstate->sortNullsFirst[0], pertrans->sortNullsFirst[0],
work_mem, false); work_mem, false);
else else
peraggstate->sortstates[aggstate->current_set] = pertrans->sortstates[aggstate->current_set] =
tuplesort_begin_heap(peraggstate->evaldesc, tuplesort_begin_heap(pertrans->evaldesc,
peraggstate->numSortCols, pertrans->numSortCols,
peraggstate->sortColIdx, pertrans->sortColIdx,
peraggstate->sortOperators, pertrans->sortOperators,
peraggstate->sortCollations, pertrans->sortCollations,
peraggstate->sortNullsFirst, pertrans->sortNullsFirst,
work_mem, false); work_mem, false);
} }
...@@ -543,20 +606,20 @@ initialize_aggregate(AggState *aggstate, AggStatePerAgg peraggstate, ...@@ -543,20 +606,20 @@ initialize_aggregate(AggState *aggstate, AggStatePerAgg peraggstate,
* Note that when the initial value is pass-by-ref, we must copy it (into * Note that when the initial value is pass-by-ref, we must copy it (into
* the aggcontext) since we will pfree the transValue later. * the aggcontext) since we will pfree the transValue later.
*/ */
if (peraggstate->initValueIsNull) if (pertrans->initValueIsNull)
pergroupstate->transValue = peraggstate->initValue; pergroupstate->transValue = pertrans->initValue;
else else
{ {
MemoryContext oldContext; MemoryContext oldContext;
oldContext = MemoryContextSwitchTo( oldContext = MemoryContextSwitchTo(
aggstate->aggcontexts[aggstate->current_set]->ecxt_per_tuple_memory); aggstate->aggcontexts[aggstate->current_set]->ecxt_per_tuple_memory);
pergroupstate->transValue = datumCopy(peraggstate->initValue, pergroupstate->transValue = datumCopy(pertrans->initValue,
peraggstate->transtypeByVal, pertrans->transtypeByVal,
peraggstate->transtypeLen); pertrans->transtypeLen);
MemoryContextSwitchTo(oldContext); MemoryContextSwitchTo(oldContext);
} }
pergroupstate->transValueIsNull = peraggstate->initValueIsNull; pergroupstate->transValueIsNull = pertrans->initValueIsNull;
/* /*
* If the initial value for the transition state doesn't exist in the * If the initial value for the transition state doesn't exist in the
...@@ -565,11 +628,11 @@ initialize_aggregate(AggState *aggstate, AggStatePerAgg peraggstate, ...@@ -565,11 +628,11 @@ initialize_aggregate(AggState *aggstate, AggStatePerAgg peraggstate,
* aggregates like max() and min().) The noTransValue flag signals that we * aggregates like max() and min().) The noTransValue flag signals that we
* still need to do this. * still need to do this.
*/ */
pergroupstate->noTransValue = peraggstate->initValueIsNull; pergroupstate->noTransValue = pertrans->initValueIsNull;
} }
/* /*
* Initialize all aggregates for a new group of input values. * Initialize all aggregate transition states for a new group of input values.
* *
* If there are multiple grouping sets, we initialize only the first numReset * If there are multiple grouping sets, we initialize only the first numReset
* of them (the grouping sets are ordered so that the most specific one, which * of them (the grouping sets are ordered so that the most specific one, which
...@@ -580,61 +643,61 @@ initialize_aggregate(AggState *aggstate, AggStatePerAgg peraggstate, ...@@ -580,61 +643,61 @@ initialize_aggregate(AggState *aggstate, AggStatePerAgg peraggstate,
*/ */
static void static void
initialize_aggregates(AggState *aggstate, initialize_aggregates(AggState *aggstate,
AggStatePerAgg peragg,
AggStatePerGroup pergroup, AggStatePerGroup pergroup,
int numReset) int numReset)
{ {
int aggno; int transno;
int numGroupingSets = Max(aggstate->phase->numsets, 1); int numGroupingSets = Max(aggstate->phase->numsets, 1);
int setno = 0; int setno = 0;
AggStatePerTrans transstates = aggstate->pertrans;
if (numReset < 1) if (numReset < 1)
numReset = numGroupingSets; numReset = numGroupingSets;
for (aggno = 0; aggno < aggstate->numaggs; aggno++) for (transno = 0; transno < aggstate->numtrans; transno++)
{ {
AggStatePerAgg peraggstate = &peragg[aggno]; AggStatePerTrans pertrans = &transstates[transno];
for (setno = 0; setno < numReset; setno++) for (setno = 0; setno < numReset; setno++)
{ {
AggStatePerGroup pergroupstate; AggStatePerGroup pergroupstate;
pergroupstate = &pergroup[aggno + (setno * (aggstate->numaggs))]; pergroupstate = &pergroup[transno + (setno * (aggstate->numtrans))];
aggstate->current_set = setno; aggstate->current_set = setno;
initialize_aggregate(aggstate, peraggstate, pergroupstate); initialize_aggregate(aggstate, pertrans, pergroupstate);
} }
} }
} }
/* /*
* Given new input value(s), advance the transition function of one aggregate * Given new input value(s), advance the transition function of one aggregate
* within one grouping set only (already set in aggstate->current_set) * state within one grouping set only (already set in aggstate->current_set)
* *
* The new values (and null flags) have been preloaded into argument positions * The new values (and null flags) have been preloaded into argument positions
* 1 and up in peraggstate->transfn_fcinfo, so that we needn't copy them again * 1 and up in pertrans->transfn_fcinfo, so that we needn't copy them again to
* to pass to the transition function. We also expect that the static fields * pass to the transition function. We also expect that the static fields of
* of the fcinfo are already initialized; that was done by ExecInitAgg(). * the fcinfo are already initialized; that was done by ExecInitAgg().
* *
* It doesn't matter which memory context this is called in. * It doesn't matter which memory context this is called in.
*/ */
static void static void
advance_transition_function(AggState *aggstate, advance_transition_function(AggState *aggstate,
AggStatePerAgg peraggstate, AggStatePerTrans pertrans,
AggStatePerGroup pergroupstate) AggStatePerGroup pergroupstate)
{ {
FunctionCallInfo fcinfo = &peraggstate->transfn_fcinfo; FunctionCallInfo fcinfo = &pertrans->transfn_fcinfo;
MemoryContext oldContext; MemoryContext oldContext;
Datum newVal; Datum newVal;
if (peraggstate->transfn.fn_strict) if (pertrans->transfn.fn_strict)
{ {
/* /*
* For a strict transfn, nothing happens when there's a NULL input; we * For a strict transfn, nothing happens when there's a NULL input; we
* just keep the prior transValue. * just keep the prior transValue.
*/ */
int numTransInputs = peraggstate->numTransInputs; int numTransInputs = pertrans->numTransInputs;
int i; int i;
for (i = 1; i <= numTransInputs; i++) for (i = 1; i <= numTransInputs; i++)
...@@ -656,8 +719,8 @@ advance_transition_function(AggState *aggstate, ...@@ -656,8 +719,8 @@ advance_transition_function(AggState *aggstate,
oldContext = MemoryContextSwitchTo( oldContext = MemoryContextSwitchTo(
aggstate->aggcontexts[aggstate->current_set]->ecxt_per_tuple_memory); aggstate->aggcontexts[aggstate->current_set]->ecxt_per_tuple_memory);
pergroupstate->transValue = datumCopy(fcinfo->arg[1], pergroupstate->transValue = datumCopy(fcinfo->arg[1],
peraggstate->transtypeByVal, pertrans->transtypeByVal,
peraggstate->transtypeLen); pertrans->transtypeLen);
pergroupstate->transValueIsNull = false; pergroupstate->transValueIsNull = false;
pergroupstate->noTransValue = false; pergroupstate->noTransValue = false;
MemoryContextSwitchTo(oldContext); MemoryContextSwitchTo(oldContext);
...@@ -678,8 +741,8 @@ advance_transition_function(AggState *aggstate, ...@@ -678,8 +741,8 @@ advance_transition_function(AggState *aggstate,
/* We run the transition functions in per-input-tuple memory context */ /* We run the transition functions in per-input-tuple memory context */
oldContext = MemoryContextSwitchTo(aggstate->tmpcontext->ecxt_per_tuple_memory); oldContext = MemoryContextSwitchTo(aggstate->tmpcontext->ecxt_per_tuple_memory);
/* set up aggstate->curperagg for AggGetAggref() */ /* set up aggstate->curpertrans for AggGetAggref() */
aggstate->curperagg = peraggstate; aggstate->curpertrans = pertrans;
/* /*
* OK to call the transition function * OK to call the transition function
...@@ -690,22 +753,22 @@ advance_transition_function(AggState *aggstate, ...@@ -690,22 +753,22 @@ advance_transition_function(AggState *aggstate,
newVal = FunctionCallInvoke(fcinfo); newVal = FunctionCallInvoke(fcinfo);
aggstate->curperagg = NULL; aggstate->curpertrans = NULL;
/* /*
* If pass-by-ref datatype, must copy the new value into aggcontext and * If pass-by-ref datatype, must copy the new value into aggcontext and
* pfree the prior transValue. But if transfn returned a pointer to its * pfree the prior transValue. But if transfn returned a pointer to its
* first input, we don't need to do anything. * first input, we don't need to do anything.
*/ */
if (!peraggstate->transtypeByVal && if (!pertrans->transtypeByVal &&
DatumGetPointer(newVal) != DatumGetPointer(pergroupstate->transValue)) DatumGetPointer(newVal) != DatumGetPointer(pergroupstate->transValue))
{ {
if (!fcinfo->isnull) if (!fcinfo->isnull)
{ {
MemoryContextSwitchTo(aggstate->aggcontexts[aggstate->current_set]->ecxt_per_tuple_memory); MemoryContextSwitchTo(aggstate->aggcontexts[aggstate->current_set]->ecxt_per_tuple_memory);
newVal = datumCopy(newVal, newVal = datumCopy(newVal,
peraggstate->transtypeByVal, pertrans->transtypeByVal,
peraggstate->transtypeLen); pertrans->transtypeLen);
} }
if (!pergroupstate->transValueIsNull) if (!pergroupstate->transValueIsNull)
pfree(DatumGetPointer(pergroupstate->transValue)); pfree(DatumGetPointer(pergroupstate->transValue));
...@@ -718,26 +781,26 @@ advance_transition_function(AggState *aggstate, ...@@ -718,26 +781,26 @@ advance_transition_function(AggState *aggstate,
} }
/* /*
* Advance all the aggregates for one input tuple. The input tuple * Advance each aggregate transition state for one input tuple. The input
* has been stored in tmpcontext->ecxt_outertuple, so that it is accessible * tuple has been stored in tmpcontext->ecxt_outertuple, so that it is
* to ExecEvalExpr. pergroup is the array of per-group structs to use * accessible to ExecEvalExpr. pergroup is the array of per-group structs to
* (this might be in a hashtable entry). * use (this might be in a hashtable entry).
* *
* When called, CurrentMemoryContext should be the per-query context. * When called, CurrentMemoryContext should be the per-query context.
*/ */
static void static void
advance_aggregates(AggState *aggstate, AggStatePerGroup pergroup) advance_aggregates(AggState *aggstate, AggStatePerGroup pergroup)
{ {
int aggno; int transno;
int setno = 0; int setno = 0;
int numGroupingSets = Max(aggstate->phase->numsets, 1); int numGroupingSets = Max(aggstate->phase->numsets, 1);
int numAggs = aggstate->numaggs; int numTrans = aggstate->numtrans;
for (aggno = 0; aggno < numAggs; aggno++) for (transno = 0; transno < numTrans; transno++)
{ {
AggStatePerAgg peraggstate = &aggstate->peragg[aggno]; AggStatePerTrans pertrans = &aggstate->pertrans[transno];
ExprState *filter = peraggstate->aggrefstate->aggfilter; ExprState *filter = pertrans->aggfilter;
int numTransInputs = peraggstate->numTransInputs; int numTransInputs = pertrans->numTransInputs;
int i; int i;
TupleTableSlot *slot; TupleTableSlot *slot;
...@@ -754,12 +817,12 @@ advance_aggregates(AggState *aggstate, AggStatePerGroup pergroup) ...@@ -754,12 +817,12 @@ advance_aggregates(AggState *aggstate, AggStatePerGroup pergroup)
} }
/* Evaluate the current input expressions for this aggregate */ /* Evaluate the current input expressions for this aggregate */
slot = ExecProject(peraggstate->evalproj, NULL); slot = ExecProject(pertrans->evalproj, NULL);
if (peraggstate->numSortCols > 0) if (pertrans->numSortCols > 0)
{ {
/* DISTINCT and/or ORDER BY case */ /* DISTINCT and/or ORDER BY case */
Assert(slot->tts_nvalid == peraggstate->numInputs); Assert(slot->tts_nvalid == pertrans->numInputs);
/* /*
* If the transfn is strict, we want to check for nullity before * If the transfn is strict, we want to check for nullity before
...@@ -768,7 +831,7 @@ advance_aggregates(AggState *aggstate, AggStatePerGroup pergroup) ...@@ -768,7 +831,7 @@ advance_aggregates(AggState *aggstate, AggStatePerGroup pergroup)
* not numInputs, since nullity in columns used only for sorting * not numInputs, since nullity in columns used only for sorting
* is not relevant here. * is not relevant here.
*/ */
if (peraggstate->transfn.fn_strict) if (pertrans->transfn.fn_strict)
{ {
for (i = 0; i < numTransInputs; i++) for (i = 0; i < numTransInputs; i++)
{ {
...@@ -782,18 +845,18 @@ advance_aggregates(AggState *aggstate, AggStatePerGroup pergroup) ...@@ -782,18 +845,18 @@ advance_aggregates(AggState *aggstate, AggStatePerGroup pergroup)
for (setno = 0; setno < numGroupingSets; setno++) for (setno = 0; setno < numGroupingSets; setno++)
{ {
/* OK, put the tuple into the tuplesort object */ /* OK, put the tuple into the tuplesort object */
if (peraggstate->numInputs == 1) if (pertrans->numInputs == 1)
tuplesort_putdatum(peraggstate->sortstates[setno], tuplesort_putdatum(pertrans->sortstates[setno],
slot->tts_values[0], slot->tts_values[0],
slot->tts_isnull[0]); slot->tts_isnull[0]);
else else
tuplesort_puttupleslot(peraggstate->sortstates[setno], slot); tuplesort_puttupleslot(pertrans->sortstates[setno], slot);
} }
} }
else else
{ {
/* We can apply the transition function immediately */ /* We can apply the transition function immediately */
FunctionCallInfo fcinfo = &peraggstate->transfn_fcinfo; FunctionCallInfo fcinfo = &pertrans->transfn_fcinfo;
/* Load values into fcinfo */ /* Load values into fcinfo */
/* Start from 1, since the 0th arg will be the transition value */ /* Start from 1, since the 0th arg will be the transition value */
...@@ -806,11 +869,11 @@ advance_aggregates(AggState *aggstate, AggStatePerGroup pergroup) ...@@ -806,11 +869,11 @@ advance_aggregates(AggState *aggstate, AggStatePerGroup pergroup)
for (setno = 0; setno < numGroupingSets; setno++) for (setno = 0; setno < numGroupingSets; setno++)
{ {
AggStatePerGroup pergroupstate = &pergroup[aggno + (setno * numAggs)]; AggStatePerGroup pergroupstate = &pergroup[transno + (setno * numTrans)];
aggstate->current_set = setno; aggstate->current_set = setno;
advance_transition_function(aggstate, peraggstate, pergroupstate); advance_transition_function(aggstate, pertrans, pergroupstate);
} }
} }
} }
...@@ -841,7 +904,7 @@ advance_aggregates(AggState *aggstate, AggStatePerGroup pergroup) ...@@ -841,7 +904,7 @@ advance_aggregates(AggState *aggstate, AggStatePerGroup pergroup)
*/ */
static void static void
process_ordered_aggregate_single(AggState *aggstate, process_ordered_aggregate_single(AggState *aggstate,
AggStatePerAgg peraggstate, AggStatePerTrans pertrans,
AggStatePerGroup pergroupstate) AggStatePerGroup pergroupstate)
{ {
Datum oldVal = (Datum) 0; Datum oldVal = (Datum) 0;
...@@ -849,14 +912,14 @@ process_ordered_aggregate_single(AggState *aggstate, ...@@ -849,14 +912,14 @@ process_ordered_aggregate_single(AggState *aggstate,
bool haveOldVal = false; bool haveOldVal = false;
MemoryContext workcontext = aggstate->tmpcontext->ecxt_per_tuple_memory; MemoryContext workcontext = aggstate->tmpcontext->ecxt_per_tuple_memory;
MemoryContext oldContext; MemoryContext oldContext;
bool isDistinct = (peraggstate->numDistinctCols > 0); bool isDistinct = (pertrans->numDistinctCols > 0);
FunctionCallInfo fcinfo = &peraggstate->transfn_fcinfo; FunctionCallInfo fcinfo = &pertrans->transfn_fcinfo;
Datum *newVal; Datum *newVal;
bool *isNull; bool *isNull;
Assert(peraggstate->numDistinctCols < 2); Assert(pertrans->numDistinctCols < 2);
tuplesort_performsort(peraggstate->sortstates[aggstate->current_set]); tuplesort_performsort(pertrans->sortstates[aggstate->current_set]);
/* Load the column into argument 1 (arg 0 will be transition value) */ /* Load the column into argument 1 (arg 0 will be transition value) */
newVal = fcinfo->arg + 1; newVal = fcinfo->arg + 1;
...@@ -868,7 +931,7 @@ process_ordered_aggregate_single(AggState *aggstate, ...@@ -868,7 +931,7 @@ process_ordered_aggregate_single(AggState *aggstate,
* pfree them when they are no longer needed. * pfree them when they are no longer needed.
*/ */
while (tuplesort_getdatum(peraggstate->sortstates[aggstate->current_set], while (tuplesort_getdatum(pertrans->sortstates[aggstate->current_set],
true, newVal, isNull)) true, newVal, isNull))
{ {
/* /*
...@@ -887,18 +950,18 @@ process_ordered_aggregate_single(AggState *aggstate, ...@@ -887,18 +950,18 @@ process_ordered_aggregate_single(AggState *aggstate,
haveOldVal && haveOldVal &&
((oldIsNull && *isNull) || ((oldIsNull && *isNull) ||
(!oldIsNull && !*isNull && (!oldIsNull && !*isNull &&
DatumGetBool(FunctionCall2(&peraggstate->equalfns[0], DatumGetBool(FunctionCall2(&pertrans->equalfns[0],
oldVal, *newVal))))) oldVal, *newVal)))))
{ {
/* equal to prior, so forget this one */ /* equal to prior, so forget this one */
if (!peraggstate->inputtypeByVal && !*isNull) if (!pertrans->inputtypeByVal && !*isNull)
pfree(DatumGetPointer(*newVal)); pfree(DatumGetPointer(*newVal));
} }
else else
{ {
advance_transition_function(aggstate, peraggstate, pergroupstate); advance_transition_function(aggstate, pertrans, pergroupstate);
/* forget the old value, if any */ /* forget the old value, if any */
if (!oldIsNull && !peraggstate->inputtypeByVal) if (!oldIsNull && !pertrans->inputtypeByVal)
pfree(DatumGetPointer(oldVal)); pfree(DatumGetPointer(oldVal));
/* and remember the new one for subsequent equality checks */ /* and remember the new one for subsequent equality checks */
oldVal = *newVal; oldVal = *newVal;
...@@ -909,11 +972,11 @@ process_ordered_aggregate_single(AggState *aggstate, ...@@ -909,11 +972,11 @@ process_ordered_aggregate_single(AggState *aggstate,
MemoryContextSwitchTo(oldContext); MemoryContextSwitchTo(oldContext);
} }
if (!oldIsNull && !peraggstate->inputtypeByVal) if (!oldIsNull && !pertrans->inputtypeByVal)
pfree(DatumGetPointer(oldVal)); pfree(DatumGetPointer(oldVal));
tuplesort_end(peraggstate->sortstates[aggstate->current_set]); tuplesort_end(pertrans->sortstates[aggstate->current_set]);
peraggstate->sortstates[aggstate->current_set] = NULL; pertrans->sortstates[aggstate->current_set] = NULL;
} }
/* /*
...@@ -930,25 +993,25 @@ process_ordered_aggregate_single(AggState *aggstate, ...@@ -930,25 +993,25 @@ process_ordered_aggregate_single(AggState *aggstate,
*/ */
static void static void
process_ordered_aggregate_multi(AggState *aggstate, process_ordered_aggregate_multi(AggState *aggstate,
AggStatePerAgg peraggstate, AggStatePerTrans pertrans,
AggStatePerGroup pergroupstate) AggStatePerGroup pergroupstate)
{ {
MemoryContext workcontext = aggstate->tmpcontext->ecxt_per_tuple_memory; MemoryContext workcontext = aggstate->tmpcontext->ecxt_per_tuple_memory;
FunctionCallInfo fcinfo = &peraggstate->transfn_fcinfo; FunctionCallInfo fcinfo = &pertrans->transfn_fcinfo;
TupleTableSlot *slot1 = peraggstate->evalslot; TupleTableSlot *slot1 = pertrans->evalslot;
TupleTableSlot *slot2 = peraggstate->uniqslot; TupleTableSlot *slot2 = pertrans->uniqslot;
int numTransInputs = peraggstate->numTransInputs; int numTransInputs = pertrans->numTransInputs;
int numDistinctCols = peraggstate->numDistinctCols; int numDistinctCols = pertrans->numDistinctCols;
bool haveOldValue = false; bool haveOldValue = false;
int i; int i;
tuplesort_performsort(peraggstate->sortstates[aggstate->current_set]); tuplesort_performsort(pertrans->sortstates[aggstate->current_set]);
ExecClearTuple(slot1); ExecClearTuple(slot1);
if (slot2) if (slot2)
ExecClearTuple(slot2); ExecClearTuple(slot2);
while (tuplesort_gettupleslot(peraggstate->sortstates[aggstate->current_set], while (tuplesort_gettupleslot(pertrans->sortstates[aggstate->current_set],
true, slot1)) true, slot1))
{ {
/* /*
...@@ -962,8 +1025,8 @@ process_ordered_aggregate_multi(AggState *aggstate, ...@@ -962,8 +1025,8 @@ process_ordered_aggregate_multi(AggState *aggstate,
!haveOldValue || !haveOldValue ||
!execTuplesMatch(slot1, slot2, !execTuplesMatch(slot1, slot2,
numDistinctCols, numDistinctCols,
peraggstate->sortColIdx, pertrans->sortColIdx,
peraggstate->equalfns, pertrans->equalfns,
workcontext)) workcontext))
{ {
/* Load values into fcinfo */ /* Load values into fcinfo */
...@@ -974,7 +1037,7 @@ process_ordered_aggregate_multi(AggState *aggstate, ...@@ -974,7 +1037,7 @@ process_ordered_aggregate_multi(AggState *aggstate,
fcinfo->argnull[i + 1] = slot1->tts_isnull[i]; fcinfo->argnull[i + 1] = slot1->tts_isnull[i];
} }
advance_transition_function(aggstate, peraggstate, pergroupstate); advance_transition_function(aggstate, pertrans, pergroupstate);
if (numDistinctCols > 0) if (numDistinctCols > 0)
{ {
...@@ -997,8 +1060,8 @@ process_ordered_aggregate_multi(AggState *aggstate, ...@@ -997,8 +1060,8 @@ process_ordered_aggregate_multi(AggState *aggstate,
if (slot2) if (slot2)
ExecClearTuple(slot2); ExecClearTuple(slot2);
tuplesort_end(peraggstate->sortstates[aggstate->current_set]); tuplesort_end(pertrans->sortstates[aggstate->current_set]);
peraggstate->sortstates[aggstate->current_set] = NULL; pertrans->sortstates[aggstate->current_set] = NULL;
} }
/* /*
...@@ -1009,10 +1072,14 @@ process_ordered_aggregate_multi(AggState *aggstate, ...@@ -1009,10 +1072,14 @@ process_ordered_aggregate_multi(AggState *aggstate,
* *
* The finalfunction will be run, and the result delivered, in the * The finalfunction will be run, and the result delivered, in the
* output-tuple context; caller's CurrentMemoryContext does not matter. * output-tuple context; caller's CurrentMemoryContext does not matter.
*
* The finalfn uses the state as set in the transno. This also might be
* being used by another aggregate function, so it's important that we do
* nothing destructive here.
*/ */
static void static void
finalize_aggregate(AggState *aggstate, finalize_aggregate(AggState *aggstate,
AggStatePerAgg peraggstate, AggStatePerAgg peragg,
AggStatePerGroup pergroupstate, AggStatePerGroup pergroupstate,
Datum *resultVal, bool *resultIsNull) Datum *resultVal, bool *resultIsNull)
{ {
...@@ -1021,6 +1088,7 @@ finalize_aggregate(AggState *aggstate, ...@@ -1021,6 +1088,7 @@ finalize_aggregate(AggState *aggstate,
MemoryContext oldContext; MemoryContext oldContext;
int i; int i;
ListCell *lc; ListCell *lc;
AggStatePerTrans pertrans = &aggstate->pertrans[peragg->transno];
oldContext = MemoryContextSwitchTo(aggstate->ss.ps.ps_ExprContext->ecxt_per_tuple_memory); oldContext = MemoryContextSwitchTo(aggstate->ss.ps.ps_ExprContext->ecxt_per_tuple_memory);
...@@ -1031,7 +1099,7 @@ finalize_aggregate(AggState *aggstate, ...@@ -1031,7 +1099,7 @@ finalize_aggregate(AggState *aggstate,
* for the transition state value. * for the transition state value.
*/ */
i = 1; i = 1;
foreach(lc, peraggstate->aggrefstate->aggdirectargs) foreach(lc, pertrans->aggdirectargs)
{ {
ExprState *expr = (ExprState *) lfirst(lc); ExprState *expr = (ExprState *) lfirst(lc);
...@@ -1046,16 +1114,16 @@ finalize_aggregate(AggState *aggstate, ...@@ -1046,16 +1114,16 @@ finalize_aggregate(AggState *aggstate,
/* /*
* Apply the agg's finalfn if one is provided, else return transValue. * Apply the agg's finalfn if one is provided, else return transValue.
*/ */
if (OidIsValid(peraggstate->finalfn_oid)) if (OidIsValid(peragg->finalfn_oid))
{ {
int numFinalArgs = peraggstate->numFinalArgs; int numFinalArgs = peragg->numFinalArgs;
/* set up aggstate->curperagg for AggGetAggref() */ /* set up aggstate->curpertrans for AggGetAggref() */
aggstate->curperagg = peraggstate; aggstate->curpertrans = pertrans;
InitFunctionCallInfoData(fcinfo, &peraggstate->finalfn, InitFunctionCallInfoData(fcinfo, &peragg->finalfn,
numFinalArgs, numFinalArgs,
peraggstate->aggCollation, pertrans->aggCollation,
(void *) aggstate, NULL); (void *) aggstate, NULL);
/* Fill in the transition state value */ /* Fill in the transition state value */
...@@ -1082,7 +1150,7 @@ finalize_aggregate(AggState *aggstate, ...@@ -1082,7 +1150,7 @@ finalize_aggregate(AggState *aggstate,
*resultVal = FunctionCallInvoke(&fcinfo); *resultVal = FunctionCallInvoke(&fcinfo);
*resultIsNull = fcinfo.isnull; *resultIsNull = fcinfo.isnull;
} }
aggstate->curperagg = NULL; aggstate->curpertrans = NULL;
} }
else else
{ {
...@@ -1093,12 +1161,12 @@ finalize_aggregate(AggState *aggstate, ...@@ -1093,12 +1161,12 @@ finalize_aggregate(AggState *aggstate,
/* /*
* If result is pass-by-ref, make sure it is in the right context. * If result is pass-by-ref, make sure it is in the right context.
*/ */
if (!peraggstate->resulttypeByVal && !*resultIsNull && if (!peragg->resulttypeByVal && !*resultIsNull &&
!MemoryContextContains(CurrentMemoryContext, !MemoryContextContains(CurrentMemoryContext,
DatumGetPointer(*resultVal))) DatumGetPointer(*resultVal)))
*resultVal = datumCopy(*resultVal, *resultVal = datumCopy(*resultVal,
peraggstate->resulttypeByVal, peragg->resulttypeByVal,
peraggstate->resulttypeLen); peragg->resulttypeLen);
MemoryContextSwitchTo(oldContext); MemoryContextSwitchTo(oldContext);
} }
...@@ -1173,7 +1241,7 @@ prepare_projection_slot(AggState *aggstate, TupleTableSlot *slot, int currentSet ...@@ -1173,7 +1241,7 @@ prepare_projection_slot(AggState *aggstate, TupleTableSlot *slot, int currentSet
*/ */
static void static void
finalize_aggregates(AggState *aggstate, finalize_aggregates(AggState *aggstate,
AggStatePerAgg peragg, AggStatePerAgg peraggs,
AggStatePerGroup pergroup, AggStatePerGroup pergroup,
int currentSet) int currentSet)
{ {
...@@ -1189,26 +1257,28 @@ finalize_aggregates(AggState *aggstate, ...@@ -1189,26 +1257,28 @@ finalize_aggregates(AggState *aggstate,
for (aggno = 0; aggno < aggstate->numaggs; aggno++) for (aggno = 0; aggno < aggstate->numaggs; aggno++)
{ {
AggStatePerAgg peraggstate = &peragg[aggno]; AggStatePerAgg peragg = &peraggs[aggno];
int transno = peragg->transno;
AggStatePerTrans pertrans = &aggstate->pertrans[transno];
AggStatePerGroup pergroupstate; AggStatePerGroup pergroupstate;
pergroupstate = &pergroup[aggno + (currentSet * (aggstate->numaggs))]; pergroupstate = &pergroup[transno + (currentSet * (aggstate->numtrans))];
if (peraggstate->numSortCols > 0) if (pertrans->numSortCols > 0)
{ {
Assert(((Agg *) aggstate->ss.ps.plan)->aggstrategy != AGG_HASHED); Assert(((Agg *) aggstate->ss.ps.plan)->aggstrategy != AGG_HASHED);
if (peraggstate->numInputs == 1) if (pertrans->numInputs == 1)
process_ordered_aggregate_single(aggstate, process_ordered_aggregate_single(aggstate,
peraggstate, pertrans,
pergroupstate); pergroupstate);
else else
process_ordered_aggregate_multi(aggstate, process_ordered_aggregate_multi(aggstate,
peraggstate, pertrans,
pergroupstate); pergroupstate);
} }
finalize_aggregate(aggstate, peraggstate, pergroupstate, finalize_aggregate(aggstate, peragg, pergroupstate,
&aggvalues[aggno], &aggnulls[aggno]); &aggvalues[aggno], &aggnulls[aggno]);
} }
} }
...@@ -1428,7 +1498,7 @@ lookup_hash_entry(AggState *aggstate, TupleTableSlot *inputslot) ...@@ -1428,7 +1498,7 @@ lookup_hash_entry(AggState *aggstate, TupleTableSlot *inputslot)
if (isnew) if (isnew)
{ {
/* initialize aggregates for new tuple group */ /* initialize aggregates for new tuple group */
initialize_aggregates(aggstate, aggstate->peragg, entry->pergroup, 0); initialize_aggregates(aggstate, entry->pergroup, 0);
} }
return entry; return entry;
...@@ -1716,7 +1786,7 @@ agg_retrieve_direct(AggState *aggstate) ...@@ -1716,7 +1786,7 @@ agg_retrieve_direct(AggState *aggstate)
/* /*
* Initialize working state for a new input tuple group. * Initialize working state for a new input tuple group.
*/ */
initialize_aggregates(aggstate, peragg, pergroup, numReset); initialize_aggregates(aggstate, pergroup, numReset);
if (aggstate->grp_firstTuple != NULL) if (aggstate->grp_firstTuple != NULL)
{ {
...@@ -1945,17 +2015,18 @@ AggState * ...@@ -1945,17 +2015,18 @@ AggState *
ExecInitAgg(Agg *node, EState *estate, int eflags) ExecInitAgg(Agg *node, EState *estate, int eflags)
{ {
AggState *aggstate; AggState *aggstate;
AggStatePerAgg peragg; AggStatePerAgg peraggs;
AggStatePerTrans pertransstates;
Plan *outerPlan; Plan *outerPlan;
ExprContext *econtext; ExprContext *econtext;
int numaggs, int numaggs,
transno,
aggno; aggno;
int phase; int phase;
ListCell *l; ListCell *l;
Bitmapset *all_grouped_cols = NULL; Bitmapset *all_grouped_cols = NULL;
int numGroupingSets = 1; int numGroupingSets = 1;
int numPhases; int numPhases;
int currentsortno = 0;
int i = 0; int i = 0;
int j = 0; int j = 0;
...@@ -1971,12 +2042,14 @@ ExecInitAgg(Agg *node, EState *estate, int eflags) ...@@ -1971,12 +2042,14 @@ ExecInitAgg(Agg *node, EState *estate, int eflags)
aggstate->aggs = NIL; aggstate->aggs = NIL;
aggstate->numaggs = 0; aggstate->numaggs = 0;
aggstate->numtrans = 0;
aggstate->maxsets = 0; aggstate->maxsets = 0;
aggstate->hashfunctions = NULL; aggstate->hashfunctions = NULL;
aggstate->projected_set = -1; aggstate->projected_set = -1;
aggstate->current_set = 0; aggstate->current_set = 0;
aggstate->peragg = NULL; aggstate->peragg = NULL;
aggstate->curperagg = NULL; aggstate->pertrans = NULL;
aggstate->curpertrans = NULL;
aggstate->agg_done = false; aggstate->agg_done = false;
aggstate->input_done = false; aggstate->input_done = false;
aggstate->pergroup = NULL; aggstate->pergroup = NULL;
...@@ -2209,8 +2282,11 @@ ExecInitAgg(Agg *node, EState *estate, int eflags) ...@@ -2209,8 +2282,11 @@ ExecInitAgg(Agg *node, EState *estate, int eflags)
econtext->ecxt_aggvalues = (Datum *) palloc0(sizeof(Datum) * numaggs); econtext->ecxt_aggvalues = (Datum *) palloc0(sizeof(Datum) * numaggs);
econtext->ecxt_aggnulls = (bool *) palloc0(sizeof(bool) * numaggs); econtext->ecxt_aggnulls = (bool *) palloc0(sizeof(bool) * numaggs);
peragg = (AggStatePerAgg) palloc0(sizeof(AggStatePerAggData) * numaggs); peraggs = (AggStatePerAgg) palloc0(sizeof(AggStatePerAggData) * numaggs);
aggstate->peragg = peragg; pertransstates = (AggStatePerTrans) palloc0(sizeof(AggStatePerTransData) * numaggs);
aggstate->peragg = peraggs;
aggstate->pertrans = pertransstates;
if (node->aggstrategy == AGG_HASHED) if (node->aggstrategy == AGG_HASHED)
{ {
...@@ -2230,71 +2306,86 @@ ExecInitAgg(Agg *node, EState *estate, int eflags) ...@@ -2230,71 +2306,86 @@ ExecInitAgg(Agg *node, EState *estate, int eflags)
aggstate->pergroup = pergroup; aggstate->pergroup = pergroup;
} }
/* /* -----------------
* Perform lookups of aggregate function info, and initialize the * Perform lookups of aggregate function info, and initialize the
* unchanging fields of the per-agg data. We also detect duplicate * unchanging fields of the per-agg and per-trans data.
* aggregates (for example, "SELECT sum(x) ... HAVING sum(x) > 0"). When *
* duplicates are detected, we only make an AggStatePerAgg struct for the * We try to optimize by detecting duplicate aggregate functions so that
* first one. The clones are simply pointed at the same result entry by * their state and final values are re-used, rather than needlessly being
* giving them duplicate aggno values. * re-calculated independently. We also detect aggregates that are not
* the same, but which can share the same transition state.
*
* Scenarios:
*
* 1. An aggregate function appears more than once in query:
*
* SELECT SUM(x) FROM ... HAVING SUM(x) > 0
*
* Since the aggregates are the identical, we only need to calculate
* the calculate it once. Both aggregates will share the same 'aggno'
* value.
*
* 2. Two different aggregate functions appear in the query, but the
* aggregates have the same transition function and initial value, but
* different final function:
*
* SELECT SUM(x), AVG(x) FROM ...
*
* In this case we must create a new peragg for the varying aggregate,
* and need to call the final functions separately, but can share the
* same transition state.
*
* For either of these optimizations to be valid, the aggregate's
* arguments must be the same, including any modifiers such as ORDER BY,
* DISTINCT and FILTER, and they mustn't contain any volatile functions.
* -----------------
*/ */
aggno = -1; aggno = -1;
transno = -1;
foreach(l, aggstate->aggs) foreach(l, aggstate->aggs)
{ {
AggrefExprState *aggrefstate = (AggrefExprState *) lfirst(l); AggrefExprState *aggrefstate = (AggrefExprState *) lfirst(l);
Aggref *aggref = (Aggref *) aggrefstate->xprstate.expr; Aggref *aggref = (Aggref *) aggrefstate->xprstate.expr;
AggStatePerAgg peraggstate; AggStatePerAgg peragg;
AggStatePerTrans pertrans;
int existing_aggno;
int existing_transno;
List *same_input_transnos;
Oid inputTypes[FUNC_MAX_ARGS]; Oid inputTypes[FUNC_MAX_ARGS];
int numArguments; int numArguments;
int numDirectArgs; int numDirectArgs;
int numInputs;
int numSortCols;
int numDistinctCols;
List *sortlist;
HeapTuple aggTuple; HeapTuple aggTuple;
Form_pg_aggregate aggform; Form_pg_aggregate aggform;
Oid aggtranstype;
AclResult aclresult; AclResult aclresult;
Oid transfn_oid, Oid transfn_oid,
finalfn_oid; finalfn_oid;
Expr *transfnexpr, Expr *finalfnexpr;
*finalfnexpr; Oid aggtranstype;
Datum textInitVal; Datum textInitVal;
int i; Datum initValue;
ListCell *lc; bool initValueIsNull;
/* Planner should have assigned aggregate to correct level */ /* Planner should have assigned aggregate to correct level */
Assert(aggref->agglevelsup == 0); Assert(aggref->agglevelsup == 0);
/* Look for a previous duplicate aggregate */ /* 1. Check for already processed aggs which can be re-used */
for (i = 0; i <= aggno; i++) existing_aggno = find_compatible_peragg(aggref, aggstate, aggno,
{ &same_input_transnos);
if (equal(aggref, peragg[i].aggref) && if (existing_aggno != -1)
!contain_volatile_functions((Node *) aggref))
break;
}
if (i <= aggno)
{ {
/* Found a match to an existing entry, so just mark it */ /*
aggrefstate->aggno = i; * Existing compatible agg found. so just point the Aggref to the
* same per-agg struct.
*/
aggrefstate->aggno = existing_aggno;
continue; continue;
} }
/* Nope, so assign a new PerAgg record */
peraggstate = &peragg[++aggno];
/* Mark Aggref state node with assigned index in the result array */ /* Mark Aggref state node with assigned index in the result array */
peragg = &peraggs[++aggno];
peragg->aggref = aggref;
aggrefstate->aggno = aggno; aggrefstate->aggno = aggno;
/* Begin filling in the peraggstate data */
peraggstate->aggrefstate = aggrefstate;
peraggstate->aggref = aggref;
peraggstate->sortstates = (Tuplesortstate **)
palloc0(sizeof(Tuplesortstate *) * numGroupingSets);
for (currentsortno = 0; currentsortno < numGroupingSets; currentsortno++)
peraggstate->sortstates[currentsortno] = NULL;
/* Fetch the pg_aggregate row */ /* Fetch the pg_aggregate row */
aggTuple = SearchSysCache1(AGGFNOID, aggTuple = SearchSysCache1(AGGFNOID,
ObjectIdGetDatum(aggref->aggfnoid)); ObjectIdGetDatum(aggref->aggfnoid));
...@@ -2311,8 +2402,8 @@ ExecInitAgg(Agg *node, EState *estate, int eflags) ...@@ -2311,8 +2402,8 @@ ExecInitAgg(Agg *node, EState *estate, int eflags)
get_func_name(aggref->aggfnoid)); get_func_name(aggref->aggfnoid));
InvokeFunctionExecuteHook(aggref->aggfnoid); InvokeFunctionExecuteHook(aggref->aggfnoid);
peraggstate->transfn_oid = transfn_oid = aggform->aggtransfn; transfn_oid = aggform->aggtransfn;
peraggstate->finalfn_oid = finalfn_oid = aggform->aggfinalfn; peragg->finalfn_oid = finalfn_oid = aggform->aggfinalfn;
/* Check that aggregate owner has permission to call component fns */ /* Check that aggregate owner has permission to call component fns */
{ {
...@@ -2350,74 +2441,43 @@ ExecInitAgg(Agg *node, EState *estate, int eflags) ...@@ -2350,74 +2441,43 @@ ExecInitAgg(Agg *node, EState *estate, int eflags)
* agg accepts ANY or a polymorphic type. * agg accepts ANY or a polymorphic type.
*/ */
numArguments = get_aggregate_argtypes(aggref, inputTypes); numArguments = get_aggregate_argtypes(aggref, inputTypes);
peraggstate->numArguments = numArguments;
/* Count the "direct" arguments, if any */ /* Count the "direct" arguments, if any */
numDirectArgs = list_length(aggref->aggdirectargs); numDirectArgs = list_length(aggref->aggdirectargs);
/* Count the number of aggregated input columns */
numInputs = list_length(aggref->args);
peraggstate->numInputs = numInputs;
/* Detect how many arguments to pass to the transfn */
if (AGGKIND_IS_ORDERED_SET(aggref->aggkind))
peraggstate->numTransInputs = numInputs;
else
peraggstate->numTransInputs = numArguments;
/* Detect how many arguments to pass to the finalfn */
if (aggform->aggfinalextra)
peraggstate->numFinalArgs = numArguments + 1;
else
peraggstate->numFinalArgs = numDirectArgs + 1;
/* resolve actual type of transition state, if polymorphic */ /* resolve actual type of transition state, if polymorphic */
aggtranstype = resolve_aggregate_transtype(aggref->aggfnoid, aggtranstype = resolve_aggregate_transtype(aggref->aggfnoid,
aggform->aggtranstype, aggform->aggtranstype,
inputTypes, inputTypes,
numArguments); numArguments);
/* build expression trees using actual argument & result types */ /* Detect how many arguments to pass to the finalfn */
build_aggregate_fnexprs(inputTypes, if (aggform->aggfinalextra)
numArguments, peragg->numFinalArgs = numArguments + 1;
numDirectArgs, else
peraggstate->numFinalArgs, peragg->numFinalArgs = numDirectArgs + 1;
aggref->aggvariadic,
/*
* build expression trees using actual argument & result types for the
* finalfn, if it exists
*/
if (OidIsValid(finalfn_oid))
{
build_aggregate_finalfn_expr(inputTypes,
peragg->numFinalArgs,
aggtranstype, aggtranstype,
aggref->aggtype, aggref->aggtype,
aggref->inputcollid, aggref->inputcollid,
transfn_oid,
InvalidOid, /* invtrans is not needed here */
finalfn_oid, finalfn_oid,
&transfnexpr,
NULL,
&finalfnexpr); &finalfnexpr);
fmgr_info(finalfn_oid, &peragg->finalfn);
/* set up infrastructure for calling the transfn and finalfn */ fmgr_info_set_expr((Node *) finalfnexpr, &peragg->finalfn);
fmgr_info(transfn_oid, &peraggstate->transfn);
fmgr_info_set_expr((Node *) transfnexpr, &peraggstate->transfn);
if (OidIsValid(finalfn_oid))
{
fmgr_info(finalfn_oid, &peraggstate->finalfn);
fmgr_info_set_expr((Node *) finalfnexpr, &peraggstate->finalfn);
} }
peraggstate->aggCollation = aggref->inputcollid; /* get info about the result type's datatype */
InitFunctionCallInfoData(peraggstate->transfn_fcinfo,
&peraggstate->transfn,
peraggstate->numTransInputs + 1,
peraggstate->aggCollation,
(void *) aggstate, NULL);
/* get info about relevant datatypes */
get_typlenbyval(aggref->aggtype, get_typlenbyval(aggref->aggtype,
&peraggstate->resulttypeLen, &peragg->resulttypeLen,
&peraggstate->resulttypeByVal); &peragg->resulttypeByVal);
get_typlenbyval(aggtranstype,
&peraggstate->transtypeLen,
&peraggstate->transtypeByVal);
/* /*
* initval is potentially null, so don't try to access it as a struct * initval is potentially null, so don't try to access it as a struct
...@@ -2425,46 +2485,182 @@ ExecInitAgg(Agg *node, EState *estate, int eflags) ...@@ -2425,46 +2485,182 @@ ExecInitAgg(Agg *node, EState *estate, int eflags)
*/ */
textInitVal = SysCacheGetAttr(AGGFNOID, aggTuple, textInitVal = SysCacheGetAttr(AGGFNOID, aggTuple,
Anum_pg_aggregate_agginitval, Anum_pg_aggregate_agginitval,
&peraggstate->initValueIsNull); &initValueIsNull);
if (initValueIsNull)
initValue = (Datum) 0;
else
initValue = GetAggInitVal(textInitVal, aggtranstype);
/*
* 2. Build working state for invoking the transition function, or
* look up previously initialized working state, if we can share it.
*
* find_compatible_peragg() already collected a list of per-Trans's
* with the same inputs. Check if any of them have the same transition
* function and initial value.
*/
existing_transno = find_compatible_pertrans(aggstate, aggref,
transfn_oid, aggtranstype,
initValue, initValueIsNull,
same_input_transnos);
if (existing_transno != -1)
{
/*
* Existing compatible trans found, so just point the 'peragg' to
* the same per-trans struct.
*/
pertrans = &pertransstates[existing_transno];
peragg->transno = existing_transno;
}
else
{
pertrans = &pertransstates[++transno];
build_pertrans_for_aggref(pertrans, aggstate, estate,
aggref, transfn_oid, aggtranstype,
initValue, initValueIsNull,
inputTypes, numArguments);
peragg->transno = transno;
}
ReleaseSysCache(aggTuple);
}
/*
* Update numaggs to match the number of unique aggregates found. Also set
* numstates to the number of unique aggregate states found.
*/
aggstate->numaggs = aggno + 1;
aggstate->numtrans = transno + 1;
return aggstate;
}
/*
* Build the state needed to calculate a state value for an aggregate.
*
* This initializes all the fields in 'pertrans'. 'aggref' is the aggregate
* to initialize the state for. 'aggtransfn', 'aggtranstype', and the rest
* of the arguments could be calculated from 'aggref', but the caller has
* calculated them already, so might as well pass them.
*/
static void
build_pertrans_for_aggref(AggStatePerTrans pertrans,
AggState *aggstate, EState *estate,
Aggref *aggref,
Oid aggtransfn, Oid aggtranstype,
Datum initValue, bool initValueIsNull,
Oid *inputTypes, int numArguments)
{
int numGroupingSets = Max(aggstate->maxsets, 1);
Expr *transfnexpr;
ListCell *lc;
int numInputs;
int numDirectArgs;
List *sortlist;
int numSortCols;
int numDistinctCols;
int naggs;
int i;
/* Begin filling in the pertrans data */
pertrans->aggref = aggref;
pertrans->aggCollation = aggref->inputcollid;
pertrans->transfn_oid = aggtransfn;
pertrans->initValue = initValue;
pertrans->initValueIsNull = initValueIsNull;
/* Count the "direct" arguments, if any */
numDirectArgs = list_length(aggref->aggdirectargs);
if (peraggstate->initValueIsNull) /* Count the number of aggregated input columns */
peraggstate->initValue = (Datum) 0; pertrans->numInputs = numInputs = list_length(aggref->args);
pertrans->aggtranstype = aggtranstype;
/* Detect how many arguments to pass to the transfn */
if (AGGKIND_IS_ORDERED_SET(aggref->aggkind))
pertrans->numTransInputs = numInputs;
else else
peraggstate->initValue = GetAggInitVal(textInitVal, pertrans->numTransInputs = numArguments;
aggtranstype);
/*
* Set up infrastructure for calling the transfn
*/
build_aggregate_transfn_expr(inputTypes,
numArguments,
numDirectArgs,
aggref->aggvariadic,
aggtranstype,
aggref->inputcollid,
aggtransfn,
InvalidOid, /* invtrans is not needed here */
&transfnexpr,
NULL);
fmgr_info(aggtransfn, &pertrans->transfn);
fmgr_info_set_expr((Node *) transfnexpr, &pertrans->transfn);
InitFunctionCallInfoData(pertrans->transfn_fcinfo,
&pertrans->transfn,
pertrans->numTransInputs + 1,
pertrans->aggCollation,
(void *) aggstate, NULL);
/* /*
* If the transfn is strict and the initval is NULL, make sure input * If the transfn is strict and the initval is NULL, make sure input type
* type and transtype are the same (or at least binary-compatible), so * and transtype are the same (or at least binary-compatible), so that
* that it's OK to use the first aggregated input value as the initial * it's OK to use the first aggregated input value as the initial
* transValue. This should have been checked at agg definition time, * transValue. This should have been checked at agg definition time, but
* but we must check again in case the transfn's strictness property * we must check again in case the transfn's strictness property has been
* has been changed. * changed.
*/ */
if (peraggstate->transfn.fn_strict && peraggstate->initValueIsNull) if (pertrans->transfn.fn_strict && pertrans->initValueIsNull)
{ {
if (numArguments <= numDirectArgs || if (numArguments <= numDirectArgs ||
!IsBinaryCoercible(inputTypes[numDirectArgs], aggtranstype)) !IsBinaryCoercible(inputTypes[numDirectArgs],
aggtranstype))
ereport(ERROR, ereport(ERROR,
(errcode(ERRCODE_INVALID_FUNCTION_DEFINITION), (errcode(ERRCODE_INVALID_FUNCTION_DEFINITION),
errmsg("aggregate %u needs to have compatible input type and transition type", errmsg("aggregate %u needs to have compatible input type and transition type",
aggref->aggfnoid))); aggref->aggfnoid)));
} }
/* get info about the state value's datatype */
get_typlenbyval(aggtranstype,
&pertrans->transtypeLen,
&pertrans->transtypeByVal);
/* /*
* Get a tupledesc corresponding to the aggregated inputs (including * Get a tupledesc corresponding to the aggregated inputs (including sort
* sort expressions) of the agg. * expressions) of the agg.
*/ */
peraggstate->evaldesc = ExecTypeFromTL(aggref->args, false); pertrans->evaldesc = ExecTypeFromTL(aggref->args, false);
/* Create slot we're going to do argument evaluation in */ /* Create slot we're going to do argument evaluation in */
peraggstate->evalslot = ExecInitExtraTupleSlot(estate); pertrans->evalslot = ExecInitExtraTupleSlot(estate);
ExecSetSlotDescriptor(peraggstate->evalslot, peraggstate->evaldesc); ExecSetSlotDescriptor(pertrans->evalslot, pertrans->evaldesc);
/* Initialize the input and FILTER expressions */
naggs = aggstate->numaggs;
pertrans->aggfilter = ExecInitExpr(aggref->aggfilter,
(PlanState *) aggstate);
pertrans->aggdirectargs = (List *) ExecInitExpr((Expr *) aggref->aggdirectargs,
(PlanState *) aggstate);
pertrans->args = (List *) ExecInitExpr((Expr *) aggref->args,
(PlanState *) aggstate);
/*
* Complain if the aggregate's arguments contain any aggregates; nested
* agg functions are semantically nonsensical. (This should have been
* caught earlier, but we defend against it here anyway.)
*/
if (naggs != aggstate->numaggs)
ereport(ERROR,
(errcode(ERRCODE_GROUPING_ERROR),
errmsg("aggregate function calls cannot be nested")));
/* Set up projection info for evaluation */ /* Set up projection info for evaluation */
peraggstate->evalproj = ExecBuildProjectionInfo(aggrefstate->args, pertrans->evalproj = ExecBuildProjectionInfo(pertrans->args,
aggstate->tmpcontext, aggstate->tmpcontext,
peraggstate->evalslot, pertrans->evalslot,
NULL); NULL);
/* /*
...@@ -2473,8 +2669,8 @@ ExecInitAgg(Agg *node, EState *estate, int eflags) ...@@ -2473,8 +2669,8 @@ ExecInitAgg(Agg *node, EState *estate, int eflags)
* stick them into arrays. We ignore ORDER BY for an ordered-set agg, * stick them into arrays. We ignore ORDER BY for an ordered-set agg,
* however; the agg's transfn and finalfn are responsible for that. * however; the agg's transfn and finalfn are responsible for that.
* *
* Note that by construction, if there is a DISTINCT clause then the * Note that by construction, if there is a DISTINCT clause then the ORDER
* ORDER BY clause is a prefix of it (see transformDistinctClause). * BY clause is a prefix of it (see transformDistinctClause).
*/ */
if (AGGKIND_IS_ORDERED_SET(aggref->aggkind)) if (AGGKIND_IS_ORDERED_SET(aggref->aggkind))
{ {
...@@ -2494,8 +2690,8 @@ ExecInitAgg(Agg *node, EState *estate, int eflags) ...@@ -2494,8 +2690,8 @@ ExecInitAgg(Agg *node, EState *estate, int eflags)
numDistinctCols = 0; numDistinctCols = 0;
} }
peraggstate->numSortCols = numSortCols; pertrans->numSortCols = numSortCols;
peraggstate->numDistinctCols = numDistinctCols; pertrans->numDistinctCols = numDistinctCols;
if (numSortCols > 0) if (numSortCols > 0)
{ {
...@@ -2503,47 +2699,46 @@ ExecInitAgg(Agg *node, EState *estate, int eflags) ...@@ -2503,47 +2699,46 @@ ExecInitAgg(Agg *node, EState *estate, int eflags)
* We don't implement DISTINCT or ORDER BY aggs in the HASHED case * We don't implement DISTINCT or ORDER BY aggs in the HASHED case
* (yet) * (yet)
*/ */
Assert(node->aggstrategy != AGG_HASHED); Assert(((Agg *) aggstate->ss.ps.plan)->aggstrategy != AGG_HASHED);
/* If we have only one input, we need its len/byval info. */ /* If we have only one input, we need its len/byval info. */
if (numInputs == 1) if (numInputs == 1)
{ {
get_typlenbyval(inputTypes[numDirectArgs], get_typlenbyval(inputTypes[numDirectArgs],
&peraggstate->inputtypeLen, &pertrans->inputtypeLen,
&peraggstate->inputtypeByVal); &pertrans->inputtypeByVal);
} }
else if (numDistinctCols > 0) else if (numDistinctCols > 0)
{ {
/* we will need an extra slot to store prior values */ /* we will need an extra slot to store prior values */
peraggstate->uniqslot = ExecInitExtraTupleSlot(estate); pertrans->uniqslot = ExecInitExtraTupleSlot(estate);
ExecSetSlotDescriptor(peraggstate->uniqslot, ExecSetSlotDescriptor(pertrans->uniqslot,
peraggstate->evaldesc); pertrans->evaldesc);
} }
/* Extract the sort information for use later */ /* Extract the sort information for use later */
peraggstate->sortColIdx = pertrans->sortColIdx =
(AttrNumber *) palloc(numSortCols * sizeof(AttrNumber)); (AttrNumber *) palloc(numSortCols * sizeof(AttrNumber));
peraggstate->sortOperators = pertrans->sortOperators =
(Oid *) palloc(numSortCols * sizeof(Oid)); (Oid *) palloc(numSortCols * sizeof(Oid));
peraggstate->sortCollations = pertrans->sortCollations =
(Oid *) palloc(numSortCols * sizeof(Oid)); (Oid *) palloc(numSortCols * sizeof(Oid));
peraggstate->sortNullsFirst = pertrans->sortNullsFirst =
(bool *) palloc(numSortCols * sizeof(bool)); (bool *) palloc(numSortCols * sizeof(bool));
i = 0; i = 0;
foreach(lc, sortlist) foreach(lc, sortlist)
{ {
SortGroupClause *sortcl = (SortGroupClause *) lfirst(lc); SortGroupClause *sortcl = (SortGroupClause *) lfirst(lc);
TargetEntry *tle = get_sortgroupclause_tle(sortcl, TargetEntry *tle = get_sortgroupclause_tle(sortcl, aggref->args);
aggref->args);
/* the parser should have made sure of this */ /* the parser should have made sure of this */
Assert(OidIsValid(sortcl->sortop)); Assert(OidIsValid(sortcl->sortop));
peraggstate->sortColIdx[i] = tle->resno; pertrans->sortColIdx[i] = tle->resno;
peraggstate->sortOperators[i] = sortcl->sortop; pertrans->sortOperators[i] = sortcl->sortop;
peraggstate->sortCollations[i] = exprCollation((Node *) tle->expr); pertrans->sortCollations[i] = exprCollation((Node *) tle->expr);
peraggstate->sortNullsFirst[i] = sortcl->nulls_first; pertrans->sortNullsFirst[i] = sortcl->nulls_first;
i++; i++;
} }
Assert(i == numSortCols); Assert(i == numSortCols);
...@@ -2557,7 +2752,7 @@ ExecInitAgg(Agg *node, EState *estate, int eflags) ...@@ -2557,7 +2752,7 @@ ExecInitAgg(Agg *node, EState *estate, int eflags)
* We need the equal function for each DISTINCT comparison we will * We need the equal function for each DISTINCT comparison we will
* make. * make.
*/ */
peraggstate->equalfns = pertrans->equalfns =
(FmgrInfo *) palloc(numDistinctCols * sizeof(FmgrInfo)); (FmgrInfo *) palloc(numDistinctCols * sizeof(FmgrInfo));
i = 0; i = 0;
...@@ -2565,21 +2760,17 @@ ExecInitAgg(Agg *node, EState *estate, int eflags) ...@@ -2565,21 +2760,17 @@ ExecInitAgg(Agg *node, EState *estate, int eflags)
{ {
SortGroupClause *sortcl = (SortGroupClause *) lfirst(lc); SortGroupClause *sortcl = (SortGroupClause *) lfirst(lc);
fmgr_info(get_opcode(sortcl->eqop), &peraggstate->equalfns[i]); fmgr_info(get_opcode(sortcl->eqop), &pertrans->equalfns[i]);
i++; i++;
} }
Assert(i == numDistinctCols); Assert(i == numDistinctCols);
} }
ReleaseSysCache(aggTuple); pertrans->sortstates = (Tuplesortstate **)
} palloc0(sizeof(Tuplesortstate *) * numGroupingSets);
/* Update numaggs to match number of unique aggregates found */
aggstate->numaggs = aggno + 1;
return aggstate;
} }
static Datum static Datum
GetAggInitVal(Datum textInitVal, Oid transtype) GetAggInitVal(Datum textInitVal, Oid transtype)
{ {
...@@ -2596,11 +2787,130 @@ GetAggInitVal(Datum textInitVal, Oid transtype) ...@@ -2596,11 +2787,130 @@ GetAggInitVal(Datum textInitVal, Oid transtype)
return initVal; return initVal;
} }
/*
* find_compatible_peragg - search for a previously initialized per-Agg struct
*
* Searches the previously looked at aggregates to find one which is compatible
* with this one, with the same input parameters. If no compatible aggregate
* can be found, returns -1.
*
* As a side-effect, this also collects a list of existing per-Trans structs
* with matching inputs. If no identical Aggref is found, the list is passed
* later to find_compatible_perstate, to see if we can at least reuse the
* state value of another aggregate.
*/
static int
find_compatible_peragg(Aggref *newagg, AggState *aggstate,
int lastaggno, List **same_input_transnos)
{
int aggno;
AggStatePerAgg peraggs;
*same_input_transnos = NIL;
/* we mustn't reuse the aggref if it contains volatile function calls */
if (contain_volatile_functions((Node *) newagg))
return -1;
peraggs = aggstate->peragg;
/*
* Search through the list of already seen aggregates. If we find an
* existing aggregate with the same aggregate function and input
* parameters as an existing one, then we can re-use that one. While
* searching, we'll also collect a list of Aggrefs with the same input
* parameters. If no matching Aggref is found, the caller can potentially
* still re-use the transition state of one of them.
*/
for (aggno = 0; aggno <= lastaggno; aggno++)
{
AggStatePerAgg peragg;
Aggref *existingRef;
peragg = &peraggs[aggno];
existingRef = peragg->aggref;
/* all of the following must be the same or it's no match */
if (newagg->inputcollid != existingRef->inputcollid ||
newagg->aggstar != existingRef->aggstar ||
newagg->aggvariadic != existingRef->aggvariadic ||
newagg->aggkind != existingRef->aggkind ||
!equal(newagg->aggdirectargs, existingRef->aggdirectargs) ||
!equal(newagg->args, existingRef->args) ||
!equal(newagg->aggorder, existingRef->aggorder) ||
!equal(newagg->aggdistinct, existingRef->aggdistinct) ||
!equal(newagg->aggfilter, existingRef->aggfilter))
continue;
/* if it's the same aggregate function then report exact match */
if (newagg->aggfnoid == existingRef->aggfnoid &&
newagg->aggtype == existingRef->aggtype &&
newagg->aggcollid == existingRef->aggcollid)
{
list_free(*same_input_transnos);
*same_input_transnos = NIL;
return aggno;
}
/*
* Not identical, but it had the same inputs. Return it to the caller,
* in case we can re-use its per-trans state.
*/
*same_input_transnos = lappend_int(*same_input_transnos,
peragg->transno);
}
return -1;
}
/*
* find_compatible_pertrans - search for a previously initialized per-Trans
* struct
*
* Searches the list of transnos for a per-Trans struct with the same
* transition state and initial condition. (The inputs have already been
* verified to match.)
*/
static int
find_compatible_pertrans(AggState *aggstate, Aggref *newagg,
Oid aggtransfn, Oid aggtranstype,
Datum initValue, bool initValueIsNull,
List *transnos)
{
ListCell *lc;
foreach(lc, transnos)
{
int transno = lfirst_int(lc);
AggStatePerTrans pertrans = &aggstate->pertrans[transno];
/*
* if the transfns or transition state types are not the same then the
* state can't be shared.
*/
if (aggtransfn != pertrans->transfn_oid ||
aggtranstype != pertrans->aggtranstype)
continue;
/* Check that the initial condition matches, too. */
if (initValueIsNull && pertrans->initValueIsNull)
return transno;
if (!initValueIsNull && !pertrans->initValueIsNull &&
datumIsEqual(initValue, pertrans->initValue,
pertrans->transtypeByVal, pertrans->transtypeLen))
{
return transno;
}
}
return -1;
}
void void
ExecEndAgg(AggState *node) ExecEndAgg(AggState *node)
{ {
PlanState *outerPlan; PlanState *outerPlan;
int aggno; int transno;
int numGroupingSets = Max(node->maxsets, 1); int numGroupingSets = Max(node->maxsets, 1);
int setno; int setno;
...@@ -2611,14 +2921,14 @@ ExecEndAgg(AggState *node) ...@@ -2611,14 +2921,14 @@ ExecEndAgg(AggState *node)
if (node->sort_out) if (node->sort_out)
tuplesort_end(node->sort_out); tuplesort_end(node->sort_out);
for (aggno = 0; aggno < node->numaggs; aggno++) for (transno = 0; transno < node->numtrans; transno++)
{ {
AggStatePerAgg peraggstate = &node->peragg[aggno]; AggStatePerTrans pertrans = &node->pertrans[transno];
for (setno = 0; setno < numGroupingSets; setno++) for (setno = 0; setno < numGroupingSets; setno++)
{ {
if (peraggstate->sortstates[setno]) if (pertrans->sortstates[setno])
tuplesort_end(peraggstate->sortstates[setno]); tuplesort_end(pertrans->sortstates[setno]);
} }
} }
...@@ -2646,7 +2956,7 @@ ExecReScanAgg(AggState *node) ...@@ -2646,7 +2956,7 @@ ExecReScanAgg(AggState *node)
ExprContext *econtext = node->ss.ps.ps_ExprContext; ExprContext *econtext = node->ss.ps.ps_ExprContext;
PlanState *outerPlan = outerPlanState(node); PlanState *outerPlan = outerPlanState(node);
Agg *aggnode = (Agg *) node->ss.ps.plan; Agg *aggnode = (Agg *) node->ss.ps.plan;
int aggno; int transno;
int numGroupingSets = Max(node->maxsets, 1); int numGroupingSets = Max(node->maxsets, 1);
int setno; int setno;
...@@ -2678,16 +2988,16 @@ ExecReScanAgg(AggState *node) ...@@ -2678,16 +2988,16 @@ ExecReScanAgg(AggState *node)
} }
/* Make sure we have closed any open tuplesorts */ /* Make sure we have closed any open tuplesorts */
for (aggno = 0; aggno < node->numaggs; aggno++) for (transno = 0; transno < node->numtrans; transno++)
{ {
for (setno = 0; setno < numGroupingSets; setno++) for (setno = 0; setno < numGroupingSets; setno++)
{ {
AggStatePerAgg peraggstate = &node->peragg[aggno]; AggStatePerTrans pertrans = &node->pertrans[transno];
if (peraggstate->sortstates[setno]) if (pertrans->sortstates[setno])
{ {
tuplesort_end(peraggstate->sortstates[setno]); tuplesort_end(pertrans->sortstates[setno]);
peraggstate->sortstates[setno] = NULL; pertrans->sortstates[setno] = NULL;
} }
} }
} }
...@@ -2811,10 +3121,12 @@ AggGetAggref(FunctionCallInfo fcinfo) ...@@ -2811,10 +3121,12 @@ AggGetAggref(FunctionCallInfo fcinfo)
{ {
if (fcinfo->context && IsA(fcinfo->context, AggState)) if (fcinfo->context && IsA(fcinfo->context, AggState))
{ {
AggStatePerAgg curperagg = ((AggState *) fcinfo->context)->curperagg; AggStatePerTrans curpertrans;
curpertrans = ((AggState *) fcinfo->context)->curpertrans;
if (curperagg) if (curpertrans)
return curperagg->aggref; return curpertrans->aggref;
} }
return NULL; return NULL;
} }
......
...@@ -2218,20 +2218,16 @@ initialize_peragg(WindowAggState *winstate, WindowFunc *wfunc, ...@@ -2218,20 +2218,16 @@ initialize_peragg(WindowAggState *winstate, WindowFunc *wfunc,
numArguments); numArguments);
/* build expression trees using actual argument & result types */ /* build expression trees using actual argument & result types */
build_aggregate_fnexprs(inputTypes, build_aggregate_transfn_expr(inputTypes,
numArguments, numArguments,
0, /* no ordered-set window functions yet */ 0, /* no ordered-set window functions yet */
peraggstate->numFinalArgs,
false, /* no variadic window functions yet */ false, /* no variadic window functions yet */
aggtranstype,
wfunc->wintype, wfunc->wintype,
wfunc->inputcollid, wfunc->inputcollid,
transfn_oid, transfn_oid,
invtransfn_oid, invtransfn_oid,
finalfn_oid,
&transfnexpr, &transfnexpr,
&invtransfnexpr, &invtransfnexpr);
&finalfnexpr);
/* set up infrastructure for calling the transfn(s) and finalfn */ /* set up infrastructure for calling the transfn(s) and finalfn */
fmgr_info(transfn_oid, &peraggstate->transfn); fmgr_info(transfn_oid, &peraggstate->transfn);
...@@ -2245,6 +2241,13 @@ initialize_peragg(WindowAggState *winstate, WindowFunc *wfunc, ...@@ -2245,6 +2241,13 @@ initialize_peragg(WindowAggState *winstate, WindowFunc *wfunc,
if (OidIsValid(finalfn_oid)) if (OidIsValid(finalfn_oid))
{ {
build_aggregate_finalfn_expr(inputTypes,
peraggstate->numFinalArgs,
aggtranstype,
wfunc->wintype,
wfunc->inputcollid,
finalfn_oid,
&finalfnexpr);
fmgr_info(finalfn_oid, &peraggstate->finalfn); fmgr_info(finalfn_oid, &peraggstate->finalfn);
fmgr_info_set_expr((Node *) finalfnexpr, &peraggstate->finalfn); fmgr_info_set_expr((Node *) finalfnexpr, &peraggstate->finalfn);
} }
......
...@@ -1829,44 +1829,40 @@ resolve_aggregate_transtype(Oid aggfuncid, ...@@ -1829,44 +1829,40 @@ resolve_aggregate_transtype(Oid aggfuncid,
} }
/* /*
* Create expression trees for the transition and final functions * Create an expression tree for the transition function of an aggregate.
* of an aggregate. These are needed so that polymorphic functions * This is needed so that polymorphic functions can be used within an
* can be used within an aggregate --- without the expression trees, * aggregate --- without the expression tree, such functions would not know
* such functions would not know the datatypes they are supposed to use. * the datatypes they are supposed to use. (The trees will never actually
* (The trees will never actually be executed, however, so we can skimp * be executed, however, so we can skimp a bit on correctness.)
* a bit on correctness.)
* *
* agg_input_types, agg_state_type, agg_result_type identify the input, * agg_input_types and agg_state_type identifies the input types of the
* transition, and result types of the aggregate. These should all be * aggregate. These should be resolved to actual types (ie, none should
* resolved to actual types (ie, none should ever be ANYELEMENT etc). * ever be ANYELEMENT etc).
* agg_input_collation is the aggregate function's input collation. * agg_input_collation is the aggregate function's input collation.
* *
* For an ordered-set aggregate, remember that agg_input_types describes * For an ordered-set aggregate, remember that agg_input_types describes
* the direct arguments followed by the aggregated arguments. * the direct arguments followed by the aggregated arguments.
* *
* transfn_oid, invtransfn_oid and finalfn_oid identify the funcs to be * transfn_oid and invtransfn_oid identify the funcs to be called; the
* called; the latter two may be InvalidOid. * latter may be InvalidOid, however if invtransfn_oid is set then
* transfn_oid must also be set.
* *
* Pointers to the constructed trees are returned into *transfnexpr, * Pointers to the constructed trees are returned into *transfnexpr,
* *invtransfnexpr and *finalfnexpr. If there is no invtransfn or finalfn, * *invtransfnexpr. If there is no invtransfn, the respective pointer is set
* the respective pointers are set to NULL. Since use of the invtransfn is * to NULL. Since use of the invtransfn is optional, NULL may be passed for
* optional, NULL may be passed for invtransfnexpr. * invtransfnexpr.
*/ */
void void
build_aggregate_fnexprs(Oid *agg_input_types, build_aggregate_transfn_expr(Oid *agg_input_types,
int agg_num_inputs, int agg_num_inputs,
int agg_num_direct_inputs, int agg_num_direct_inputs,
int num_finalfn_inputs,
bool agg_variadic, bool agg_variadic,
Oid agg_state_type, Oid agg_state_type,
Oid agg_result_type,
Oid agg_input_collation, Oid agg_input_collation,
Oid transfn_oid, Oid transfn_oid,
Oid invtransfn_oid, Oid invtransfn_oid,
Oid finalfn_oid,
Expr **transfnexpr, Expr **transfnexpr,
Expr **invtransfnexpr, Expr **invtransfnexpr)
Expr **finalfnexpr)
{ {
Param *argp; Param *argp;
List *args; List *args;
...@@ -1929,13 +1925,24 @@ build_aggregate_fnexprs(Oid *agg_input_types, ...@@ -1929,13 +1925,24 @@ build_aggregate_fnexprs(Oid *agg_input_types,
else else
*invtransfnexpr = NULL; *invtransfnexpr = NULL;
} }
}
/* see if we have a final function */ /*
if (!OidIsValid(finalfn_oid)) * Like build_aggregate_transfn_expr, but creates an expression tree for the
{ * final function of an aggregate, rather than the transition function.
*finalfnexpr = NULL; */
return; void
} build_aggregate_finalfn_expr(Oid *agg_input_types,
int num_finalfn_inputs,
Oid agg_state_type,
Oid agg_result_type,
Oid agg_input_collation,
Oid finalfn_oid,
Expr **finalfnexpr)
{
Param *argp;
List *args;
int i;
/* /*
* Build expr tree for final function * Build expr tree for final function
......
...@@ -609,9 +609,6 @@ typedef struct WholeRowVarExprState ...@@ -609,9 +609,6 @@ typedef struct WholeRowVarExprState
typedef struct AggrefExprState typedef struct AggrefExprState
{ {
ExprState xprstate; ExprState xprstate;
List *aggdirectargs; /* states of direct-argument expressions */
List *args; /* states of aggregated-argument expressions */
ExprState *aggfilter; /* state of FILTER expression, if any */
int aggno; /* ID number for agg within its plan node */ int aggno; /* ID number for agg within its plan node */
} AggrefExprState; } AggrefExprState;
...@@ -1825,6 +1822,7 @@ typedef struct GroupState ...@@ -1825,6 +1822,7 @@ typedef struct GroupState
*/ */
/* these structs are private in nodeAgg.c: */ /* these structs are private in nodeAgg.c: */
typedef struct AggStatePerAggData *AggStatePerAgg; typedef struct AggStatePerAggData *AggStatePerAgg;
typedef struct AggStatePerTransData *AggStatePerTrans;
typedef struct AggStatePerGroupData *AggStatePerGroup; typedef struct AggStatePerGroupData *AggStatePerGroup;
typedef struct AggStatePerPhaseData *AggStatePerPhase; typedef struct AggStatePerPhaseData *AggStatePerPhase;
...@@ -1833,14 +1831,16 @@ typedef struct AggState ...@@ -1833,14 +1831,16 @@ typedef struct AggState
ScanState ss; /* its first field is NodeTag */ ScanState ss; /* its first field is NodeTag */
List *aggs; /* all Aggref nodes in targetlist & quals */ List *aggs; /* all Aggref nodes in targetlist & quals */
int numaggs; /* length of list (could be zero!) */ int numaggs; /* length of list (could be zero!) */
int numtrans; /* number of pertrans items */
AggStatePerPhase phase; /* pointer to current phase data */ AggStatePerPhase phase; /* pointer to current phase data */
int numphases; /* number of phases */ int numphases; /* number of phases */
int current_phase; /* current phase number */ int current_phase; /* current phase number */
FmgrInfo *hashfunctions; /* per-grouping-field hash fns */ FmgrInfo *hashfunctions; /* per-grouping-field hash fns */
AggStatePerAgg peragg; /* per-Aggref information */ AggStatePerAgg peragg; /* per-Aggref information */
AggStatePerTrans pertrans; /* per-Trans state information */
ExprContext **aggcontexts; /* econtexts for long-lived data (per GS) */ ExprContext **aggcontexts; /* econtexts for long-lived data (per GS) */
ExprContext *tmpcontext; /* econtext for input expressions */ ExprContext *tmpcontext; /* econtext for input expressions */
AggStatePerAgg curperagg; /* identifies currently active aggregate */ AggStatePerTrans curpertrans; /* currently active trans state */
bool input_done; /* indicates end of input */ bool input_done; /* indicates end of input */
bool agg_done; /* indicates completion of Agg scan */ bool agg_done; /* indicates completion of Agg scan */
int projected_set; /* The last projected grouping set */ int projected_set; /* The last projected grouping set */
......
...@@ -35,19 +35,23 @@ extern Oid resolve_aggregate_transtype(Oid aggfuncid, ...@@ -35,19 +35,23 @@ extern Oid resolve_aggregate_transtype(Oid aggfuncid,
Oid *inputTypes, Oid *inputTypes,
int numArguments); int numArguments);
extern void build_aggregate_fnexprs(Oid *agg_input_types, extern void build_aggregate_transfn_expr(Oid *agg_input_types,
int agg_num_inputs, int agg_num_inputs,
int agg_num_direct_inputs, int agg_num_direct_inputs,
int num_finalfn_inputs,
bool agg_variadic, bool agg_variadic,
Oid agg_state_type, Oid agg_state_type,
Oid agg_result_type,
Oid agg_input_collation, Oid agg_input_collation,
Oid transfn_oid, Oid transfn_oid,
Oid invtransfn_oid, Oid invtransfn_oid,
Oid finalfn_oid,
Expr **transfnexpr, Expr **transfnexpr,
Expr **invtransfnexpr, Expr **invtransfnexpr);
extern void build_aggregate_finalfn_expr(Oid *agg_input_types,
int num_finalfn_inputs,
Oid agg_state_type,
Oid agg_result_type,
Oid agg_input_collation,
Oid finalfn_oid,
Expr **finalfnexpr); Expr **finalfnexpr);
#endif /* PARSE_AGG_H */ #endif /* PARSE_AGG_H */
...@@ -1580,3 +1580,207 @@ select least_agg(variadic array[q1,q2]) from int8_tbl; ...@@ -1580,3 +1580,207 @@ select least_agg(variadic array[q1,q2]) from int8_tbl;
-4567890123456789 -4567890123456789
(1 row) (1 row)
-- test aggregates with common transition functions share the same states
begin work;
create type avg_state as (total bigint, count bigint);
create or replace function avg_transfn(state avg_state, n int) returns avg_state as
$$
declare new_state avg_state;
begin
raise notice 'avg_transfn called with %', n;
if state is null then
if n is not null then
new_state.total := n;
new_state.count := 1;
return new_state;
end if;
return null;
elsif n is not null then
state.total := state.total + n;
state.count := state.count + 1;
return state;
end if;
return null;
end
$$ language plpgsql;
create function avg_finalfn(state avg_state) returns int4 as
$$
begin
if state is null then
return NULL;
else
return state.total / state.count;
end if;
end
$$ language plpgsql;
create function sum_finalfn(state avg_state) returns int4 as
$$
begin
if state is null then
return NULL;
else
return state.total;
end if;
end
$$ language plpgsql;
create aggregate my_avg(int4)
(
stype = avg_state,
sfunc = avg_transfn,
finalfunc = avg_finalfn
);
create aggregate my_sum(int4)
(
stype = avg_state,
sfunc = avg_transfn,
finalfunc = sum_finalfn
);
-- aggregate state should be shared as aggs are the same.
select my_avg(one),my_avg(one) from (values(1),(3)) t(one);
NOTICE: avg_transfn called with 1
NOTICE: avg_transfn called with 3
my_avg | my_avg
--------+--------
2 | 2
(1 row)
-- aggregate state should be shared as transfn is the same for both aggs.
select my_avg(one),my_sum(one) from (values(1),(3)) t(one);
NOTICE: avg_transfn called with 1
NOTICE: avg_transfn called with 3
my_avg | my_sum
--------+--------
2 | 4
(1 row)
-- shouldn't share states due to the distinctness not matching.
select my_avg(distinct one),my_sum(one) from (values(1),(3)) t(one);
NOTICE: avg_transfn called with 1
NOTICE: avg_transfn called with 3
NOTICE: avg_transfn called with 1
NOTICE: avg_transfn called with 3
my_avg | my_sum
--------+--------
2 | 4
(1 row)
-- shouldn't share states due to the filter clause not matching.
select my_avg(one) filter (where one > 1),my_sum(one) from (values(1),(3)) t(one);
NOTICE: avg_transfn called with 1
NOTICE: avg_transfn called with 3
NOTICE: avg_transfn called with 3
my_avg | my_sum
--------+--------
3 | 4
(1 row)
-- this should not share the state due to different input columns.
select my_avg(one),my_sum(two) from (values(1,2),(3,4)) t(one,two);
NOTICE: avg_transfn called with 2
NOTICE: avg_transfn called with 1
NOTICE: avg_transfn called with 4
NOTICE: avg_transfn called with 3
my_avg | my_sum
--------+--------
2 | 6
(1 row)
-- test that aggs with the same sfunc and initcond share the same agg state
create aggregate my_sum_init(int4)
(
stype = avg_state,
sfunc = avg_transfn,
finalfunc = sum_finalfn,
initcond = '(10,0)'
);
create aggregate my_avg_init(int4)
(
stype = avg_state,
sfunc = avg_transfn,
finalfunc = avg_finalfn,
initcond = '(10,0)'
);
create aggregate my_avg_init2(int4)
(
stype = avg_state,
sfunc = avg_transfn,
finalfunc = avg_finalfn,
initcond = '(4,0)'
);
-- state should be shared if INITCONDs are matching
select my_sum_init(one),my_avg_init(one) from (values(1),(3)) t(one);
NOTICE: avg_transfn called with 1
NOTICE: avg_transfn called with 3
my_sum_init | my_avg_init
-------------+-------------
14 | 7
(1 row)
-- Varying INITCONDs should cause the states not to be shared.
select my_sum_init(one),my_avg_init2(one) from (values(1),(3)) t(one);
NOTICE: avg_transfn called with 1
NOTICE: avg_transfn called with 1
NOTICE: avg_transfn called with 3
NOTICE: avg_transfn called with 3
my_sum_init | my_avg_init2
-------------+--------------
14 | 4
(1 row)
rollback;
-- test aggregate state sharing to ensure it works if one aggregate has a
-- finalfn and the other one has none.
begin work;
create or replace function sum_transfn(state int4, n int4) returns int4 as
$$
declare new_state int4;
begin
raise notice 'sum_transfn called with %', n;
if state is null then
if n is not null then
new_state := n;
return new_state;
end if;
return null;
elsif n is not null then
state := state + n;
return state;
end if;
return null;
end
$$ language plpgsql;
create function halfsum_finalfn(state int4) returns int4 as
$$
begin
if state is null then
return NULL;
else
return state / 2;
end if;
end
$$ language plpgsql;
create aggregate my_sum(int4)
(
stype = int4,
sfunc = sum_transfn
);
create aggregate my_half_sum(int4)
(
stype = int4,
sfunc = sum_transfn,
finalfunc = halfsum_finalfn
);
-- Agg state should be shared even though my_sum has no finalfn
select my_sum(one),my_half_sum(one) from (values(1),(2),(3),(4)) t(one);
NOTICE: sum_transfn called with 1
NOTICE: sum_transfn called with 2
NOTICE: sum_transfn called with 3
NOTICE: sum_transfn called with 4
my_sum | my_half_sum
--------+-------------
10 | 5
(1 row)
rollback;
...@@ -590,3 +590,168 @@ drop view aggordview1; ...@@ -590,3 +590,168 @@ drop view aggordview1;
-- variadic aggregates -- variadic aggregates
select least_agg(q1,q2) from int8_tbl; select least_agg(q1,q2) from int8_tbl;
select least_agg(variadic array[q1,q2]) from int8_tbl; select least_agg(variadic array[q1,q2]) from int8_tbl;
-- test aggregates with common transition functions share the same states
begin work;
create type avg_state as (total bigint, count bigint);
create or replace function avg_transfn(state avg_state, n int) returns avg_state as
$$
declare new_state avg_state;
begin
raise notice 'avg_transfn called with %', n;
if state is null then
if n is not null then
new_state.total := n;
new_state.count := 1;
return new_state;
end if;
return null;
elsif n is not null then
state.total := state.total + n;
state.count := state.count + 1;
return state;
end if;
return null;
end
$$ language plpgsql;
create function avg_finalfn(state avg_state) returns int4 as
$$
begin
if state is null then
return NULL;
else
return state.total / state.count;
end if;
end
$$ language plpgsql;
create function sum_finalfn(state avg_state) returns int4 as
$$
begin
if state is null then
return NULL;
else
return state.total;
end if;
end
$$ language plpgsql;
create aggregate my_avg(int4)
(
stype = avg_state,
sfunc = avg_transfn,
finalfunc = avg_finalfn
);
create aggregate my_sum(int4)
(
stype = avg_state,
sfunc = avg_transfn,
finalfunc = sum_finalfn
);
-- aggregate state should be shared as aggs are the same.
select my_avg(one),my_avg(one) from (values(1),(3)) t(one);
-- aggregate state should be shared as transfn is the same for both aggs.
select my_avg(one),my_sum(one) from (values(1),(3)) t(one);
-- shouldn't share states due to the distinctness not matching.
select my_avg(distinct one),my_sum(one) from (values(1),(3)) t(one);
-- shouldn't share states due to the filter clause not matching.
select my_avg(one) filter (where one > 1),my_sum(one) from (values(1),(3)) t(one);
-- this should not share the state due to different input columns.
select my_avg(one),my_sum(two) from (values(1,2),(3,4)) t(one,two);
-- test that aggs with the same sfunc and initcond share the same agg state
create aggregate my_sum_init(int4)
(
stype = avg_state,
sfunc = avg_transfn,
finalfunc = sum_finalfn,
initcond = '(10,0)'
);
create aggregate my_avg_init(int4)
(
stype = avg_state,
sfunc = avg_transfn,
finalfunc = avg_finalfn,
initcond = '(10,0)'
);
create aggregate my_avg_init2(int4)
(
stype = avg_state,
sfunc = avg_transfn,
finalfunc = avg_finalfn,
initcond = '(4,0)'
);
-- state should be shared if INITCONDs are matching
select my_sum_init(one),my_avg_init(one) from (values(1),(3)) t(one);
-- Varying INITCONDs should cause the states not to be shared.
select my_sum_init(one),my_avg_init2(one) from (values(1),(3)) t(one);
rollback;
-- test aggregate state sharing to ensure it works if one aggregate has a
-- finalfn and the other one has none.
begin work;
create or replace function sum_transfn(state int4, n int4) returns int4 as
$$
declare new_state int4;
begin
raise notice 'sum_transfn called with %', n;
if state is null then
if n is not null then
new_state := n;
return new_state;
end if;
return null;
elsif n is not null then
state := state + n;
return state;
end if;
return null;
end
$$ language plpgsql;
create function halfsum_finalfn(state int4) returns int4 as
$$
begin
if state is null then
return NULL;
else
return state / 2;
end if;
end
$$ language plpgsql;
create aggregate my_sum(int4)
(
stype = int4,
sfunc = sum_transfn
);
create aggregate my_half_sum(int4)
(
stype = int4,
sfunc = sum_transfn,
finalfunc = halfsum_finalfn
);
-- Agg state should be shared even though my_sum has no finalfn
select my_sum(one),my_half_sum(one) from (values(1),(2),(3),(4)) t(one);
rollback;
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment