Commit 05ca21b8 authored by Tom Lane's avatar Tom Lane

Fix type checking for support functions of parallel VARIADIC aggregates.

The impact of VARIADIC on the combine/serialize/deserialize support
functions of an aggregate wasn't thought through carefully.  There is
actually no impact, because variadicity isn't passed through to these
functions (and it doesn't seem like it would need to be).  However,
lookup_agg_function was mistakenly told to check things as though it were
passed through.  The net result was that it was impossible to declare an
aggregate that had both VARIADIC input and parallelism support functions.

In passing, fix a runtime check in nodeAgg.c for the combine function's
strictness to make its error message agree with the creation-time check.
The previous message was actually backwards, and it doesn't seem like
there's a good reason to have two versions of this message text anyway.

Back-patch to 9.6 where parallel aggregation was introduced.

Alexey Bashtanov; message fix by me

Discussion: https://postgr.es/m/f86dde87-fef4-71eb-0480-62754aaca01b@imap.cc
parent 185f4f84
...@@ -410,16 +410,17 @@ AggregateCreate(const char *aggName, ...@@ -410,16 +410,17 @@ AggregateCreate(const char *aggName,
Oid combineType; Oid combineType;
/* /*
* Combine function must have 2 argument, each of which is the trans * Combine function must have 2 arguments, each of which is the trans
* type * type. VARIADIC doesn't affect it.
*/ */
fnArgs[0] = aggTransType; fnArgs[0] = aggTransType;
fnArgs[1] = aggTransType; fnArgs[1] = aggTransType;
combinefn = lookup_agg_function(aggcombinefnName, 2, fnArgs, combinefn = lookup_agg_function(aggcombinefnName, 2,
variadicArgType, &combineType); fnArgs, InvalidOid,
&combineType);
/* Ensure the return type matches the aggregates trans type */ /* Ensure the return type matches the aggregate's trans type */
if (combineType != aggTransType) if (combineType != aggTransType)
ereport(ERROR, ereport(ERROR,
(errcode(ERRCODE_DATATYPE_MISMATCH), (errcode(ERRCODE_DATATYPE_MISMATCH),
...@@ -429,14 +430,14 @@ AggregateCreate(const char *aggName, ...@@ -429,14 +430,14 @@ AggregateCreate(const char *aggName,
/* /*
* A combine function to combine INTERNAL states must accept nulls and * A combine function to combine INTERNAL states must accept nulls and
* ensure that the returned state is in the correct memory context. * ensure that the returned state is in the correct memory context. We
* cannot directly check the latter, but we can check the former.
*/ */
if (aggTransType == INTERNALOID && func_strict(combinefn)) if (aggTransType == INTERNALOID && func_strict(combinefn))
ereport(ERROR, ereport(ERROR,
(errcode(ERRCODE_INVALID_FUNCTION_DEFINITION), (errcode(ERRCODE_INVALID_FUNCTION_DEFINITION),
errmsg("combine function with transition type %s must not be declared STRICT", errmsg("combine function with transition type %s must not be declared STRICT",
format_type_be(aggTransType)))); format_type_be(aggTransType))));
} }
/* /*
...@@ -444,10 +445,11 @@ AggregateCreate(const char *aggName, ...@@ -444,10 +445,11 @@ AggregateCreate(const char *aggName,
*/ */
if (aggserialfnName) if (aggserialfnName)
{ {
/* signature is always serialize(internal) returns bytea */
fnArgs[0] = INTERNALOID; fnArgs[0] = INTERNALOID;
serialfn = lookup_agg_function(aggserialfnName, 1, serialfn = lookup_agg_function(aggserialfnName, 1,
fnArgs, variadicArgType, fnArgs, InvalidOid,
&rettype); &rettype);
if (rettype != BYTEAOID) if (rettype != BYTEAOID)
...@@ -463,11 +465,12 @@ AggregateCreate(const char *aggName, ...@@ -463,11 +465,12 @@ AggregateCreate(const char *aggName,
*/ */
if (aggdeserialfnName) if (aggdeserialfnName)
{ {
/* signature is always deserialize(bytea, internal) returns internal */
fnArgs[0] = BYTEAOID; fnArgs[0] = BYTEAOID;
fnArgs[1] = INTERNALOID; /* dummy argument for type safety */ fnArgs[1] = INTERNALOID; /* dummy argument for type safety */
deserialfn = lookup_agg_function(aggdeserialfnName, 2, deserialfn = lookup_agg_function(aggdeserialfnName, 2,
fnArgs, variadicArgType, fnArgs, InvalidOid,
&rettype); &rettype);
if (rettype != INTERNALOID) if (rettype != INTERNALOID)
...@@ -770,7 +773,11 @@ AggregateCreate(const char *aggName, ...@@ -770,7 +773,11 @@ AggregateCreate(const char *aggName,
/* /*
* lookup_agg_function * lookup_agg_function
* common code for finding transfn, invtransfn, finalfn, and combinefn * common code for finding aggregate support functions
*
* fnName: possibly-schema-qualified function name
* nargs, input_types: expected function argument types
* variadicArgType: type of variadic argument if any, else InvalidOid
* *
* Returns OID of function, and stores its return type into *rettype * Returns OID of function, and stores its return type into *rettype
* *
......
...@@ -2940,8 +2940,8 @@ build_pertrans_for_aggref(AggStatePerTrans pertrans, ...@@ -2940,8 +2940,8 @@ build_pertrans_for_aggref(AggStatePerTrans pertrans,
if (pertrans->transfn.fn_strict && aggtranstype == INTERNALOID) if (pertrans->transfn.fn_strict && aggtranstype == INTERNALOID)
ereport(ERROR, ereport(ERROR,
(errcode(ERRCODE_INVALID_FUNCTION_DEFINITION), (errcode(ERRCODE_INVALID_FUNCTION_DEFINITION),
errmsg("combine function for aggregate %u must be declared as STRICT", errmsg("combine function with transition type %s must not be declared STRICT",
aggref->aggfnoid))); format_type_be(aggtranstype))));
} }
else else
{ {
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment