Commit 804163bc authored by Heikki Linnakangas's avatar Heikki Linnakangas

Share transition state between different aggregates when possible.

If there are two different aggregates in the query with same inputs, and
the aggregates have the same initial condition and transition function,
only calculate the state value once, and only call the final functions
separately. For example, AVG(x) and SUM(x) aggregates have the same
transition function, which accumulates the sum and number of input tuples.
For a query like "SELECT AVG(x), SUM(x) FROM x", we can therefore
accumulate the state function only once, which gives a nice speedup.

David Rowley, reviewed and edited by me.
parent dee0200f
......@@ -4487,35 +4487,15 @@ ExecInitExpr(Expr *node, PlanState *parent)
break;
case T_Aggref:
{
Aggref *aggref = (Aggref *) node;
AggrefExprState *astate = makeNode(AggrefExprState);
astate->xprstate.evalfunc = (ExprStateEvalFunc) ExecEvalAggref;
if (parent && IsA(parent, AggState))
{
AggState *aggstate = (AggState *) parent;
int naggs;
aggstate->aggs = lcons(astate, aggstate->aggs);
naggs = ++aggstate->numaggs;
astate->aggdirectargs = (List *) ExecInitExpr((Expr *) aggref->aggdirectargs,
parent);
astate->args = (List *) ExecInitExpr((Expr *) aggref->args,
parent);
astate->aggfilter = ExecInitExpr(aggref->aggfilter,
parent);
/*
* Complain if the aggregate's arguments contain any
* aggregates; nested agg functions are semantically
* nonsensical. (This should have been caught earlier,
* but we defend against it here anyway.)
*/
if (naggs != aggstate->numaggs)
ereport(ERROR,
(errcode(ERRCODE_GROUPING_ERROR),
errmsg("aggregate function calls cannot be nested")));
aggstate->numaggs++;
}
else
{
......
This diff is collapsed.
......@@ -2218,20 +2218,16 @@ initialize_peragg(WindowAggState *winstate, WindowFunc *wfunc,
numArguments);
/* build expression trees using actual argument & result types */
build_aggregate_fnexprs(inputTypes,
numArguments,
0, /* no ordered-set window functions yet */
peraggstate->numFinalArgs,
false, /* no variadic window functions yet */
aggtranstype,
wfunc->wintype,
wfunc->inputcollid,
transfn_oid,
invtransfn_oid,
finalfn_oid,
&transfnexpr,
&invtransfnexpr,
&finalfnexpr);
build_aggregate_transfn_expr(inputTypes,
numArguments,
0, /* no ordered-set window functions yet */
false, /* no variadic window functions yet */
wfunc->wintype,
wfunc->inputcollid,
transfn_oid,
invtransfn_oid,
&transfnexpr,
&invtransfnexpr);
/* set up infrastructure for calling the transfn(s) and finalfn */
fmgr_info(transfn_oid, &peraggstate->transfn);
......@@ -2245,6 +2241,13 @@ initialize_peragg(WindowAggState *winstate, WindowFunc *wfunc,
if (OidIsValid(finalfn_oid))
{
build_aggregate_finalfn_expr(inputTypes,
peraggstate->numFinalArgs,
aggtranstype,
wfunc->wintype,
wfunc->inputcollid,
finalfn_oid,
&finalfnexpr);
fmgr_info(finalfn_oid, &peraggstate->finalfn);
fmgr_info_set_expr((Node *) finalfnexpr, &peraggstate->finalfn);
}
......
......@@ -1829,44 +1829,40 @@ resolve_aggregate_transtype(Oid aggfuncid,
}
/*
* Create expression trees for the transition and final functions
* of an aggregate. These are needed so that polymorphic functions
* can be used within an aggregate --- without the expression trees,
* such functions would not know the datatypes they are supposed to use.
* (The trees will never actually be executed, however, so we can skimp
* a bit on correctness.)
* Create an expression tree for the transition function of an aggregate.
* This is needed so that polymorphic functions can be used within an
* aggregate --- without the expression tree, such functions would not know
* the datatypes they are supposed to use. (The trees will never actually
* be executed, however, so we can skimp a bit on correctness.)
*
* agg_input_types, agg_state_type, agg_result_type identify the input,
* transition, and result types of the aggregate. These should all be
* resolved to actual types (ie, none should ever be ANYELEMENT etc).
* agg_input_types and agg_state_type identifies the input types of the
* aggregate. These should be resolved to actual types (ie, none should
* ever be ANYELEMENT etc).
* agg_input_collation is the aggregate function's input collation.
*
* For an ordered-set aggregate, remember that agg_input_types describes
* the direct arguments followed by the aggregated arguments.
*
* transfn_oid, invtransfn_oid and finalfn_oid identify the funcs to be
* called; the latter two may be InvalidOid.
* transfn_oid and invtransfn_oid identify the funcs to be called; the
* latter may be InvalidOid, however if invtransfn_oid is set then
* transfn_oid must also be set.
*
* Pointers to the constructed trees are returned into *transfnexpr,
* *invtransfnexpr and *finalfnexpr. If there is no invtransfn or finalfn,
* the respective pointers are set to NULL. Since use of the invtransfn is
* optional, NULL may be passed for invtransfnexpr.
* *invtransfnexpr. If there is no invtransfn, the respective pointer is set
* to NULL. Since use of the invtransfn is optional, NULL may be passed for
* invtransfnexpr.
*/
void
build_aggregate_fnexprs(Oid *agg_input_types,
int agg_num_inputs,
int agg_num_direct_inputs,
int num_finalfn_inputs,
bool agg_variadic,
Oid agg_state_type,
Oid agg_result_type,
Oid agg_input_collation,
Oid transfn_oid,
Oid invtransfn_oid,
Oid finalfn_oid,
Expr **transfnexpr,
Expr **invtransfnexpr,
Expr **finalfnexpr)
build_aggregate_transfn_expr(Oid *agg_input_types,
int agg_num_inputs,
int agg_num_direct_inputs,
bool agg_variadic,
Oid agg_state_type,
Oid agg_input_collation,
Oid transfn_oid,
Oid invtransfn_oid,
Expr **transfnexpr,
Expr **invtransfnexpr)
{
Param *argp;
List *args;
......@@ -1929,13 +1925,24 @@ build_aggregate_fnexprs(Oid *agg_input_types,
else
*invtransfnexpr = NULL;
}
}
/* see if we have a final function */
if (!OidIsValid(finalfn_oid))
{
*finalfnexpr = NULL;
return;
}
/*
* Like build_aggregate_transfn_expr, but creates an expression tree for the
* final function of an aggregate, rather than the transition function.
*/
void
build_aggregate_finalfn_expr(Oid *agg_input_types,
int num_finalfn_inputs,
Oid agg_state_type,
Oid agg_result_type,
Oid agg_input_collation,
Oid finalfn_oid,
Expr **finalfnexpr)
{
Param *argp;
List *args;
int i;
/*
* Build expr tree for final function
......
......@@ -609,9 +609,6 @@ typedef struct WholeRowVarExprState
typedef struct AggrefExprState
{
ExprState xprstate;
List *aggdirectargs; /* states of direct-argument expressions */
List *args; /* states of aggregated-argument expressions */
ExprState *aggfilter; /* state of FILTER expression, if any */
int aggno; /* ID number for agg within its plan node */
} AggrefExprState;
......@@ -1825,6 +1822,7 @@ typedef struct GroupState
*/
/* these structs are private in nodeAgg.c: */
typedef struct AggStatePerAggData *AggStatePerAgg;
typedef struct AggStatePerTransData *AggStatePerTrans;
typedef struct AggStatePerGroupData *AggStatePerGroup;
typedef struct AggStatePerPhaseData *AggStatePerPhase;
......@@ -1833,14 +1831,16 @@ typedef struct AggState
ScanState ss; /* its first field is NodeTag */
List *aggs; /* all Aggref nodes in targetlist & quals */
int numaggs; /* length of list (could be zero!) */
int numtrans; /* number of pertrans items */
AggStatePerPhase phase; /* pointer to current phase data */
int numphases; /* number of phases */
int current_phase; /* current phase number */
FmgrInfo *hashfunctions; /* per-grouping-field hash fns */
AggStatePerAgg peragg; /* per-Aggref information */
AggStatePerTrans pertrans; /* per-Trans state information */
ExprContext **aggcontexts; /* econtexts for long-lived data (per GS) */
ExprContext *tmpcontext; /* econtext for input expressions */
AggStatePerAgg curperagg; /* identifies currently active aggregate */
AggStatePerTrans curpertrans; /* currently active trans state */
bool input_done; /* indicates end of input */
bool agg_done; /* indicates completion of Agg scan */
int projected_set; /* The last projected grouping set */
......
......@@ -35,19 +35,23 @@ extern Oid resolve_aggregate_transtype(Oid aggfuncid,
Oid *inputTypes,
int numArguments);
extern void build_aggregate_fnexprs(Oid *agg_input_types,
extern void build_aggregate_transfn_expr(Oid *agg_input_types,
int agg_num_inputs,
int agg_num_direct_inputs,
int num_finalfn_inputs,
bool agg_variadic,
Oid agg_state_type,
Oid agg_result_type,
Oid agg_input_collation,
Oid transfn_oid,
Oid invtransfn_oid,
Oid finalfn_oid,
Expr **transfnexpr,
Expr **invtransfnexpr,
Expr **invtransfnexpr);
extern void build_aggregate_finalfn_expr(Oid *agg_input_types,
int num_finalfn_inputs,
Oid agg_state_type,
Oid agg_result_type,
Oid agg_input_collation,
Oid finalfn_oid,
Expr **finalfnexpr);
#endif /* PARSE_AGG_H */
......@@ -1580,3 +1580,207 @@ select least_agg(variadic array[q1,q2]) from int8_tbl;
-4567890123456789
(1 row)
-- test aggregates with common transition functions share the same states
begin work;
create type avg_state as (total bigint, count bigint);
create or replace function avg_transfn(state avg_state, n int) returns avg_state as
$$
declare new_state avg_state;
begin
raise notice 'avg_transfn called with %', n;
if state is null then
if n is not null then
new_state.total := n;
new_state.count := 1;
return new_state;
end if;
return null;
elsif n is not null then
state.total := state.total + n;
state.count := state.count + 1;
return state;
end if;
return null;
end
$$ language plpgsql;
create function avg_finalfn(state avg_state) returns int4 as
$$
begin
if state is null then
return NULL;
else
return state.total / state.count;
end if;
end
$$ language plpgsql;
create function sum_finalfn(state avg_state) returns int4 as
$$
begin
if state is null then
return NULL;
else
return state.total;
end if;
end
$$ language plpgsql;
create aggregate my_avg(int4)
(
stype = avg_state,
sfunc = avg_transfn,
finalfunc = avg_finalfn
);
create aggregate my_sum(int4)
(
stype = avg_state,
sfunc = avg_transfn,
finalfunc = sum_finalfn
);
-- aggregate state should be shared as aggs are the same.
select my_avg(one),my_avg(one) from (values(1),(3)) t(one);
NOTICE: avg_transfn called with 1
NOTICE: avg_transfn called with 3
my_avg | my_avg
--------+--------
2 | 2
(1 row)
-- aggregate state should be shared as transfn is the same for both aggs.
select my_avg(one),my_sum(one) from (values(1),(3)) t(one);
NOTICE: avg_transfn called with 1
NOTICE: avg_transfn called with 3
my_avg | my_sum
--------+--------
2 | 4
(1 row)
-- shouldn't share states due to the distinctness not matching.
select my_avg(distinct one),my_sum(one) from (values(1),(3)) t(one);
NOTICE: avg_transfn called with 1
NOTICE: avg_transfn called with 3
NOTICE: avg_transfn called with 1
NOTICE: avg_transfn called with 3
my_avg | my_sum
--------+--------
2 | 4
(1 row)
-- shouldn't share states due to the filter clause not matching.
select my_avg(one) filter (where one > 1),my_sum(one) from (values(1),(3)) t(one);
NOTICE: avg_transfn called with 1
NOTICE: avg_transfn called with 3
NOTICE: avg_transfn called with 3
my_avg | my_sum
--------+--------
3 | 4
(1 row)
-- this should not share the state due to different input columns.
select my_avg(one),my_sum(two) from (values(1,2),(3,4)) t(one,two);
NOTICE: avg_transfn called with 2
NOTICE: avg_transfn called with 1
NOTICE: avg_transfn called with 4
NOTICE: avg_transfn called with 3
my_avg | my_sum
--------+--------
2 | 6
(1 row)
-- test that aggs with the same sfunc and initcond share the same agg state
create aggregate my_sum_init(int4)
(
stype = avg_state,
sfunc = avg_transfn,
finalfunc = sum_finalfn,
initcond = '(10,0)'
);
create aggregate my_avg_init(int4)
(
stype = avg_state,
sfunc = avg_transfn,
finalfunc = avg_finalfn,
initcond = '(10,0)'
);
create aggregate my_avg_init2(int4)
(
stype = avg_state,
sfunc = avg_transfn,
finalfunc = avg_finalfn,
initcond = '(4,0)'
);
-- state should be shared if INITCONDs are matching
select my_sum_init(one),my_avg_init(one) from (values(1),(3)) t(one);
NOTICE: avg_transfn called with 1
NOTICE: avg_transfn called with 3
my_sum_init | my_avg_init
-------------+-------------
14 | 7
(1 row)
-- Varying INITCONDs should cause the states not to be shared.
select my_sum_init(one),my_avg_init2(one) from (values(1),(3)) t(one);
NOTICE: avg_transfn called with 1
NOTICE: avg_transfn called with 1
NOTICE: avg_transfn called with 3
NOTICE: avg_transfn called with 3
my_sum_init | my_avg_init2
-------------+--------------
14 | 4
(1 row)
rollback;
-- test aggregate state sharing to ensure it works if one aggregate has a
-- finalfn and the other one has none.
begin work;
create or replace function sum_transfn(state int4, n int4) returns int4 as
$$
declare new_state int4;
begin
raise notice 'sum_transfn called with %', n;
if state is null then
if n is not null then
new_state := n;
return new_state;
end if;
return null;
elsif n is not null then
state := state + n;
return state;
end if;
return null;
end
$$ language plpgsql;
create function halfsum_finalfn(state int4) returns int4 as
$$
begin
if state is null then
return NULL;
else
return state / 2;
end if;
end
$$ language plpgsql;
create aggregate my_sum(int4)
(
stype = int4,
sfunc = sum_transfn
);
create aggregate my_half_sum(int4)
(
stype = int4,
sfunc = sum_transfn,
finalfunc = halfsum_finalfn
);
-- Agg state should be shared even though my_sum has no finalfn
select my_sum(one),my_half_sum(one) from (values(1),(2),(3),(4)) t(one);
NOTICE: sum_transfn called with 1
NOTICE: sum_transfn called with 2
NOTICE: sum_transfn called with 3
NOTICE: sum_transfn called with 4
my_sum | my_half_sum
--------+-------------
10 | 5
(1 row)
rollback;
......@@ -590,3 +590,168 @@ drop view aggordview1;
-- variadic aggregates
select least_agg(q1,q2) from int8_tbl;
select least_agg(variadic array[q1,q2]) from int8_tbl;
-- test aggregates with common transition functions share the same states
begin work;
create type avg_state as (total bigint, count bigint);
create or replace function avg_transfn(state avg_state, n int) returns avg_state as
$$
declare new_state avg_state;
begin
raise notice 'avg_transfn called with %', n;
if state is null then
if n is not null then
new_state.total := n;
new_state.count := 1;
return new_state;
end if;
return null;
elsif n is not null then
state.total := state.total + n;
state.count := state.count + 1;
return state;
end if;
return null;
end
$$ language plpgsql;
create function avg_finalfn(state avg_state) returns int4 as
$$
begin
if state is null then
return NULL;
else
return state.total / state.count;
end if;
end
$$ language plpgsql;
create function sum_finalfn(state avg_state) returns int4 as
$$
begin
if state is null then
return NULL;
else
return state.total;
end if;
end
$$ language plpgsql;
create aggregate my_avg(int4)
(
stype = avg_state,
sfunc = avg_transfn,
finalfunc = avg_finalfn
);
create aggregate my_sum(int4)
(
stype = avg_state,
sfunc = avg_transfn,
finalfunc = sum_finalfn
);
-- aggregate state should be shared as aggs are the same.
select my_avg(one),my_avg(one) from (values(1),(3)) t(one);
-- aggregate state should be shared as transfn is the same for both aggs.
select my_avg(one),my_sum(one) from (values(1),(3)) t(one);
-- shouldn't share states due to the distinctness not matching.
select my_avg(distinct one),my_sum(one) from (values(1),(3)) t(one);
-- shouldn't share states due to the filter clause not matching.
select my_avg(one) filter (where one > 1),my_sum(one) from (values(1),(3)) t(one);
-- this should not share the state due to different input columns.
select my_avg(one),my_sum(two) from (values(1,2),(3,4)) t(one,two);
-- test that aggs with the same sfunc and initcond share the same agg state
create aggregate my_sum_init(int4)
(
stype = avg_state,
sfunc = avg_transfn,
finalfunc = sum_finalfn,
initcond = '(10,0)'
);
create aggregate my_avg_init(int4)
(
stype = avg_state,
sfunc = avg_transfn,
finalfunc = avg_finalfn,
initcond = '(10,0)'
);
create aggregate my_avg_init2(int4)
(
stype = avg_state,
sfunc = avg_transfn,
finalfunc = avg_finalfn,
initcond = '(4,0)'
);
-- state should be shared if INITCONDs are matching
select my_sum_init(one),my_avg_init(one) from (values(1),(3)) t(one);
-- Varying INITCONDs should cause the states not to be shared.
select my_sum_init(one),my_avg_init2(one) from (values(1),(3)) t(one);
rollback;
-- test aggregate state sharing to ensure it works if one aggregate has a
-- finalfn and the other one has none.
begin work;
create or replace function sum_transfn(state int4, n int4) returns int4 as
$$
declare new_state int4;
begin
raise notice 'sum_transfn called with %', n;
if state is null then
if n is not null then
new_state := n;
return new_state;
end if;
return null;
elsif n is not null then
state := state + n;
return state;
end if;
return null;
end
$$ language plpgsql;
create function halfsum_finalfn(state int4) returns int4 as
$$
begin
if state is null then
return NULL;
else
return state / 2;
end if;
end
$$ language plpgsql;
create aggregate my_sum(int4)
(
stype = int4,
sfunc = sum_transfn
);
create aggregate my_half_sum(int4)
(
stype = int4,
sfunc = sum_transfn,
finalfunc = halfsum_finalfn
);
-- Agg state should be shared even though my_sum has no finalfn
select my_sum(one),my_half_sum(one) from (values(1),(2),(3),(4)) t(one);
rollback;
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment