Commit 41ea0c23 authored by Robert Haas's avatar Robert Haas

Fix parallel-safety code for parallel aggregation.

has_parallel_hazard() was ignoring the proparallel markings for
aggregates, which is no good.  Fix that.  There was no way to mark
an aggregate as actually being parallel-safe, either, so add a
PARALLEL option to CREATE AGGREGATE.

Patch by me, reviewed by David Rowley.
parent 09adc9a8
...@@ -40,6 +40,7 @@ CREATE AGGREGATE <replaceable class="parameter">name</replaceable> ( [ <replacea ...@@ -40,6 +40,7 @@ CREATE AGGREGATE <replaceable class="parameter">name</replaceable> ( [ <replacea
[ , MFINALFUNC_EXTRA ] [ , MFINALFUNC_EXTRA ]
[ , MINITCOND = <replaceable class="PARAMETER">minitial_condition</replaceable> ] [ , MINITCOND = <replaceable class="PARAMETER">minitial_condition</replaceable> ]
[ , SORTOP = <replaceable class="PARAMETER">sort_operator</replaceable> ] [ , SORTOP = <replaceable class="PARAMETER">sort_operator</replaceable> ]
[ , PARALLEL = { SAFE | RESTRICTED | UNSAFE } ]
) )
CREATE AGGREGATE <replaceable class="parameter">name</replaceable> ( [ [ <replaceable class="parameter">argmode</replaceable> ] [ <replaceable class="parameter">argname</replaceable> ] <replaceable class="parameter">arg_data_type</replaceable> [ , ... ] ] CREATE AGGREGATE <replaceable class="parameter">name</replaceable> ( [ [ <replaceable class="parameter">argmode</replaceable> ] [ <replaceable class="parameter">argname</replaceable> ] <replaceable class="parameter">arg_data_type</replaceable> [ , ... ] ]
...@@ -55,6 +56,8 @@ CREATE AGGREGATE <replaceable class="parameter">name</replaceable> ( [ [ <replac ...@@ -55,6 +56,8 @@ CREATE AGGREGATE <replaceable class="parameter">name</replaceable> ( [ [ <replac
[ , SERIALTYPE = <replaceable class="PARAMETER">serialtype</replaceable> ] [ , SERIALTYPE = <replaceable class="PARAMETER">serialtype</replaceable> ]
[ , INITCOND = <replaceable class="PARAMETER">initial_condition</replaceable> ] [ , INITCOND = <replaceable class="PARAMETER">initial_condition</replaceable> ]
[ , HYPOTHETICAL ] [ , HYPOTHETICAL ]
[ , PARALLEL = { SAFE | RESTRICTED | UNSAFE } ]
) )
<phrase>or the old syntax</phrase> <phrase>or the old syntax</phrase>
...@@ -684,6 +687,12 @@ SELECT col FROM tab ORDER BY col USING sortop LIMIT 1; ...@@ -684,6 +687,12 @@ SELECT col FROM tab ORDER BY col USING sortop LIMIT 1;
Currently, ordered-set aggregates do not need to support Currently, ordered-set aggregates do not need to support
moving-aggregate mode, since they cannot be used as window functions. moving-aggregate mode, since they cannot be used as window functions.
</para> </para>
<para>
The meaning of <literal>PARALLEL SAFE</>, <literal>PARALLEL RESTRICTED</>,
and <literal>PARALLEL UNSAFE</> is the same as for
<xref linkend="sql-createfunction">.
</para>
</refsect1> </refsect1>
<refsect1> <refsect1>
......
...@@ -72,7 +72,8 @@ AggregateCreate(const char *aggName, ...@@ -72,7 +72,8 @@ AggregateCreate(const char *aggName,
Oid aggmTransType, Oid aggmTransType,
int32 aggmTransSpace, int32 aggmTransSpace,
const char *agginitval, const char *agginitval,
const char *aggminitval) const char *aggminitval,
char proparallel)
{ {
Relation aggdesc; Relation aggdesc;
HeapTuple tup; HeapTuple tup;
...@@ -622,7 +623,7 @@ AggregateCreate(const char *aggName, ...@@ -622,7 +623,7 @@ AggregateCreate(const char *aggName,
false, /* isStrict (not needed for agg) */ false, /* isStrict (not needed for agg) */
PROVOLATILE_IMMUTABLE, /* volatility (not PROVOLATILE_IMMUTABLE, /* volatility (not
* needed for agg) */ * needed for agg) */
PROPARALLEL_UNSAFE, proparallel,
parameterTypes, /* paramTypes */ parameterTypes, /* paramTypes */
allParameterTypes, /* allParamTypes */ allParameterTypes, /* allParamTypes */
parameterModes, /* parameterModes */ parameterModes, /* parameterModes */
......
...@@ -78,6 +78,7 @@ DefineAggregate(List *name, List *args, bool oldstyle, List *parameters, ...@@ -78,6 +78,7 @@ DefineAggregate(List *name, List *args, bool oldstyle, List *parameters,
int32 mtransSpace = 0; int32 mtransSpace = 0;
char *initval = NULL; char *initval = NULL;
char *minitval = NULL; char *minitval = NULL;
char *parallel = NULL;
int numArgs; int numArgs;
int numDirectArgs = 0; int numDirectArgs = 0;
oidvector *parameterTypes; oidvector *parameterTypes;
...@@ -91,6 +92,7 @@ DefineAggregate(List *name, List *args, bool oldstyle, List *parameters, ...@@ -91,6 +92,7 @@ DefineAggregate(List *name, List *args, bool oldstyle, List *parameters,
Oid mtransTypeId = InvalidOid; Oid mtransTypeId = InvalidOid;
char transTypeType; char transTypeType;
char mtransTypeType = 0; char mtransTypeType = 0;
char proparallel = PROPARALLEL_UNSAFE;
ListCell *pl; ListCell *pl;
/* Convert list of names to a name and namespace */ /* Convert list of names to a name and namespace */
...@@ -178,6 +180,8 @@ DefineAggregate(List *name, List *args, bool oldstyle, List *parameters, ...@@ -178,6 +180,8 @@ DefineAggregate(List *name, List *args, bool oldstyle, List *parameters,
initval = defGetString(defel); initval = defGetString(defel);
else if (pg_strcasecmp(defel->defname, "minitcond") == 0) else if (pg_strcasecmp(defel->defname, "minitcond") == 0)
minitval = defGetString(defel); minitval = defGetString(defel);
else if (pg_strcasecmp(defel->defname, "parallel") == 0)
parallel = defGetString(defel);
else else
ereport(WARNING, ereport(WARNING,
(errcode(ERRCODE_SYNTAX_ERROR), (errcode(ERRCODE_SYNTAX_ERROR),
...@@ -449,6 +453,20 @@ DefineAggregate(List *name, List *args, bool oldstyle, List *parameters, ...@@ -449,6 +453,20 @@ DefineAggregate(List *name, List *args, bool oldstyle, List *parameters,
(void) OidInputFunctionCall(typinput, minitval, typioparam, -1); (void) OidInputFunctionCall(typinput, minitval, typioparam, -1);
} }
if (parallel)
{
if (pg_strcasecmp(parallel, "safe") == 0)
proparallel = PROPARALLEL_SAFE;
else if (pg_strcasecmp(parallel, "restricted") == 0)
proparallel = PROPARALLEL_RESTRICTED;
else if (pg_strcasecmp(parallel, "unsafe") == 0)
proparallel = PROPARALLEL_UNSAFE;
else
ereport(ERROR,
(errcode(ERRCODE_SYNTAX_ERROR),
errmsg("parameter \"parallel\" must be SAFE, RESTRICTED, or UNSAFE")));
}
/* /*
* Most of the argument-checking is done inside of AggregateCreate * Most of the argument-checking is done inside of AggregateCreate
*/ */
...@@ -480,5 +498,6 @@ DefineAggregate(List *name, List *args, bool oldstyle, List *parameters, ...@@ -480,5 +498,6 @@ DefineAggregate(List *name, List *args, bool oldstyle, List *parameters,
mtransTypeId, /* transition data type */ mtransTypeId, /* transition data type */
mtransSpace, /* transition space */ mtransSpace, /* transition space */
initval, /* initial condition */ initval, /* initial condition */
minitval); /* initial condition */ minitval, /* initial condition */
proparallel); /* parallel safe? */
} }
...@@ -566,9 +566,8 @@ interpret_func_parallel(DefElem *defel) ...@@ -566,9 +566,8 @@ interpret_func_parallel(DefElem *defel)
else else
{ {
ereport(ERROR, ereport(ERROR,
(errcode(ERRCODE_INVALID_PARAMETER_VALUE), (errcode(ERRCODE_SYNTAX_ERROR),
errmsg("parallel option \"%s\" not recognized", errmsg("parameter \"parallel\" must be SAFE, RESTRICTED, or UNSAFE")));
str)));
return PROPARALLEL_UNSAFE; /* keep compiler quiet */ return PROPARALLEL_UNSAFE; /* keep compiler quiet */
} }
} }
......
...@@ -1419,6 +1419,13 @@ has_parallel_hazard_walker(Node *node, has_parallel_hazard_arg *context) ...@@ -1419,6 +1419,13 @@ has_parallel_hazard_walker(Node *node, has_parallel_hazard_arg *context)
if (parallel_too_dangerous(func_parallel(expr->funcid), context)) if (parallel_too_dangerous(func_parallel(expr->funcid), context))
return true; return true;
} }
else if (IsA(node, Aggref))
{
Aggref *aggref = (Aggref *) node;
if (parallel_too_dangerous(func_parallel(aggref->aggfnoid), context))
return true;
}
else if (IsA(node, OpExpr)) else if (IsA(node, OpExpr))
{ {
OpExpr *expr = (OpExpr *) node; OpExpr *expr = (OpExpr *) node;
......
...@@ -349,6 +349,7 @@ extern ObjectAddress AggregateCreate(const char *aggName, ...@@ -349,6 +349,7 @@ extern ObjectAddress AggregateCreate(const char *aggName,
Oid aggmTransType, Oid aggmTransType,
int32 aggmTransSpace, int32 aggmTransSpace,
const char *agginitval, const char *agginitval,
const char *aggminitval); const char *aggminitval,
char proparallel);
#endif /* PG_AGGREGATE_H */ #endif /* PG_AGGREGATE_H */
...@@ -20,9 +20,9 @@ CREATE AGGREGATE newsum ( ...@@ -20,9 +20,9 @@ CREATE AGGREGATE newsum (
-- zero-argument aggregate -- zero-argument aggregate
CREATE AGGREGATE newcnt (*) ( CREATE AGGREGATE newcnt (*) (
sfunc = int8inc, stype = int8, sfunc = int8inc, stype = int8,
initcond = '0' initcond = '0', parallel = safe
); );
-- old-style spelling of same -- old-style spelling of same (except without parallel-safe; that's too new)
CREATE AGGREGATE oldcnt ( CREATE AGGREGATE oldcnt (
sfunc = int8inc, basetype = 'ANY', stype = int8, sfunc = int8inc, basetype = 'ANY', stype = int8,
initcond = '0' initcond = '0'
...@@ -188,6 +188,14 @@ WHERE aggfnoid = 'myavg'::REGPROC; ...@@ -188,6 +188,14 @@ WHERE aggfnoid = 'myavg'::REGPROC;
(1 row) (1 row)
DROP AGGREGATE myavg (numeric); DROP AGGREGATE myavg (numeric);
-- invalid: bad parallel-safety marking
CREATE AGGREGATE mysum (int)
(
stype = int,
sfunc = int4pl,
parallel = pear
);
ERROR: parameter "parallel" must be SAFE, RESTRICTED, or UNSAFE
-- invalid: nonstrict inverse with strict forward function -- invalid: nonstrict inverse with strict forward function
CREATE FUNCTION float8mi_n(float8, float8) RETURNS float8 AS CREATE FUNCTION float8mi_n(float8, float8) RETURNS float8 AS
$$ SELECT $1 - $2; $$ $$ SELECT $1 - $2; $$
......
...@@ -23,10 +23,10 @@ CREATE AGGREGATE newsum ( ...@@ -23,10 +23,10 @@ CREATE AGGREGATE newsum (
-- zero-argument aggregate -- zero-argument aggregate
CREATE AGGREGATE newcnt (*) ( CREATE AGGREGATE newcnt (*) (
sfunc = int8inc, stype = int8, sfunc = int8inc, stype = int8,
initcond = '0' initcond = '0', parallel = safe
); );
-- old-style spelling of same -- old-style spelling of same (except without parallel-safe; that's too new)
CREATE AGGREGATE oldcnt ( CREATE AGGREGATE oldcnt (
sfunc = int8inc, basetype = 'ANY', stype = int8, sfunc = int8inc, basetype = 'ANY', stype = int8,
initcond = '0' initcond = '0'
...@@ -201,6 +201,14 @@ WHERE aggfnoid = 'myavg'::REGPROC; ...@@ -201,6 +201,14 @@ WHERE aggfnoid = 'myavg'::REGPROC;
DROP AGGREGATE myavg (numeric); DROP AGGREGATE myavg (numeric);
-- invalid: bad parallel-safety marking
CREATE AGGREGATE mysum (int)
(
stype = int,
sfunc = int4pl,
parallel = pear
);
-- invalid: nonstrict inverse with strict forward function -- invalid: nonstrict inverse with strict forward function
CREATE FUNCTION float8mi_n(float8, float8) RETURNS float8 AS CREATE FUNCTION float8mi_n(float8, float8) RETURNS float8 AS
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment