Commit b560ec1b authored by Noah Misch's avatar Noah Misch

Implement the FILTER clause for aggregate function calls.

This is SQL-standard with a few extensions, namely support for
subqueries and outer references in clause expressions.

catversion bump due to change in Aggref and WindowFunc.

David Fetter, reviewed by Dean Rasheed.
parent 7a8e9f29
...@@ -1546,6 +1546,7 @@ JumbleExpr(pgssJumbleState *jstate, Node *node) ...@@ -1546,6 +1546,7 @@ JumbleExpr(pgssJumbleState *jstate, Node *node)
JumbleExpr(jstate, (Node *) expr->args); JumbleExpr(jstate, (Node *) expr->args);
JumbleExpr(jstate, (Node *) expr->aggorder); JumbleExpr(jstate, (Node *) expr->aggorder);
JumbleExpr(jstate, (Node *) expr->aggdistinct); JumbleExpr(jstate, (Node *) expr->aggdistinct);
JumbleExpr(jstate, (Node *) expr->aggfilter);
} }
break; break;
case T_WindowFunc: case T_WindowFunc:
...@@ -1555,6 +1556,7 @@ JumbleExpr(pgssJumbleState *jstate, Node *node) ...@@ -1555,6 +1556,7 @@ JumbleExpr(pgssJumbleState *jstate, Node *node)
APP_JUMB(expr->winfnoid); APP_JUMB(expr->winfnoid);
APP_JUMB(expr->winref); APP_JUMB(expr->winref);
JumbleExpr(jstate, (Node *) expr->args); JumbleExpr(jstate, (Node *) expr->args);
JumbleExpr(jstate, (Node *) expr->aggfilter);
} }
break; break;
case T_ArrayRef: case T_ArrayRef:
......
...@@ -1786,7 +1786,7 @@ ...@@ -1786,7 +1786,7 @@
</row> </row>
<row> <row>
<entry><token>FILTER</token></entry> <entry><token>FILTER</token></entry>
<entry></entry> <entry>non-reserved</entry>
<entry>reserved</entry> <entry>reserved</entry>
<entry>reserved</entry> <entry>reserved</entry>
<entry></entry> <entry></entry>
......
...@@ -598,6 +598,11 @@ GROUP BY <replaceable class="parameter">expression</replaceable> [, ...] ...@@ -598,6 +598,11 @@ GROUP BY <replaceable class="parameter">expression</replaceable> [, ...]
making up each group, producing a separate value for each group making up each group, producing a separate value for each group
(whereas without <literal>GROUP BY</literal>, an aggregate (whereas without <literal>GROUP BY</literal>, an aggregate
produces a single value computed across all the selected rows). produces a single value computed across all the selected rows).
The set of rows fed to the aggregate function can be further filtered by
attaching a <literal>FILTER</literal> clause to the aggregate function
call; see <xref linkend="syntax-aggregates"> for more information. When
a <literal>FILTER</literal> clause is present, only those rows matching it
are included.
When <literal>GROUP BY</literal> is present, it is not valid for When <literal>GROUP BY</literal> is present, it is not valid for
the <command>SELECT</command> list expressions to refer to the <command>SELECT</command> list expressions to refer to
ungrouped columns except within aggregate functions or if the ungrouped columns except within aggregate functions or if the
......
...@@ -1554,6 +1554,10 @@ sqrt(2) ...@@ -1554,6 +1554,10 @@ sqrt(2)
<secondary>invocation</secondary> <secondary>invocation</secondary>
</indexterm> </indexterm>
<indexterm zone="syntax-aggregates">
<primary>filter</primary>
</indexterm>
<para> <para>
An <firstterm>aggregate expression</firstterm> represents the An <firstterm>aggregate expression</firstterm> represents the
application of an aggregate function across the rows selected by a application of an aggregate function across the rows selected by a
...@@ -1562,19 +1566,19 @@ sqrt(2) ...@@ -1562,19 +1566,19 @@ sqrt(2)
syntax of an aggregate expression is one of the following: syntax of an aggregate expression is one of the following:
<synopsis> <synopsis>
<replaceable>aggregate_name</replaceable> (<replaceable>expression</replaceable> [ , ... ] [ <replaceable>order_by_clause</replaceable> ] ) <replaceable>aggregate_name</replaceable> (<replaceable>expression</replaceable> [ , ... ] [ <replaceable>order_by_clause</replaceable> ] ) [ FILTER ( WHERE <replaceable>filter_clause</replaceable> ) ]
<replaceable>aggregate_name</replaceable> (ALL <replaceable>expression</replaceable> [ , ... ] [ <replaceable>order_by_clause</replaceable> ] ) <replaceable>aggregate_name</replaceable> (ALL <replaceable>expression</replaceable> [ , ... ] [ <replaceable>order_by_clause</replaceable> ] ) [ FILTER ( WHERE <replaceable>filter_clause</replaceable> ) ]
<replaceable>aggregate_name</replaceable> (DISTINCT <replaceable>expression</replaceable> [ , ... ] [ <replaceable>order_by_clause</replaceable> ] ) <replaceable>aggregate_name</replaceable> (DISTINCT <replaceable>expression</replaceable> [ , ... ] [ <replaceable>order_by_clause</replaceable> ] ) [ FILTER ( WHERE <replaceable>filter_clause</replaceable> ) ]
<replaceable>aggregate_name</replaceable> ( * ) <replaceable>aggregate_name</replaceable> ( * ) [ FILTER ( WHERE <replaceable>filter_clause</replaceable> ) ]
</synopsis> </synopsis>
where <replaceable>aggregate_name</replaceable> is a previously where <replaceable>aggregate_name</replaceable> is a previously
defined aggregate (possibly qualified with a schema name), defined aggregate (possibly qualified with a schema name) and
<replaceable>expression</replaceable> is <replaceable>expression</replaceable> is
any value expression that does not itself contain an aggregate any value expression that does not itself contain an aggregate
expression or a window function call, and expression or a window function call. The optional
<replaceable>order_by_clause</replaceable> is a optional <replaceable>order_by_clause</replaceable> and
<literal>ORDER BY</> clause as described below. <replaceable>filter_clause</replaceable> are described below.
</para> </para>
<para> <para>
...@@ -1606,6 +1610,23 @@ sqrt(2) ...@@ -1606,6 +1610,23 @@ sqrt(2)
distinct non-null values of <literal>f1</literal>. distinct non-null values of <literal>f1</literal>.
</para> </para>
<para>
If <literal>FILTER</literal> is specified, then only the input
rows for which the <replaceable>filter_clause</replaceable>
evaluates to true are fed to the aggregate function; other rows
are discarded. For example:
<programlisting>
SELECT
count(*) AS unfiltered,
count(*) FILTER (WHERE i < 5) AS filtered
FROM generate_series(1,10) AS s(i);
unfiltered | filtered
------------+----------
10 | 4
(1 row)
</programlisting>
</para>
<para> <para>
Ordinarily, the input rows are fed to the aggregate function in an Ordinarily, the input rows are fed to the aggregate function in an
unspecified order. In many cases this does not matter; for example, unspecified order. In many cases this does not matter; for example,
...@@ -1709,10 +1730,10 @@ SELECT string_agg(a ORDER BY a, ',') FROM table; -- incorrect ...@@ -1709,10 +1730,10 @@ SELECT string_agg(a ORDER BY a, ',') FROM table; -- incorrect
The syntax of a window function call is one of the following: The syntax of a window function call is one of the following:
<synopsis> <synopsis>
<replaceable>function_name</replaceable> (<optional><replaceable>expression</replaceable> <optional>, <replaceable>expression</replaceable> ... </optional></optional>) OVER ( <replaceable class="parameter">window_definition</replaceable> ) <replaceable>function_name</replaceable> (<optional><replaceable>expression</replaceable> <optional>, <replaceable>expression</replaceable> ... </optional></optional>) [ FILTER ( WHERE <replaceable>filter_clause</replaceable> ) ] OVER ( <replaceable class="parameter">window_definition</replaceable> )
<replaceable>function_name</replaceable> (<optional><replaceable>expression</replaceable> <optional>, <replaceable>expression</replaceable> ... </optional></optional>) OVER <replaceable>window_name</replaceable> <replaceable>function_name</replaceable> (<optional><replaceable>expression</replaceable> <optional>, <replaceable>expression</replaceable> ... </optional></optional>) [ FILTER ( WHERE <replaceable>filter_clause</replaceable> ) ] OVER <replaceable>window_name</replaceable>
<replaceable>function_name</replaceable> ( * ) OVER ( <replaceable class="parameter">window_definition</replaceable> ) <replaceable>function_name</replaceable> ( * ) [ FILTER ( WHERE <replaceable>filter_clause</replaceable> ) ] OVER ( <replaceable class="parameter">window_definition</replaceable> )
<replaceable>function_name</replaceable> ( * ) OVER <replaceable>window_name</replaceable> <replaceable>function_name</replaceable> ( * ) [ FILTER ( WHERE <replaceable>filter_clause</replaceable> ) ] OVER <replaceable>window_name</replaceable>
</synopsis> </synopsis>
where <replaceable class="parameter">window_definition</replaceable> where <replaceable class="parameter">window_definition</replaceable>
has the syntax has the syntax
...@@ -1836,7 +1857,8 @@ UNBOUNDED FOLLOWING ...@@ -1836,7 +1857,8 @@ UNBOUNDED FOLLOWING
The built-in window functions are described in <xref The built-in window functions are described in <xref
linkend="functions-window-table">. Other window functions can be added by linkend="functions-window-table">. Other window functions can be added by
the user. Also, any built-in or user-defined aggregate function can be the user. Also, any built-in or user-defined aggregate function can be
used as a window function. used as a window function. Only aggregate window functions accept
a <literal>FILTER</literal> clause.
</para> </para>
<para> <para>
......
...@@ -4410,6 +4410,8 @@ ExecInitExpr(Expr *node, PlanState *parent) ...@@ -4410,6 +4410,8 @@ ExecInitExpr(Expr *node, PlanState *parent)
astate->args = (List *) ExecInitExpr((Expr *) aggref->args, astate->args = (List *) ExecInitExpr((Expr *) aggref->args,
parent); parent);
astate->aggfilter = ExecInitExpr(aggref->aggfilter,
parent);
/* /*
* Complain if the aggregate's arguments contain any * Complain if the aggregate's arguments contain any
...@@ -4448,6 +4450,8 @@ ExecInitExpr(Expr *node, PlanState *parent) ...@@ -4448,6 +4450,8 @@ ExecInitExpr(Expr *node, PlanState *parent)
wfstate->args = (List *) ExecInitExpr((Expr *) wfunc->args, wfstate->args = (List *) ExecInitExpr((Expr *) wfunc->args,
parent); parent);
wfstate->aggfilter = ExecInitExpr(wfunc->aggfilter,
parent);
/* /*
* Complain if the windowfunc's arguments contain any * Complain if the windowfunc's arguments contain any
......
...@@ -649,9 +649,9 @@ get_last_attnums(Node *node, ProjectionInfo *projInfo) ...@@ -649,9 +649,9 @@ get_last_attnums(Node *node, ProjectionInfo *projInfo)
} }
/* /*
* Don't examine the arguments of Aggrefs or WindowFuncs, because those do * Don't examine the arguments or filters of Aggrefs or WindowFuncs,
* not represent expressions to be evaluated within the overall * because those do not represent expressions to be evaluated within the
* targetlist's econtext. * overall targetlist's econtext.
*/ */
if (IsA(node, Aggref)) if (IsA(node, Aggref))
return false; return false;
......
...@@ -380,7 +380,7 @@ sql_fn_post_column_ref(ParseState *pstate, ColumnRef *cref, Node *var) ...@@ -380,7 +380,7 @@ sql_fn_post_column_ref(ParseState *pstate, ColumnRef *cref, Node *var)
param = ParseFuncOrColumn(pstate, param = ParseFuncOrColumn(pstate,
list_make1(subfield), list_make1(subfield),
list_make1(param), list_make1(param),
NIL, false, false, false, NIL, NULL, false, false, false,
NULL, true, cref->location); NULL, true, cref->location);
} }
......
...@@ -484,10 +484,23 @@ advance_aggregates(AggState *aggstate, AggStatePerGroup pergroup) ...@@ -484,10 +484,23 @@ advance_aggregates(AggState *aggstate, AggStatePerGroup pergroup)
{ {
AggStatePerAgg peraggstate = &aggstate->peragg[aggno]; AggStatePerAgg peraggstate = &aggstate->peragg[aggno];
AggStatePerGroup pergroupstate = &pergroup[aggno]; AggStatePerGroup pergroupstate = &pergroup[aggno];
ExprState *filter = peraggstate->aggrefstate->aggfilter;
int nargs = peraggstate->numArguments; int nargs = peraggstate->numArguments;
int i; int i;
TupleTableSlot *slot; TupleTableSlot *slot;
/* Skip anything FILTERed out */
if (filter)
{
bool isnull;
Datum res;
res = ExecEvalExprSwitchContext(filter, aggstate->tmpcontext,
&isnull, NULL);
if (isnull || !DatumGetBool(res))
continue;
}
/* Evaluate the current input expressions for this aggregate */ /* Evaluate the current input expressions for this aggregate */
slot = ExecProject(peraggstate->evalproj, NULL); slot = ExecProject(peraggstate->evalproj, NULL);
......
...@@ -227,9 +227,23 @@ advance_windowaggregate(WindowAggState *winstate, ...@@ -227,9 +227,23 @@ advance_windowaggregate(WindowAggState *winstate,
int i; int i;
MemoryContext oldContext; MemoryContext oldContext;
ExprContext *econtext = winstate->tmpcontext; ExprContext *econtext = winstate->tmpcontext;
ExprState *filter = wfuncstate->aggfilter;
oldContext = MemoryContextSwitchTo(econtext->ecxt_per_tuple_memory); oldContext = MemoryContextSwitchTo(econtext->ecxt_per_tuple_memory);
/* Skip anything FILTERed out */
if (filter)
{
bool isnull;
Datum res = ExecEvalExpr(filter, econtext, &isnull, NULL);
if (isnull || !DatumGetBool(res))
{
MemoryContextSwitchTo(oldContext);
return;
}
}
/* We start from 1, since the 0th arg will be the transition value */ /* We start from 1, since the 0th arg will be the transition value */
i = 1; i = 1;
foreach(arg, wfuncstate->args) foreach(arg, wfuncstate->args)
......
...@@ -1137,6 +1137,7 @@ _copyAggref(const Aggref *from) ...@@ -1137,6 +1137,7 @@ _copyAggref(const Aggref *from)
COPY_NODE_FIELD(args); COPY_NODE_FIELD(args);
COPY_NODE_FIELD(aggorder); COPY_NODE_FIELD(aggorder);
COPY_NODE_FIELD(aggdistinct); COPY_NODE_FIELD(aggdistinct);
COPY_NODE_FIELD(aggfilter);
COPY_SCALAR_FIELD(aggstar); COPY_SCALAR_FIELD(aggstar);
COPY_SCALAR_FIELD(agglevelsup); COPY_SCALAR_FIELD(agglevelsup);
COPY_LOCATION_FIELD(location); COPY_LOCATION_FIELD(location);
...@@ -1157,6 +1158,7 @@ _copyWindowFunc(const WindowFunc *from) ...@@ -1157,6 +1158,7 @@ _copyWindowFunc(const WindowFunc *from)
COPY_SCALAR_FIELD(wincollid); COPY_SCALAR_FIELD(wincollid);
COPY_SCALAR_FIELD(inputcollid); COPY_SCALAR_FIELD(inputcollid);
COPY_NODE_FIELD(args); COPY_NODE_FIELD(args);
COPY_NODE_FIELD(aggfilter);
COPY_SCALAR_FIELD(winref); COPY_SCALAR_FIELD(winref);
COPY_SCALAR_FIELD(winstar); COPY_SCALAR_FIELD(winstar);
COPY_SCALAR_FIELD(winagg); COPY_SCALAR_FIELD(winagg);
...@@ -2152,6 +2154,7 @@ _copyFuncCall(const FuncCall *from) ...@@ -2152,6 +2154,7 @@ _copyFuncCall(const FuncCall *from)
COPY_NODE_FIELD(funcname); COPY_NODE_FIELD(funcname);
COPY_NODE_FIELD(args); COPY_NODE_FIELD(args);
COPY_NODE_FIELD(agg_order); COPY_NODE_FIELD(agg_order);
COPY_NODE_FIELD(agg_filter);
COPY_SCALAR_FIELD(agg_star); COPY_SCALAR_FIELD(agg_star);
COPY_SCALAR_FIELD(agg_distinct); COPY_SCALAR_FIELD(agg_distinct);
COPY_SCALAR_FIELD(func_variadic); COPY_SCALAR_FIELD(func_variadic);
......
...@@ -196,6 +196,7 @@ _equalAggref(const Aggref *a, const Aggref *b) ...@@ -196,6 +196,7 @@ _equalAggref(const Aggref *a, const Aggref *b)
COMPARE_NODE_FIELD(args); COMPARE_NODE_FIELD(args);
COMPARE_NODE_FIELD(aggorder); COMPARE_NODE_FIELD(aggorder);
COMPARE_NODE_FIELD(aggdistinct); COMPARE_NODE_FIELD(aggdistinct);
COMPARE_NODE_FIELD(aggfilter);
COMPARE_SCALAR_FIELD(aggstar); COMPARE_SCALAR_FIELD(aggstar);
COMPARE_SCALAR_FIELD(agglevelsup); COMPARE_SCALAR_FIELD(agglevelsup);
COMPARE_LOCATION_FIELD(location); COMPARE_LOCATION_FIELD(location);
...@@ -211,6 +212,7 @@ _equalWindowFunc(const WindowFunc *a, const WindowFunc *b) ...@@ -211,6 +212,7 @@ _equalWindowFunc(const WindowFunc *a, const WindowFunc *b)
COMPARE_SCALAR_FIELD(wincollid); COMPARE_SCALAR_FIELD(wincollid);
COMPARE_SCALAR_FIELD(inputcollid); COMPARE_SCALAR_FIELD(inputcollid);
COMPARE_NODE_FIELD(args); COMPARE_NODE_FIELD(args);
COMPARE_NODE_FIELD(aggfilter);
COMPARE_SCALAR_FIELD(winref); COMPARE_SCALAR_FIELD(winref);
COMPARE_SCALAR_FIELD(winstar); COMPARE_SCALAR_FIELD(winstar);
COMPARE_SCALAR_FIELD(winagg); COMPARE_SCALAR_FIELD(winagg);
...@@ -1993,6 +1995,7 @@ _equalFuncCall(const FuncCall *a, const FuncCall *b) ...@@ -1993,6 +1995,7 @@ _equalFuncCall(const FuncCall *a, const FuncCall *b)
COMPARE_NODE_FIELD(funcname); COMPARE_NODE_FIELD(funcname);
COMPARE_NODE_FIELD(args); COMPARE_NODE_FIELD(args);
COMPARE_NODE_FIELD(agg_order); COMPARE_NODE_FIELD(agg_order);
COMPARE_NODE_FIELD(agg_filter);
COMPARE_SCALAR_FIELD(agg_star); COMPARE_SCALAR_FIELD(agg_star);
COMPARE_SCALAR_FIELD(agg_distinct); COMPARE_SCALAR_FIELD(agg_distinct);
COMPARE_SCALAR_FIELD(func_variadic); COMPARE_SCALAR_FIELD(func_variadic);
......
...@@ -526,6 +526,7 @@ makeFuncCall(List *name, List *args, int location) ...@@ -526,6 +526,7 @@ makeFuncCall(List *name, List *args, int location)
n->args = args; n->args = args;
n->location = location; n->location = location;
n->agg_order = NIL; n->agg_order = NIL;
n->agg_filter = NULL;
n->agg_star = FALSE; n->agg_star = FALSE;
n->agg_distinct = FALSE; n->agg_distinct = FALSE;
n->func_variadic = FALSE; n->func_variadic = FALSE;
......
...@@ -1570,6 +1570,8 @@ expression_tree_walker(Node *node, ...@@ -1570,6 +1570,8 @@ expression_tree_walker(Node *node,
if (expression_tree_walker((Node *) expr->aggdistinct, if (expression_tree_walker((Node *) expr->aggdistinct,
walker, context)) walker, context))
return true; return true;
if (walker((Node *) expr->aggfilter, context))
return true;
} }
break; break;
case T_WindowFunc: case T_WindowFunc:
...@@ -1580,6 +1582,8 @@ expression_tree_walker(Node *node, ...@@ -1580,6 +1582,8 @@ expression_tree_walker(Node *node,
if (expression_tree_walker((Node *) expr->args, if (expression_tree_walker((Node *) expr->args,
walker, context)) walker, context))
return true; return true;
if (walker((Node *) expr->aggfilter, context))
return true;
} }
break; break;
case T_ArrayRef: case T_ArrayRef:
...@@ -2079,6 +2083,7 @@ expression_tree_mutator(Node *node, ...@@ -2079,6 +2083,7 @@ expression_tree_mutator(Node *node,
MUTATE(newnode->args, aggref->args, List *); MUTATE(newnode->args, aggref->args, List *);
MUTATE(newnode->aggorder, aggref->aggorder, List *); MUTATE(newnode->aggorder, aggref->aggorder, List *);
MUTATE(newnode->aggdistinct, aggref->aggdistinct, List *); MUTATE(newnode->aggdistinct, aggref->aggdistinct, List *);
MUTATE(newnode->aggfilter, aggref->aggfilter, Expr *);
return (Node *) newnode; return (Node *) newnode;
} }
break; break;
...@@ -2089,6 +2094,7 @@ expression_tree_mutator(Node *node, ...@@ -2089,6 +2094,7 @@ expression_tree_mutator(Node *node,
FLATCOPY(newnode, wfunc, WindowFunc); FLATCOPY(newnode, wfunc, WindowFunc);
MUTATE(newnode->args, wfunc->args, List *); MUTATE(newnode->args, wfunc->args, List *);
MUTATE(newnode->aggfilter, wfunc->aggfilter, Expr *);
return (Node *) newnode; return (Node *) newnode;
} }
break; break;
...@@ -2951,6 +2957,8 @@ raw_expression_tree_walker(Node *node, ...@@ -2951,6 +2957,8 @@ raw_expression_tree_walker(Node *node,
return true; return true;
if (walker(fcall->agg_order, context)) if (walker(fcall->agg_order, context))
return true; return true;
if (walker(fcall->agg_filter, context))
return true;
if (walker(fcall->over, context)) if (walker(fcall->over, context))
return true; return true;
/* function name is deemed uninteresting */ /* function name is deemed uninteresting */
......
...@@ -958,6 +958,7 @@ _outAggref(StringInfo str, const Aggref *node) ...@@ -958,6 +958,7 @@ _outAggref(StringInfo str, const Aggref *node)
WRITE_NODE_FIELD(args); WRITE_NODE_FIELD(args);
WRITE_NODE_FIELD(aggorder); WRITE_NODE_FIELD(aggorder);
WRITE_NODE_FIELD(aggdistinct); WRITE_NODE_FIELD(aggdistinct);
WRITE_NODE_FIELD(aggfilter);
WRITE_BOOL_FIELD(aggstar); WRITE_BOOL_FIELD(aggstar);
WRITE_UINT_FIELD(agglevelsup); WRITE_UINT_FIELD(agglevelsup);
WRITE_LOCATION_FIELD(location); WRITE_LOCATION_FIELD(location);
...@@ -973,6 +974,7 @@ _outWindowFunc(StringInfo str, const WindowFunc *node) ...@@ -973,6 +974,7 @@ _outWindowFunc(StringInfo str, const WindowFunc *node)
WRITE_OID_FIELD(wincollid); WRITE_OID_FIELD(wincollid);
WRITE_OID_FIELD(inputcollid); WRITE_OID_FIELD(inputcollid);
WRITE_NODE_FIELD(args); WRITE_NODE_FIELD(args);
WRITE_NODE_FIELD(aggfilter);
WRITE_UINT_FIELD(winref); WRITE_UINT_FIELD(winref);
WRITE_BOOL_FIELD(winstar); WRITE_BOOL_FIELD(winstar);
WRITE_BOOL_FIELD(winagg); WRITE_BOOL_FIELD(winagg);
...@@ -2080,6 +2082,7 @@ _outFuncCall(StringInfo str, const FuncCall *node) ...@@ -2080,6 +2082,7 @@ _outFuncCall(StringInfo str, const FuncCall *node)
WRITE_NODE_FIELD(funcname); WRITE_NODE_FIELD(funcname);
WRITE_NODE_FIELD(args); WRITE_NODE_FIELD(args);
WRITE_NODE_FIELD(agg_order); WRITE_NODE_FIELD(agg_order);
WRITE_NODE_FIELD(agg_filter);
WRITE_BOOL_FIELD(agg_star); WRITE_BOOL_FIELD(agg_star);
WRITE_BOOL_FIELD(agg_distinct); WRITE_BOOL_FIELD(agg_distinct);
WRITE_BOOL_FIELD(func_variadic); WRITE_BOOL_FIELD(func_variadic);
......
...@@ -479,6 +479,7 @@ _readAggref(void) ...@@ -479,6 +479,7 @@ _readAggref(void)
READ_NODE_FIELD(args); READ_NODE_FIELD(args);
READ_NODE_FIELD(aggorder); READ_NODE_FIELD(aggorder);
READ_NODE_FIELD(aggdistinct); READ_NODE_FIELD(aggdistinct);
READ_NODE_FIELD(aggfilter);
READ_BOOL_FIELD(aggstar); READ_BOOL_FIELD(aggstar);
READ_UINT_FIELD(agglevelsup); READ_UINT_FIELD(agglevelsup);
READ_LOCATION_FIELD(location); READ_LOCATION_FIELD(location);
...@@ -499,6 +500,7 @@ _readWindowFunc(void) ...@@ -499,6 +500,7 @@ _readWindowFunc(void)
READ_OID_FIELD(wincollid); READ_OID_FIELD(wincollid);
READ_OID_FIELD(inputcollid); READ_OID_FIELD(inputcollid);
READ_NODE_FIELD(args); READ_NODE_FIELD(args);
READ_NODE_FIELD(aggfilter);
READ_UINT_FIELD(winref); READ_UINT_FIELD(winref);
READ_BOOL_FIELD(winstar); READ_BOOL_FIELD(winstar);
READ_BOOL_FIELD(winagg); READ_BOOL_FIELD(winagg);
......
...@@ -1590,6 +1590,14 @@ cost_windowagg(Path *path, PlannerInfo *root, ...@@ -1590,6 +1590,14 @@ cost_windowagg(Path *path, PlannerInfo *root,
startup_cost += argcosts.startup; startup_cost += argcosts.startup;
wfunccost += argcosts.per_tuple; wfunccost += argcosts.per_tuple;
/*
* Add the filter's cost to per-input-row costs. XXX We should reduce
* input expression costs according to filter selectivity.
*/
cost_qual_eval_node(&argcosts, (Node *) wfunc->aggfilter, root);
startup_cost += argcosts.startup;
wfunccost += argcosts.per_tuple;
total_cost += wfunccost * input_tuples; total_cost += wfunccost * input_tuples;
} }
......
...@@ -329,6 +329,12 @@ find_minmax_aggs_walker(Node *node, List **context) ...@@ -329,6 +329,12 @@ find_minmax_aggs_walker(Node *node, List **context)
*/ */
if (aggref->aggorder != NIL) if (aggref->aggorder != NIL)
return true; return true;
/*
* We might implement the optimization when a FILTER clause is present
* by adding the filter to the quals of the generated subquery.
*/
if (aggref->aggfilter != NULL)
return true;
/* note: we do not care if DISTINCT is mentioned ... */ /* note: we do not care if DISTINCT is mentioned ... */
aggsortop = fetch_agg_sort_op(aggref->aggfnoid); aggsortop = fetch_agg_sort_op(aggref->aggfnoid);
......
...@@ -495,6 +495,15 @@ count_agg_clauses_walker(Node *node, count_agg_clauses_context *context) ...@@ -495,6 +495,15 @@ count_agg_clauses_walker(Node *node, count_agg_clauses_context *context)
costs->transCost.startup += argcosts.startup; costs->transCost.startup += argcosts.startup;
costs->transCost.per_tuple += argcosts.per_tuple; costs->transCost.per_tuple += argcosts.per_tuple;
/*
* Add the filter's cost to per-input-row costs. XXX We should reduce
* input expression costs according to filter selectivity.
*/
cost_qual_eval_node(&argcosts, (Node *) aggref->aggfilter,
context->root);
costs->transCost.startup += argcosts.startup;
costs->transCost.per_tuple += argcosts.per_tuple;
/* extract argument types (ignoring any ORDER BY expressions) */ /* extract argument types (ignoring any ORDER BY expressions) */
inputTypes = (Oid *) palloc(sizeof(Oid) * list_length(aggref->args)); inputTypes = (Oid *) palloc(sizeof(Oid) * list_length(aggref->args));
numArguments = 0; numArguments = 0;
...@@ -565,7 +574,8 @@ count_agg_clauses_walker(Node *node, count_agg_clauses_context *context) ...@@ -565,7 +574,8 @@ count_agg_clauses_walker(Node *node, count_agg_clauses_context *context)
/* /*
* Complain if the aggregate's arguments contain any aggregates; * Complain if the aggregate's arguments contain any aggregates;
* nested agg functions are semantically nonsensical. * nested agg functions are semantically nonsensical. Aggregates in
* the FILTER clause are detected in transformAggregateCall().
*/ */
if (contain_agg_clause((Node *) aggref->args)) if (contain_agg_clause((Node *) aggref->args))
ereport(ERROR, ereport(ERROR,
...@@ -639,7 +649,8 @@ find_window_functions_walker(Node *node, WindowFuncLists *lists) ...@@ -639,7 +649,8 @@ find_window_functions_walker(Node *node, WindowFuncLists *lists)
/* /*
* Complain if the window function's arguments contain window * Complain if the window function's arguments contain window
* functions * functions. Window functions in the FILTER clause are detected in
* transformAggregateCall().
*/ */
if (contain_window_function((Node *) wfunc->args)) if (contain_window_function((Node *) wfunc->args))
ereport(ERROR, ereport(ERROR,
......
...@@ -492,6 +492,7 @@ static Node *makeRecursiveViewSelect(char *relname, List *aliases, Node *query); ...@@ -492,6 +492,7 @@ static Node *makeRecursiveViewSelect(char *relname, List *aliases, Node *query);
opt_frame_clause frame_extent frame_bound opt_frame_clause frame_extent frame_bound
%type <str> opt_existing_window_name %type <str> opt_existing_window_name
%type <boolean> opt_if_not_exists %type <boolean> opt_if_not_exists
%type <node> filter_clause
/* /*
* Non-keyword token types. These are hard-wired into the "flex" lexer. * Non-keyword token types. These are hard-wired into the "flex" lexer.
...@@ -538,8 +539,8 @@ static Node *makeRecursiveViewSelect(char *relname, List *aliases, Node *query); ...@@ -538,8 +539,8 @@ static Node *makeRecursiveViewSelect(char *relname, List *aliases, Node *query);
EXCLUDE EXCLUDING EXCLUSIVE EXECUTE EXISTS EXPLAIN EXCLUDE EXCLUDING EXCLUSIVE EXECUTE EXISTS EXPLAIN
EXTENSION EXTERNAL EXTRACT EXTENSION EXTERNAL EXTRACT
FALSE_P FAMILY FETCH FIRST_P FLOAT_P FOLLOWING FOR FORCE FOREIGN FORWARD FALSE_P FAMILY FETCH FILTER FIRST_P FLOAT_P FOLLOWING FOR
FREEZE FROM FULL FUNCTION FUNCTIONS FORCE FOREIGN FORWARD FREEZE FROM FULL FUNCTION FUNCTIONS
GLOBAL GRANT GRANTED GREATEST GROUP_P GLOBAL GRANT GRANTED GREATEST GROUP_P
...@@ -11112,10 +11113,11 @@ func_application: func_name '(' ')' ...@@ -11112,10 +11113,11 @@ func_application: func_name '(' ')'
* (Note that many of the special SQL functions wouldn't actually make any * (Note that many of the special SQL functions wouldn't actually make any
* sense as functional index entries, but we ignore that consideration here.) * sense as functional index entries, but we ignore that consideration here.)
*/ */
func_expr: func_application over_clause func_expr: func_application filter_clause over_clause
{ {
FuncCall *n = (FuncCall*)$1; FuncCall *n = (FuncCall*)$1;
n->over = $2; n->agg_filter = $2;
n->over = $3;
$$ = (Node*)n; $$ = (Node*)n;
} }
| func_expr_common_subexpr | func_expr_common_subexpr
...@@ -11526,6 +11528,11 @@ window_definition: ...@@ -11526,6 +11528,11 @@ window_definition:
} }
; ;
filter_clause:
FILTER '(' WHERE a_expr ')' { $$ = $4; }
| /*EMPTY*/ { $$ = NULL; }
;
over_clause: OVER window_specification over_clause: OVER window_specification
{ $$ = $2; } { $$ = $2; }
| OVER ColId | OVER ColId
...@@ -12500,6 +12507,7 @@ unreserved_keyword: ...@@ -12500,6 +12507,7 @@ unreserved_keyword:
| EXTENSION | EXTENSION
| EXTERNAL | EXTERNAL
| FAMILY | FAMILY
| FILTER
| FIRST_P | FIRST_P
| FOLLOWING | FOLLOWING
| FORCE | FORCE
......
...@@ -44,7 +44,7 @@ typedef struct ...@@ -44,7 +44,7 @@ typedef struct
int sublevels_up; int sublevels_up;
} check_ungrouped_columns_context; } check_ungrouped_columns_context;
static int check_agg_arguments(ParseState *pstate, List *args); static int check_agg_arguments(ParseState *pstate, List *args, Expr *filter);
static bool check_agg_arguments_walker(Node *node, static bool check_agg_arguments_walker(Node *node,
check_agg_arguments_context *context); check_agg_arguments_context *context);
static void check_ungrouped_columns(Node *node, ParseState *pstate, Query *qry, static void check_ungrouped_columns(Node *node, ParseState *pstate, Query *qry,
...@@ -160,7 +160,7 @@ transformAggregateCall(ParseState *pstate, Aggref *agg, ...@@ -160,7 +160,7 @@ transformAggregateCall(ParseState *pstate, Aggref *agg,
* Check the arguments to compute the aggregate's level and detect * Check the arguments to compute the aggregate's level and detect
* improper nesting. * improper nesting.
*/ */
min_varlevel = check_agg_arguments(pstate, agg->args); min_varlevel = check_agg_arguments(pstate, agg->args, agg->aggfilter);
agg->agglevelsup = min_varlevel; agg->agglevelsup = min_varlevel;
/* Mark the correct pstate level as having aggregates */ /* Mark the correct pstate level as having aggregates */
...@@ -207,6 +207,9 @@ transformAggregateCall(ParseState *pstate, Aggref *agg, ...@@ -207,6 +207,9 @@ transformAggregateCall(ParseState *pstate, Aggref *agg,
case EXPR_KIND_HAVING: case EXPR_KIND_HAVING:
/* okay */ /* okay */
break; break;
case EXPR_KIND_FILTER:
errkind = true;
break;
case EXPR_KIND_WINDOW_PARTITION: case EXPR_KIND_WINDOW_PARTITION:
/* okay */ /* okay */
break; break;
...@@ -299,8 +302,8 @@ transformAggregateCall(ParseState *pstate, Aggref *agg, ...@@ -299,8 +302,8 @@ transformAggregateCall(ParseState *pstate, Aggref *agg,
* one is its parent, etc). * one is its parent, etc).
* *
* The aggregate's level is the same as the level of the lowest-level variable * The aggregate's level is the same as the level of the lowest-level variable
* or aggregate in its arguments; or if it contains no variables at all, we * or aggregate in its arguments or filter expression; or if it contains no
* presume it to be local. * variables at all, we presume it to be local.
* *
* We also take this opportunity to detect any aggregates or window functions * We also take this opportunity to detect any aggregates or window functions
* nested within the arguments. We can throw error immediately if we find * nested within the arguments. We can throw error immediately if we find
...@@ -309,7 +312,7 @@ transformAggregateCall(ParseState *pstate, Aggref *agg, ...@@ -309,7 +312,7 @@ transformAggregateCall(ParseState *pstate, Aggref *agg,
* which we can't know until we finish scanning the arguments. * which we can't know until we finish scanning the arguments.
*/ */
static int static int
check_agg_arguments(ParseState *pstate, List *args) check_agg_arguments(ParseState *pstate, List *args, Expr *filter)
{ {
int agglevel; int agglevel;
check_agg_arguments_context context; check_agg_arguments_context context;
...@@ -323,6 +326,10 @@ check_agg_arguments(ParseState *pstate, List *args) ...@@ -323,6 +326,10 @@ check_agg_arguments(ParseState *pstate, List *args)
check_agg_arguments_walker, check_agg_arguments_walker,
(void *) &context); (void *) &context);
(void) expression_tree_walker((Node *) filter,
check_agg_arguments_walker,
(void *) &context);
/* /*
* If we found no vars nor aggs at all, it's a level-zero aggregate; * If we found no vars nor aggs at all, it's a level-zero aggregate;
* otherwise, its level is the minimum of vars or aggs. * otherwise, its level is the minimum of vars or aggs.
...@@ -481,6 +488,9 @@ transformWindowFuncCall(ParseState *pstate, WindowFunc *wfunc, ...@@ -481,6 +488,9 @@ transformWindowFuncCall(ParseState *pstate, WindowFunc *wfunc,
case EXPR_KIND_HAVING: case EXPR_KIND_HAVING:
errkind = true; errkind = true;
break; break;
case EXPR_KIND_FILTER:
errkind = true;
break;
case EXPR_KIND_WINDOW_PARTITION: case EXPR_KIND_WINDOW_PARTITION:
case EXPR_KIND_WINDOW_ORDER: case EXPR_KIND_WINDOW_ORDER:
case EXPR_KIND_WINDOW_FRAME_RANGE: case EXPR_KIND_WINDOW_FRAME_RANGE:
...@@ -807,11 +817,10 @@ check_ungrouped_columns_walker(Node *node, ...@@ -807,11 +817,10 @@ check_ungrouped_columns_walker(Node *node,
/* /*
* If we find an aggregate call of the original level, do not recurse into * If we find an aggregate call of the original level, do not recurse into
* its arguments; ungrouped vars in the arguments are not an error. We can * its arguments or filter; ungrouped vars there are not an error. We can
* also skip looking at the arguments of aggregates of higher levels, * also skip looking at aggregates of higher levels, since they could not
* since they could not possibly contain Vars that are of concern to us * possibly contain Vars of concern to us (see transformAggregateCall).
* (see transformAggregateCall). We do need to look into the arguments of * We do need to look at aggregates of lower levels, however.
* aggregates of lower levels, however.
*/ */
if (IsA(node, Aggref) && if (IsA(node, Aggref) &&
(int) ((Aggref *) node)->agglevelsup >= context->sublevels_up) (int) ((Aggref *) node)->agglevelsup >= context->sublevels_up)
......
...@@ -575,6 +575,10 @@ assign_collations_walker(Node *node, assign_collations_context *context) ...@@ -575,6 +575,10 @@ assign_collations_walker(Node *node, assign_collations_context *context)
* the case above for T_TargetEntry will apply * the case above for T_TargetEntry will apply
* appropriate checks to agg ORDER BY items. * appropriate checks to agg ORDER BY items.
* *
* Likewise, we assign collations for the (bool)
* expression in aggfilter, independently of any
* other args.
*
* We need not recurse into the aggorder or * We need not recurse into the aggorder or
* aggdistinct lists, because those contain only * aggdistinct lists, because those contain only
* SortGroupClause nodes which we need not * SortGroupClause nodes which we need not
...@@ -595,6 +599,24 @@ assign_collations_walker(Node *node, assign_collations_context *context) ...@@ -595,6 +599,24 @@ assign_collations_walker(Node *node, assign_collations_context *context)
(void) assign_collations_walker((Node *) tle, (void) assign_collations_walker((Node *) tle,
&loccontext); &loccontext);
} }
assign_expr_collations(context->pstate,
(Node *) aggref->aggfilter);
}
break;
case T_WindowFunc:
{
/*
* WindowFunc requires special processing only for
* its aggfilter clause, as for aggregates.
*/
WindowFunc *wfunc = (WindowFunc *) node;
(void) assign_collations_walker((Node *) wfunc->args,
&loccontext);
assign_expr_collations(context->pstate,
(Node *) wfunc->aggfilter);
} }
break; break;
case T_CaseExpr: case T_CaseExpr:
......
...@@ -22,6 +22,7 @@ ...@@ -22,6 +22,7 @@
#include "nodes/nodeFuncs.h" #include "nodes/nodeFuncs.h"
#include "optimizer/var.h" #include "optimizer/var.h"
#include "parser/analyze.h" #include "parser/analyze.h"
#include "parser/parse_clause.h"
#include "parser/parse_coerce.h" #include "parser/parse_coerce.h"
#include "parser/parse_collate.h" #include "parser/parse_collate.h"
#include "parser/parse_expr.h" #include "parser/parse_expr.h"
...@@ -462,7 +463,7 @@ transformIndirection(ParseState *pstate, Node *basenode, List *indirection) ...@@ -462,7 +463,7 @@ transformIndirection(ParseState *pstate, Node *basenode, List *indirection)
newresult = ParseFuncOrColumn(pstate, newresult = ParseFuncOrColumn(pstate,
list_make1(n), list_make1(n),
list_make1(result), list_make1(result),
NIL, false, false, false, NIL, NULL, false, false, false,
NULL, true, location); NULL, true, location);
if (newresult == NULL) if (newresult == NULL)
unknown_attribute(pstate, result, strVal(n), location); unknown_attribute(pstate, result, strVal(n), location);
...@@ -630,7 +631,7 @@ transformColumnRef(ParseState *pstate, ColumnRef *cref) ...@@ -630,7 +631,7 @@ transformColumnRef(ParseState *pstate, ColumnRef *cref)
node = ParseFuncOrColumn(pstate, node = ParseFuncOrColumn(pstate,
list_make1(makeString(colname)), list_make1(makeString(colname)),
list_make1(node), list_make1(node),
NIL, false, false, false, NIL, NULL, false, false, false,
NULL, true, cref->location); NULL, true, cref->location);
} }
break; break;
...@@ -675,7 +676,7 @@ transformColumnRef(ParseState *pstate, ColumnRef *cref) ...@@ -675,7 +676,7 @@ transformColumnRef(ParseState *pstate, ColumnRef *cref)
node = ParseFuncOrColumn(pstate, node = ParseFuncOrColumn(pstate,
list_make1(makeString(colname)), list_make1(makeString(colname)),
list_make1(node), list_make1(node),
NIL, false, false, false, NIL, NULL, false, false, false,
NULL, true, cref->location); NULL, true, cref->location);
} }
break; break;
...@@ -733,7 +734,7 @@ transformColumnRef(ParseState *pstate, ColumnRef *cref) ...@@ -733,7 +734,7 @@ transformColumnRef(ParseState *pstate, ColumnRef *cref)
node = ParseFuncOrColumn(pstate, node = ParseFuncOrColumn(pstate,
list_make1(makeString(colname)), list_make1(makeString(colname)),
list_make1(node), list_make1(node),
NIL, false, false, false, NIL, NULL, false, false, false,
NULL, true, cref->location); NULL, true, cref->location);
} }
break; break;
...@@ -1241,6 +1242,7 @@ transformFuncCall(ParseState *pstate, FuncCall *fn) ...@@ -1241,6 +1242,7 @@ transformFuncCall(ParseState *pstate, FuncCall *fn)
{ {
List *targs; List *targs;
ListCell *args; ListCell *args;
Expr *tagg_filter;
/* Transform the list of arguments ... */ /* Transform the list of arguments ... */
targs = NIL; targs = NIL;
...@@ -1250,11 +1252,22 @@ transformFuncCall(ParseState *pstate, FuncCall *fn) ...@@ -1250,11 +1252,22 @@ transformFuncCall(ParseState *pstate, FuncCall *fn)
(Node *) lfirst(args))); (Node *) lfirst(args)));
} }
/*
* Transform the aggregate filter using transformWhereClause(), to which
* FILTER is virtually identical...
*/
tagg_filter = NULL;
if (fn->agg_filter != NULL)
tagg_filter = (Expr *)
transformWhereClause(pstate, (Node *) fn->agg_filter,
EXPR_KIND_FILTER, "FILTER");
/* ... and hand off to ParseFuncOrColumn */ /* ... and hand off to ParseFuncOrColumn */
return ParseFuncOrColumn(pstate, return ParseFuncOrColumn(pstate,
fn->funcname, fn->funcname,
targs, targs,
fn->agg_order, fn->agg_order,
tagg_filter,
fn->agg_star, fn->agg_star,
fn->agg_distinct, fn->agg_distinct,
fn->func_variadic, fn->func_variadic,
...@@ -1430,6 +1443,7 @@ transformSubLink(ParseState *pstate, SubLink *sublink) ...@@ -1430,6 +1443,7 @@ transformSubLink(ParseState *pstate, SubLink *sublink)
case EXPR_KIND_FROM_FUNCTION: case EXPR_KIND_FROM_FUNCTION:
case EXPR_KIND_WHERE: case EXPR_KIND_WHERE:
case EXPR_KIND_HAVING: case EXPR_KIND_HAVING:
case EXPR_KIND_FILTER:
case EXPR_KIND_WINDOW_PARTITION: case EXPR_KIND_WINDOW_PARTITION:
case EXPR_KIND_WINDOW_ORDER: case EXPR_KIND_WINDOW_ORDER:
case EXPR_KIND_WINDOW_FRAME_RANGE: case EXPR_KIND_WINDOW_FRAME_RANGE:
...@@ -2579,6 +2593,8 @@ ParseExprKindName(ParseExprKind exprKind) ...@@ -2579,6 +2593,8 @@ ParseExprKindName(ParseExprKind exprKind)
return "WHERE"; return "WHERE";
case EXPR_KIND_HAVING: case EXPR_KIND_HAVING:
return "HAVING"; return "HAVING";
case EXPR_KIND_FILTER:
return "FILTER";
case EXPR_KIND_WINDOW_PARTITION: case EXPR_KIND_WINDOW_PARTITION:
return "window PARTITION BY"; return "window PARTITION BY";
case EXPR_KIND_WINDOW_ORDER: case EXPR_KIND_WINDOW_ORDER:
......
...@@ -56,13 +56,13 @@ static Node *ParseComplexProjection(ParseState *pstate, char *funcname, ...@@ -56,13 +56,13 @@ static Node *ParseComplexProjection(ParseState *pstate, char *funcname,
* Also, when is_column is true, we return NULL on failure rather than * Also, when is_column is true, we return NULL on failure rather than
* reporting a no-such-function error. * reporting a no-such-function error.
* *
* The argument expressions (in fargs) must have been transformed already. * The argument expressions (in fargs) and filter must have been transformed
* But the agg_order expressions, if any, have not been. * already. But the agg_order expressions, if any, have not been.
*/ */
Node * Node *
ParseFuncOrColumn(ParseState *pstate, List *funcname, List *fargs, ParseFuncOrColumn(ParseState *pstate, List *funcname, List *fargs,
List *agg_order, bool agg_star, bool agg_distinct, List *agg_order, Expr *agg_filter,
bool func_variadic, bool agg_star, bool agg_distinct, bool func_variadic,
WindowDef *over, bool is_column, int location) WindowDef *over, bool is_column, int location)
{ {
Oid rettype; Oid rettype;
...@@ -174,8 +174,8 @@ ParseFuncOrColumn(ParseState *pstate, List *funcname, List *fargs, ...@@ -174,8 +174,8 @@ ParseFuncOrColumn(ParseState *pstate, List *funcname, List *fargs,
* the "function call" could be a projection. We also check that there * the "function call" could be a projection. We also check that there
* wasn't any aggregate or variadic decoration, nor an argument name. * wasn't any aggregate or variadic decoration, nor an argument name.
*/ */
if (nargs == 1 && agg_order == NIL && !agg_star && !agg_distinct && if (nargs == 1 && agg_order == NIL && agg_filter == NULL && !agg_star &&
over == NULL && !func_variadic && argnames == NIL && !agg_distinct && over == NULL && !func_variadic && argnames == NIL &&
list_length(funcname) == 1) list_length(funcname) == 1)
{ {
Oid argtype = actual_arg_types[0]; Oid argtype = actual_arg_types[0];
...@@ -251,6 +251,12 @@ ParseFuncOrColumn(ParseState *pstate, List *funcname, List *fargs, ...@@ -251,6 +251,12 @@ ParseFuncOrColumn(ParseState *pstate, List *funcname, List *fargs,
errmsg("ORDER BY specified, but %s is not an aggregate function", errmsg("ORDER BY specified, but %s is not an aggregate function",
NameListToString(funcname)), NameListToString(funcname)),
parser_errposition(pstate, location))); parser_errposition(pstate, location)));
if (agg_filter)
ereport(ERROR,
(errcode(ERRCODE_WRONG_OBJECT_TYPE),
errmsg("FILTER specified, but %s is not an aggregate function",
NameListToString(funcname)),
parser_errposition(pstate, location)));
if (over) if (over)
ereport(ERROR, ereport(ERROR,
(errcode(ERRCODE_WRONG_OBJECT_TYPE), (errcode(ERRCODE_WRONG_OBJECT_TYPE),
...@@ -402,6 +408,7 @@ ParseFuncOrColumn(ParseState *pstate, List *funcname, List *fargs, ...@@ -402,6 +408,7 @@ ParseFuncOrColumn(ParseState *pstate, List *funcname, List *fargs,
/* aggcollid and inputcollid will be set by parse_collate.c */ /* aggcollid and inputcollid will be set by parse_collate.c */
/* args, aggorder, aggdistinct will be set by transformAggregateCall */ /* args, aggorder, aggdistinct will be set by transformAggregateCall */
aggref->aggstar = agg_star; aggref->aggstar = agg_star;
aggref->aggfilter = agg_filter;
/* agglevelsup will be set by transformAggregateCall */ /* agglevelsup will be set by transformAggregateCall */
aggref->location = location; aggref->location = location;
...@@ -460,6 +467,7 @@ ParseFuncOrColumn(ParseState *pstate, List *funcname, List *fargs, ...@@ -460,6 +467,7 @@ ParseFuncOrColumn(ParseState *pstate, List *funcname, List *fargs,
/* winref will be set by transformWindowFuncCall */ /* winref will be set by transformWindowFuncCall */
wfunc->winstar = agg_star; wfunc->winstar = agg_star;
wfunc->winagg = (fdresult == FUNCDETAIL_AGGREGATE); wfunc->winagg = (fdresult == FUNCDETAIL_AGGREGATE);
wfunc->aggfilter = agg_filter;
wfunc->location = location; wfunc->location = location;
/* /*
...@@ -482,6 +490,16 @@ ParseFuncOrColumn(ParseState *pstate, List *funcname, List *fargs, ...@@ -482,6 +490,16 @@ ParseFuncOrColumn(ParseState *pstate, List *funcname, List *fargs,
NameListToString(funcname)), NameListToString(funcname)),
parser_errposition(pstate, location))); parser_errposition(pstate, location)));
/*
* Reject window functions which are not aggregates in the case of
* FILTER.
*/
if (!wfunc->winagg && agg_filter)
ereport(ERROR,
(errcode(ERRCODE_WRONG_OBJECT_TYPE),
errmsg("FILTER is not implemented in non-aggregate window functions"),
parser_errposition(pstate, location)));
/* /*
* ordered aggs not allowed in windows yet * ordered aggs not allowed in windows yet
*/ */
......
...@@ -7424,6 +7424,13 @@ get_agg_expr(Aggref *aggref, deparse_context *context) ...@@ -7424,6 +7424,13 @@ get_agg_expr(Aggref *aggref, deparse_context *context)
appendStringInfoString(buf, " ORDER BY "); appendStringInfoString(buf, " ORDER BY ");
get_rule_orderby(aggref->aggorder, aggref->args, false, context); get_rule_orderby(aggref->aggorder, aggref->args, false, context);
} }
if (aggref->aggfilter != NULL)
{
appendStringInfoString(buf, ") FILTER (WHERE ");
get_rule_expr((Node *) aggref->aggfilter, context, false);
}
appendStringInfoChar(buf, ')'); appendStringInfoChar(buf, ')');
} }
...@@ -7461,6 +7468,13 @@ get_windowfunc_expr(WindowFunc *wfunc, deparse_context *context) ...@@ -7461,6 +7468,13 @@ get_windowfunc_expr(WindowFunc *wfunc, deparse_context *context)
appendStringInfoChar(buf, '*'); appendStringInfoChar(buf, '*');
else else
get_rule_expr((Node *) wfunc->args, context, true); get_rule_expr((Node *) wfunc->args, context, true);
if (wfunc->aggfilter != NULL)
{
appendStringInfoString(buf, ") FILTER (WHERE ");
get_rule_expr((Node *) wfunc->aggfilter, context, false);
}
appendStringInfoString(buf, ") OVER "); appendStringInfoString(buf, ") OVER ");
foreach(l, context->windowClause) foreach(l, context->windowClause)
......
...@@ -53,6 +53,6 @@ ...@@ -53,6 +53,6 @@
*/ */
/* yyyymmddN */ /* yyyymmddN */
#define CATALOG_VERSION_NO 201307051 #define CATALOG_VERSION_NO 201307161
#endif #endif
...@@ -584,6 +584,7 @@ typedef struct AggrefExprState ...@@ -584,6 +584,7 @@ typedef struct AggrefExprState
{ {
ExprState xprstate; ExprState xprstate;
List *args; /* states of argument expressions */ List *args; /* states of argument expressions */
ExprState *aggfilter; /* FILTER expression */
int aggno; /* ID number for agg within its plan node */ int aggno; /* ID number for agg within its plan node */
} AggrefExprState; } AggrefExprState;
...@@ -595,6 +596,7 @@ typedef struct WindowFuncExprState ...@@ -595,6 +596,7 @@ typedef struct WindowFuncExprState
{ {
ExprState xprstate; ExprState xprstate;
List *args; /* states of argument expressions */ List *args; /* states of argument expressions */
ExprState *aggfilter; /* FILTER expression */
int wfuncno; /* ID number for wfunc within its plan node */ int wfuncno; /* ID number for wfunc within its plan node */
} WindowFuncExprState; } WindowFuncExprState;
......
...@@ -283,8 +283,8 @@ typedef struct CollateClause ...@@ -283,8 +283,8 @@ typedef struct CollateClause
* agg_star indicates we saw a 'foo(*)' construct, while agg_distinct * agg_star indicates we saw a 'foo(*)' construct, while agg_distinct
* indicates we saw 'foo(DISTINCT ...)'. In any of these cases, the * indicates we saw 'foo(DISTINCT ...)'. In any of these cases, the
* construct *must* be an aggregate call. Otherwise, it might be either an * construct *must* be an aggregate call. Otherwise, it might be either an
* aggregate or some other kind of function. However, if OVER is present * aggregate or some other kind of function. However, if FILTER or OVER is
* it had better be an aggregate or window function. * present it had better be an aggregate or window function.
* *
* Normally, you'd initialize this via makeFuncCall() and then only * Normally, you'd initialize this via makeFuncCall() and then only
* change the parts of the struct its defaults don't match afterwards * change the parts of the struct its defaults don't match afterwards
...@@ -297,6 +297,7 @@ typedef struct FuncCall ...@@ -297,6 +297,7 @@ typedef struct FuncCall
List *funcname; /* qualified name of function */ List *funcname; /* qualified name of function */
List *args; /* the arguments (list of exprs) */ List *args; /* the arguments (list of exprs) */
List *agg_order; /* ORDER BY (list of SortBy) */ List *agg_order; /* ORDER BY (list of SortBy) */
Node *agg_filter; /* FILTER clause, if any */
bool agg_star; /* argument was really '*' */ bool agg_star; /* argument was really '*' */
bool agg_distinct; /* arguments were labeled DISTINCT */ bool agg_distinct; /* arguments were labeled DISTINCT */
bool func_variadic; /* last argument was labeled VARIADIC */ bool func_variadic; /* last argument was labeled VARIADIC */
......
...@@ -247,6 +247,7 @@ typedef struct Aggref ...@@ -247,6 +247,7 @@ typedef struct Aggref
List *args; /* arguments and sort expressions */ List *args; /* arguments and sort expressions */
List *aggorder; /* ORDER BY (list of SortGroupClause) */ List *aggorder; /* ORDER BY (list of SortGroupClause) */
List *aggdistinct; /* DISTINCT (list of SortGroupClause) */ List *aggdistinct; /* DISTINCT (list of SortGroupClause) */
Expr *aggfilter; /* FILTER expression */
bool aggstar; /* TRUE if argument list was really '*' */ bool aggstar; /* TRUE if argument list was really '*' */
Index agglevelsup; /* > 0 if agg belongs to outer query */ Index agglevelsup; /* > 0 if agg belongs to outer query */
int location; /* token location, or -1 if unknown */ int location; /* token location, or -1 if unknown */
...@@ -263,6 +264,7 @@ typedef struct WindowFunc ...@@ -263,6 +264,7 @@ typedef struct WindowFunc
Oid wincollid; /* OID of collation of result */ Oid wincollid; /* OID of collation of result */
Oid inputcollid; /* OID of collation that function should use */ Oid inputcollid; /* OID of collation that function should use */
List *args; /* arguments to the window function */ List *args; /* arguments to the window function */
Expr *aggfilter; /* FILTER expression */
Index winref; /* index of associated WindowClause */ Index winref; /* index of associated WindowClause */
bool winstar; /* TRUE if argument list was really '*' */ bool winstar; /* TRUE if argument list was really '*' */
bool winagg; /* is function a simple aggregate? */ bool winagg; /* is function a simple aggregate? */
......
...@@ -155,6 +155,7 @@ PG_KEYWORD("extract", EXTRACT, COL_NAME_KEYWORD) ...@@ -155,6 +155,7 @@ PG_KEYWORD("extract", EXTRACT, COL_NAME_KEYWORD)
PG_KEYWORD("false", FALSE_P, RESERVED_KEYWORD) PG_KEYWORD("false", FALSE_P, RESERVED_KEYWORD)
PG_KEYWORD("family", FAMILY, UNRESERVED_KEYWORD) PG_KEYWORD("family", FAMILY, UNRESERVED_KEYWORD)
PG_KEYWORD("fetch", FETCH, RESERVED_KEYWORD) PG_KEYWORD("fetch", FETCH, RESERVED_KEYWORD)
PG_KEYWORD("filter", FILTER, UNRESERVED_KEYWORD)
PG_KEYWORD("first", FIRST_P, UNRESERVED_KEYWORD) PG_KEYWORD("first", FIRST_P, UNRESERVED_KEYWORD)
PG_KEYWORD("float", FLOAT_P, COL_NAME_KEYWORD) PG_KEYWORD("float", FLOAT_P, COL_NAME_KEYWORD)
PG_KEYWORD("following", FOLLOWING, UNRESERVED_KEYWORD) PG_KEYWORD("following", FOLLOWING, UNRESERVED_KEYWORD)
......
...@@ -42,10 +42,9 @@ typedef enum ...@@ -42,10 +42,9 @@ typedef enum
} FuncDetailCode; } FuncDetailCode;
extern Node *ParseFuncOrColumn(ParseState *pstate, extern Node *ParseFuncOrColumn(ParseState *pstate, List *funcname, List *fargs,
List *funcname, List *fargs, List *agg_order, Expr *agg_filter,
List *agg_order, bool agg_star, bool agg_distinct, bool agg_star, bool agg_distinct, bool func_variadic,
bool func_variadic,
WindowDef *over, bool is_column, int location); WindowDef *over, bool is_column, int location);
extern FuncDetailCode func_get_detail(List *funcname, extern FuncDetailCode func_get_detail(List *funcname,
......
...@@ -39,6 +39,7 @@ typedef enum ParseExprKind ...@@ -39,6 +39,7 @@ typedef enum ParseExprKind
EXPR_KIND_FROM_FUNCTION, /* function in FROM clause */ EXPR_KIND_FROM_FUNCTION, /* function in FROM clause */
EXPR_KIND_WHERE, /* WHERE */ EXPR_KIND_WHERE, /* WHERE */
EXPR_KIND_HAVING, /* HAVING */ EXPR_KIND_HAVING, /* HAVING */
EXPR_KIND_FILTER, /* FILTER */
EXPR_KIND_WINDOW_PARTITION, /* window definition PARTITION BY */ EXPR_KIND_WINDOW_PARTITION, /* window definition PARTITION BY */
EXPR_KIND_WINDOW_ORDER, /* window definition ORDER BY */ EXPR_KIND_WINDOW_ORDER, /* window definition ORDER BY */
EXPR_KIND_WINDOW_FRAME_RANGE, /* window frame clause with RANGE */ EXPR_KIND_WINDOW_FRAME_RANGE, /* window frame clause with RANGE */
......
...@@ -1154,3 +1154,98 @@ select string_agg(v, decode('ee', 'hex')) from bytea_test_table; ...@@ -1154,3 +1154,98 @@ select string_agg(v, decode('ee', 'hex')) from bytea_test_table;
(1 row) (1 row)
drop table bytea_test_table; drop table bytea_test_table;
-- FILTER tests
select min(unique1) filter (where unique1 > 100) from tenk1;
min
-----
101
(1 row)
select ten, sum(distinct four) filter (where four::text ~ '123') from onek a
group by ten;
ten | sum
-----+-----
0 |
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
(10 rows)
select ten, sum(distinct four) filter (where four > 10) from onek a
group by ten
having exists (select 1 from onek b where sum(distinct a.four) = b.four);
ten | sum
-----+-----
0 |
2 |
4 |
6 |
8 |
(5 rows)
select max(foo COLLATE "C") filter (where (bar collate "POSIX") > '0')
from (values ('a', 'b')) AS v(foo,bar);
max
-----
a
(1 row)
-- outer reference in FILTER (PostgreSQL extension)
select (select count(*)
from (values (1)) t0(inner_c))
from (values (2),(3)) t1(outer_c); -- inner query is aggregation query
count
-------
1
1
(2 rows)
select (select count(*) filter (where outer_c <> 0)
from (values (1)) t0(inner_c))
from (values (2),(3)) t1(outer_c); -- outer query is aggregation query
count
-------
2
(1 row)
select (select count(inner_c) filter (where outer_c <> 0)
from (values (1)) t0(inner_c))
from (values (2),(3)) t1(outer_c); -- inner query is aggregation query
count
-------
1
1
(2 rows)
select
(select max((select i.unique2 from tenk1 i where i.unique1 = o.unique1))
filter (where o.unique1 < 10))
from tenk1 o; -- outer query is aggregation query
max
------
9998
(1 row)
-- subquery in FILTER clause (PostgreSQL extension)
select sum(unique1) FILTER (WHERE
unique1 IN (SELECT unique1 FROM onek where unique1 < 100)) FROM tenk1;
sum
------
4950
(1 row)
-- exercise lots of aggregate parts with FILTER
select aggfns(distinct a,b,c order by a,c using ~<~,b) filter (where a > 1)
from (values (1,3,'foo'),(0,null,null),(2,2,'bar'),(3,1,'baz')) v(a,b,c),
generate_series(1,2) i;
aggfns
---------------------------
{"(2,2,bar)","(3,1,baz)"}
(1 row)
...@@ -1020,5 +1020,18 @@ SELECT ntile(0) OVER (ORDER BY ten), ten, four FROM tenk1; ...@@ -1020,5 +1020,18 @@ SELECT ntile(0) OVER (ORDER BY ten), ten, four FROM tenk1;
ERROR: argument of ntile must be greater than zero ERROR: argument of ntile must be greater than zero
SELECT nth_value(four, 0) OVER (ORDER BY ten), ten, four FROM tenk1; SELECT nth_value(four, 0) OVER (ORDER BY ten), ten, four FROM tenk1;
ERROR: argument of nth_value must be greater than zero ERROR: argument of nth_value must be greater than zero
-- filter
SELECT sum(salary), row_number() OVER (ORDER BY depname), sum(
sum(salary) FILTER (WHERE enroll_date > '2007-01-01')
) FILTER (WHERE depname <> 'sales') OVER (ORDER BY depname DESC) AS "filtered_sum",
depname
FROM empsalary GROUP BY depname;
sum | row_number | filtered_sum | depname
-------+------------+--------------+-----------
14600 | 3 | | sales
7400 | 2 | 3500 | personnel
25100 | 1 | 22600 | develop
(3 rows)
-- cleanup -- cleanup
DROP TABLE empsalary; DROP TABLE empsalary;
...@@ -442,3 +442,41 @@ select string_agg(v, NULL) from bytea_test_table; ...@@ -442,3 +442,41 @@ select string_agg(v, NULL) from bytea_test_table;
select string_agg(v, decode('ee', 'hex')) from bytea_test_table; select string_agg(v, decode('ee', 'hex')) from bytea_test_table;
drop table bytea_test_table; drop table bytea_test_table;
-- FILTER tests
select min(unique1) filter (where unique1 > 100) from tenk1;
select ten, sum(distinct four) filter (where four::text ~ '123') from onek a
group by ten;
select ten, sum(distinct four) filter (where four > 10) from onek a
group by ten
having exists (select 1 from onek b where sum(distinct a.four) = b.four);
select max(foo COLLATE "C") filter (where (bar collate "POSIX") > '0')
from (values ('a', 'b')) AS v(foo,bar);
-- outer reference in FILTER (PostgreSQL extension)
select (select count(*)
from (values (1)) t0(inner_c))
from (values (2),(3)) t1(outer_c); -- inner query is aggregation query
select (select count(*) filter (where outer_c <> 0)
from (values (1)) t0(inner_c))
from (values (2),(3)) t1(outer_c); -- outer query is aggregation query
select (select count(inner_c) filter (where outer_c <> 0)
from (values (1)) t0(inner_c))
from (values (2),(3)) t1(outer_c); -- inner query is aggregation query
select
(select max((select i.unique2 from tenk1 i where i.unique1 = o.unique1))
filter (where o.unique1 < 10))
from tenk1 o; -- outer query is aggregation query
-- subquery in FILTER clause (PostgreSQL extension)
select sum(unique1) FILTER (WHERE
unique1 IN (SELECT unique1 FROM onek where unique1 < 100)) FROM tenk1;
-- exercise lots of aggregate parts with FILTER
select aggfns(distinct a,b,c order by a,c using ~<~,b) filter (where a > 1)
from (values (1,3,'foo'),(0,null,null),(2,2,'bar'),(3,1,'baz')) v(a,b,c),
generate_series(1,2) i;
...@@ -264,5 +264,13 @@ SELECT ntile(0) OVER (ORDER BY ten), ten, four FROM tenk1; ...@@ -264,5 +264,13 @@ SELECT ntile(0) OVER (ORDER BY ten), ten, four FROM tenk1;
SELECT nth_value(four, 0) OVER (ORDER BY ten), ten, four FROM tenk1; SELECT nth_value(four, 0) OVER (ORDER BY ten), ten, four FROM tenk1;
-- filter
SELECT sum(salary), row_number() OVER (ORDER BY depname), sum(
sum(salary) FILTER (WHERE enroll_date > '2007-01-01')
) FILTER (WHERE depname <> 'sales') OVER (ORDER BY depname DESC) AS "filtered_sum",
depname
FROM empsalary GROUP BY depname;
-- cleanup -- cleanup
DROP TABLE empsalary; DROP TABLE empsalary;
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment