Commit 5fe5a2ce authored by Robert Haas's avatar Robert Haas

Allow aggregate transition states to be serialized and deserialized.

This is necessary infrastructure for supporting parallel aggregation
for aggregates whose transition type is "internal".  Such values
can't be passed between cooperating processes, because they are
just pointers.

David Rowley, reviewed by Tomas Vondra and by me.
parent 7f0a2c85
......@@ -412,6 +412,18 @@
<entry><literal><link linkend="catalog-pg-proc"><structname>pg_proc</structname></link>.oid</literal></entry>
<entry>Combine function (zero if none)</entry>
</row>
<row>
<entry><structfield>aggserialfn</structfield></entry>
<entry><type>regproc</type></entry>
<entry><literal><link linkend="catalog-pg-proc"><structname>pg_proc</structname></link>.oid</literal></entry>
<entry>Serialization function (zero if none)</entry>
</row>
<row>
<entry><structfield>aggdeserialfn</structfield></entry>
<entry><type>regproc</type></entry>
<entry><literal><link linkend="catalog-pg-proc"><structname>pg_proc</structname></link>.oid</literal></entry>
<entry>Deserialization function (zero if none)</entry>
</row>
<row>
<entry><structfield>aggmtransfn</structfield></entry>
<entry><type>regproc</type></entry>
......@@ -454,6 +466,12 @@
<entry><literal><link linkend="catalog-pg-type"><structname>pg_type</structname></link>.oid</literal></entry>
<entry>Data type of the aggregate function's internal transition (state) data</entry>
</row>
<row>
<entry><structfield>aggserialtype</structfield></entry>
<entry><type>oid</type></entry>
<entry><literal><link linkend="catalog-pg-type"><structname>pg_type</structname></link>.oid</literal></entry>
<entry>Return data type of the aggregate function's serialization function (zero if none)</entry>
</row>
<row>
<entry><structfield>aggtransspace</structfield></entry>
<entry><type>int4</type></entry>
......
......@@ -28,6 +28,9 @@ CREATE AGGREGATE <replaceable class="parameter">name</replaceable> ( [ <replacea
[ , FINALFUNC = <replaceable class="PARAMETER">ffunc</replaceable> ]
[ , FINALFUNC_EXTRA ]
[ , COMBINEFUNC = <replaceable class="PARAMETER">combinefunc</replaceable> ]
[ , SERIALFUNC = <replaceable class="PARAMETER">serialfunc</replaceable> ]
[ , DESERIALFUNC = <replaceable class="PARAMETER">deserialfunc</replaceable> ]
[ , SERIALTYPE = <replaceable class="PARAMETER">serialtype</replaceable> ]
[ , INITCOND = <replaceable class="PARAMETER">initial_condition</replaceable> ]
[ , MSFUNC = <replaceable class="PARAMETER">msfunc</replaceable> ]
[ , MINVFUNC = <replaceable class="PARAMETER">minvfunc</replaceable> ]
......@@ -47,6 +50,9 @@ CREATE AGGREGATE <replaceable class="parameter">name</replaceable> ( [ [ <replac
[ , FINALFUNC = <replaceable class="PARAMETER">ffunc</replaceable> ]
[ , FINALFUNC_EXTRA ]
[ , COMBINEFUNC = <replaceable class="PARAMETER">combinefunc</replaceable> ]
[ , SERIALFUNC = <replaceable class="PARAMETER">serialfunc</replaceable> ]
[ , DESERIALFUNC = <replaceable class="PARAMETER">deserialfunc</replaceable> ]
[ , SERIALTYPE = <replaceable class="PARAMETER">serialtype</replaceable> ]
[ , INITCOND = <replaceable class="PARAMETER">initial_condition</replaceable> ]
[ , HYPOTHETICAL ]
)
......@@ -61,6 +67,9 @@ CREATE AGGREGATE <replaceable class="PARAMETER">name</replaceable> (
[ , FINALFUNC = <replaceable class="PARAMETER">ffunc</replaceable> ]
[ , FINALFUNC_EXTRA ]
[ , COMBINEFUNC = <replaceable class="PARAMETER">combinefunc</replaceable> ]
[ , SERIALFUNC = <replaceable class="PARAMETER">serialfunc</replaceable> ]
[ , DESERIALFUNC = <replaceable class="PARAMETER">deserialfunc</replaceable> ]
[ , SERIALTYPE = <replaceable class="PARAMETER">serialtype</replaceable> ]
[ , INITCOND = <replaceable class="PARAMETER">initial_condition</replaceable> ]
[ , MSFUNC = <replaceable class="PARAMETER">msfunc</replaceable> ]
[ , MINVFUNC = <replaceable class="PARAMETER">minvfunc</replaceable> ]
......@@ -436,6 +445,47 @@ SELECT col FROM tab ORDER BY col USING sortop LIMIT 1;
</listitem>
</varlistentry>
<varlistentry>
<term><replaceable class="PARAMETER">serialfunc</replaceable></term>
<listitem>
<para>
In order to allow aggregate functions with an <literal>INTERNAL</>
<replaceable class="PARAMETER">state_data_type</replaceable> to
participate in parallel aggregation, the aggregate must have a valid
<replaceable class="PARAMETER">serialfunc</replaceable>, which must
serialize the aggregate state into <replaceable class="PARAMETER">
serialtype</replaceable>. This function must take a single argument of
<replaceable class="PARAMETER">state_data_type</replaceable> and return
<replaceable class="PARAMETER">serialtype</replaceable>. A
corresponding <replaceable class="PARAMETER">deserialfunc</replaceable>
is also required.
</para>
</listitem>
</varlistentry>
<varlistentry>
<term><replaceable class="PARAMETER">deserialfunc</replaceable></term>
<listitem>
<para>
Deserializes a previously serialized aggregate state back into
<replaceable class="PARAMETER">state_data_type</replaceable>. This
function must take a single argument of <replaceable class="PARAMETER">
serialtype</replaceable> and return <replaceable class="PARAMETER">
state_data_type</replaceable>.
</para>
</listitem>
</varlistentry>
<varlistentry>
<term><replaceable class="PARAMETER">serialtype</replaceable></term>
<listitem>
<para>
The data type to into which an <literal>INTERNAL</literal> aggregate
state should be serialized.
</para>
</listitem>
</varlistentry>
<varlistentry>
<term><replaceable class="PARAMETER">initial_condition</replaceable></term>
<listitem>
......
......@@ -58,6 +58,8 @@ AggregateCreate(const char *aggName,
List *aggtransfnName,
List *aggfinalfnName,
List *aggcombinefnName,
List *aggserialfnName,
List *aggdeserialfnName,
List *aggmtransfnName,
List *aggminvtransfnName,
List *aggmfinalfnName,
......@@ -65,6 +67,7 @@ AggregateCreate(const char *aggName,
bool mfinalfnExtraArgs,
List *aggsortopName,
Oid aggTransType,
Oid aggSerialType,
int32 aggTransSpace,
Oid aggmTransType,
int32 aggmTransSpace,
......@@ -79,6 +82,8 @@ AggregateCreate(const char *aggName,
Oid transfn;
Oid finalfn = InvalidOid; /* can be omitted */
Oid combinefn = InvalidOid; /* can be omitted */
Oid serialfn = InvalidOid; /* can be omitted */
Oid deserialfn = InvalidOid; /* can be omitted */
Oid mtransfn = InvalidOid; /* can be omitted */
Oid minvtransfn = InvalidOid; /* can be omitted */
Oid mfinalfn = InvalidOid; /* can be omitted */
......@@ -420,6 +425,57 @@ AggregateCreate(const char *aggName,
errmsg("return type of combine function %s is not %s",
NameListToString(aggcombinefnName),
format_type_be(aggTransType))));
/*
* A combine function to combine INTERNAL states must accept nulls and
* ensure that the returned state is in the correct memory context.
*/
if (aggTransType == INTERNALOID && func_strict(combinefn))
ereport(ERROR,
(errcode(ERRCODE_INVALID_FUNCTION_DEFINITION),
errmsg("combine function with \"%s\" transition type must not be declared STRICT",
format_type_be(aggTransType))));
}
/*
* Validate the serialization function, if present. We must ensure that the
* return type of this function is the same as the specified serialType.
*/
if (aggserialfnName)
{
fnArgs[0] = aggTransType;
serialfn = lookup_agg_function(aggserialfnName, 1,
fnArgs, variadicArgType,
&rettype);
if (rettype != aggSerialType)
ereport(ERROR,
(errcode(ERRCODE_DATATYPE_MISMATCH),
errmsg("return type of serialization function %s is not %s",
NameListToString(aggserialfnName),
format_type_be(aggSerialType))));
}
/*
* Validate the deserialization function, if present. We must ensure that
* the return type of this function is the same as the transType.
*/
if (aggdeserialfnName)
{
fnArgs[0] = aggSerialType;
deserialfn = lookup_agg_function(aggdeserialfnName, 1,
fnArgs, variadicArgType,
&rettype);
if (rettype != aggTransType)
ereport(ERROR,
(errcode(ERRCODE_DATATYPE_MISMATCH),
errmsg("return type of deserialization function %s is not %s",
NameListToString(aggdeserialfnName),
format_type_be(aggTransType))));
}
/*
......@@ -594,6 +650,8 @@ AggregateCreate(const char *aggName,
values[Anum_pg_aggregate_aggtransfn - 1] = ObjectIdGetDatum(transfn);
values[Anum_pg_aggregate_aggfinalfn - 1] = ObjectIdGetDatum(finalfn);
values[Anum_pg_aggregate_aggcombinefn - 1] = ObjectIdGetDatum(combinefn);
values[Anum_pg_aggregate_aggserialfn - 1] = ObjectIdGetDatum(serialfn);
values[Anum_pg_aggregate_aggdeserialfn - 1] = ObjectIdGetDatum(deserialfn);
values[Anum_pg_aggregate_aggmtransfn - 1] = ObjectIdGetDatum(mtransfn);
values[Anum_pg_aggregate_aggminvtransfn - 1] = ObjectIdGetDatum(minvtransfn);
values[Anum_pg_aggregate_aggmfinalfn - 1] = ObjectIdGetDatum(mfinalfn);
......@@ -601,6 +659,7 @@ AggregateCreate(const char *aggName,
values[Anum_pg_aggregate_aggmfinalextra - 1] = BoolGetDatum(mfinalfnExtraArgs);
values[Anum_pg_aggregate_aggsortop - 1] = ObjectIdGetDatum(sortop);
values[Anum_pg_aggregate_aggtranstype - 1] = ObjectIdGetDatum(aggTransType);
values[Anum_pg_aggregate_aggserialtype - 1] = ObjectIdGetDatum(aggSerialType);
values[Anum_pg_aggregate_aggtransspace - 1] = Int32GetDatum(aggTransSpace);
values[Anum_pg_aggregate_aggmtranstype - 1] = ObjectIdGetDatum(aggmTransType);
values[Anum_pg_aggregate_aggmtransspace - 1] = Int32GetDatum(aggmTransSpace);
......@@ -627,7 +686,8 @@ AggregateCreate(const char *aggName,
* Create dependencies for the aggregate (above and beyond those already
* made by ProcedureCreate). Note: we don't need an explicit dependency
* on aggTransType since we depend on it indirectly through transfn.
* Likewise for aggmTransType if any.
* Likewise for aggmTransType using the mtransfunc, and also for
* aggSerialType using the serialfn, if they exist.
*/
/* Depends on transition function */
......@@ -654,6 +714,24 @@ AggregateCreate(const char *aggName,
recordDependencyOn(&myself, &referenced, DEPENDENCY_NORMAL);
}
/* Depends on serialization function, if any */
if (OidIsValid(serialfn))
{
referenced.classId = ProcedureRelationId;
referenced.objectId = serialfn;
referenced.objectSubId = 0;
recordDependencyOn(&myself, &referenced, DEPENDENCY_NORMAL);
}
/* Depends on deserialization function, if any */
if (OidIsValid(deserialfn))
{
referenced.classId = ProcedureRelationId;
referenced.objectId = deserialfn;
referenced.objectSubId = 0;
recordDependencyOn(&myself, &referenced, DEPENDENCY_NORMAL);
}
/* Depends on forward transition function, if any */
if (OidIsValid(mtransfn))
{
......
......@@ -62,6 +62,8 @@ DefineAggregate(List *name, List *args, bool oldstyle, List *parameters,
List *transfuncName = NIL;
List *finalfuncName = NIL;
List *combinefuncName = NIL;
List *serialfuncName = NIL;
List *deserialfuncName = NIL;
List *mtransfuncName = NIL;
List *minvtransfuncName = NIL;
List *mfinalfuncName = NIL;
......@@ -70,6 +72,7 @@ DefineAggregate(List *name, List *args, bool oldstyle, List *parameters,
List *sortoperatorName = NIL;
TypeName *baseType = NULL;
TypeName *transType = NULL;
TypeName *serialType = NULL;
TypeName *mtransType = NULL;
int32 transSpace = 0;
int32 mtransSpace = 0;
......@@ -84,6 +87,7 @@ DefineAggregate(List *name, List *args, bool oldstyle, List *parameters,
List *parameterDefaults;
Oid variadicArgType;
Oid transTypeId;
Oid serialTypeId = InvalidOid;
Oid mtransTypeId = InvalidOid;
char transTypeType;
char mtransTypeType = 0;
......@@ -127,6 +131,10 @@ DefineAggregate(List *name, List *args, bool oldstyle, List *parameters,
finalfuncName = defGetQualifiedName(defel);
else if (pg_strcasecmp(defel->defname, "combinefunc") == 0)
combinefuncName = defGetQualifiedName(defel);
else if (pg_strcasecmp(defel->defname, "serialfunc") == 0)
serialfuncName = defGetQualifiedName(defel);
else if (pg_strcasecmp(defel->defname, "deserialfunc") == 0)
deserialfuncName = defGetQualifiedName(defel);
else if (pg_strcasecmp(defel->defname, "msfunc") == 0)
mtransfuncName = defGetQualifiedName(defel);
else if (pg_strcasecmp(defel->defname, "minvfunc") == 0)
......@@ -154,6 +162,8 @@ DefineAggregate(List *name, List *args, bool oldstyle, List *parameters,
}
else if (pg_strcasecmp(defel->defname, "stype") == 0)
transType = defGetTypeName(defel);
else if (pg_strcasecmp(defel->defname, "serialtype") == 0)
serialType = defGetTypeName(defel);
else if (pg_strcasecmp(defel->defname, "stype1") == 0)
transType = defGetTypeName(defel);
else if (pg_strcasecmp(defel->defname, "sspace") == 0)
......@@ -319,6 +329,75 @@ DefineAggregate(List *name, List *args, bool oldstyle, List *parameters,
format_type_be(transTypeId))));
}
if (serialType)
{
/*
* There's little point in having a serialization/deserialization
* function on aggregates that don't have an internal state, so let's
* just disallow this as it may help clear up any confusion or needless
* authoring of these functions.
*/
if (transTypeId != INTERNALOID)
ereport(ERROR,
(errcode(ERRCODE_INVALID_FUNCTION_DEFINITION),
errmsg("a serialization type must only be specified when the aggregate transition data type is \"%s\"",
format_type_be(INTERNALOID))));
serialTypeId = typenameTypeId(NULL, serialType);
if (get_typtype(mtransTypeId) == TYPTYPE_PSEUDO &&
!IsPolymorphicType(serialTypeId))
ereport(ERROR,
(errcode(ERRCODE_INVALID_FUNCTION_DEFINITION),
errmsg("aggregate serialization data type cannot be %s",
format_type_be(serialTypeId))));
/*
* We disallow INTERNAL serialType as the whole point of the
* serialized types is to allow the aggregate state to be output,
* and we cannot output INTERNAL. This check, combined with the one
* above ensures that the trans type and serialization type are not the
* same.
*/
if (serialTypeId == INTERNALOID)
ereport(ERROR,
(errcode(ERRCODE_INVALID_FUNCTION_DEFINITION),
errmsg("aggregate serialization type cannot be \"%s\"",
format_type_be(serialTypeId))));
/*
* If serialType is specified then serialfuncName and deserialfuncName
* must be present; if not, then none of the serialization options
* should have been specified.
*/
if (serialfuncName == NIL)
ereport(ERROR,
(errcode(ERRCODE_INVALID_FUNCTION_DEFINITION),
errmsg("aggregate serialization function must be specified when serialization type is specified")));
if (deserialfuncName == NIL)
ereport(ERROR,
(errcode(ERRCODE_INVALID_FUNCTION_DEFINITION),
errmsg("aggregate deserialization function must be specified when serialization type is specified")));
}
else
{
/*
* If serialization type was not specified then there shouldn't be a
* serialization function.
*/
if (serialfuncName != NIL)
ereport(ERROR,
(errcode(ERRCODE_INVALID_FUNCTION_DEFINITION),
errmsg("must specify serialization type when specifying serialization function")));
/* likewise for the deserialization function */
if (deserialfuncName != NIL)
ereport(ERROR,
(errcode(ERRCODE_INVALID_FUNCTION_DEFINITION),
errmsg("must specify serialization type when specifying deserialization function")));
}
/*
* If a moving-aggregate transtype is specified, look that up. Same
* restrictions as for transtype.
......@@ -387,6 +466,8 @@ DefineAggregate(List *name, List *args, bool oldstyle, List *parameters,
transfuncName, /* step function name */
finalfuncName, /* final function name */
combinefuncName, /* combine function name */
serialfuncName, /* serial function name */
deserialfuncName, /* deserial function name */
mtransfuncName, /* fwd trans function name */
minvtransfuncName, /* inv trans function name */
mfinalfuncName, /* final function name */
......@@ -394,6 +475,7 @@ DefineAggregate(List *name, List *args, bool oldstyle, List *parameters,
mfinalfuncExtraArgs,
sortoperatorName, /* sort operator name */
transTypeId, /* transition data type */
serialTypeId, /* serialization data type */
transSpace, /* transition space */
mtransTypeId, /* transition data type */
mtransSpace, /* transition space */
......
This diff is collapsed.
......@@ -871,6 +871,7 @@ _copyAgg(const Agg *from)
COPY_SCALAR_FIELD(aggstrategy);
COPY_SCALAR_FIELD(combineStates);
COPY_SCALAR_FIELD(finalizeAggs);
COPY_SCALAR_FIELD(serialStates);
COPY_SCALAR_FIELD(numCols);
if (from->numCols > 0)
{
......
......@@ -708,6 +708,7 @@ _outAgg(StringInfo str, const Agg *node)
WRITE_ENUM_FIELD(aggstrategy, AggStrategy);
WRITE_BOOL_FIELD(combineStates);
WRITE_BOOL_FIELD(finalizeAggs);
WRITE_BOOL_FIELD(serialStates);
WRITE_INT_FIELD(numCols);
appendStringInfoString(str, " :grpColIdx");
......
......@@ -1993,6 +1993,7 @@ _readAgg(void)
READ_ENUM_FIELD(aggstrategy, AggStrategy);
READ_BOOL_FIELD(combineStates);
READ_BOOL_FIELD(finalizeAggs);
READ_BOOL_FIELD(serialStates);
READ_INT_FIELD(numCols);
READ_ATTRNUMBER_ARRAY(grpColIdx, local_node->numCols);
READ_OID_ARRAY(grpOperators, local_node->numCols);
......
......@@ -1279,6 +1279,7 @@ create_unique_plan(PlannerInfo *root, UniquePath *best_path, int flags)
AGG_HASHED,
false,
true,
false,
numGroupCols,
groupColIdx,
groupOperators,
......@@ -1578,6 +1579,7 @@ create_agg_plan(PlannerInfo *root, AggPath *best_path)
best_path->aggstrategy,
best_path->combineStates,
best_path->finalizeAggs,
best_path->serialStates,
list_length(best_path->groupClause),
extract_grouping_cols(best_path->groupClause,
subplan->targetlist),
......@@ -1732,6 +1734,7 @@ create_groupingsets_plan(PlannerInfo *root, GroupingSetsPath *best_path)
AGG_SORTED,
false,
true,
false,
list_length((List *) linitial(gsets)),
new_grpColIdx,
extract_grouping_ops(groupClause),
......@@ -1768,6 +1771,7 @@ create_groupingsets_plan(PlannerInfo *root, GroupingSetsPath *best_path)
(numGroupCols > 0) ? AGG_SORTED : AGG_PLAIN,
false,
true,
false,
numGroupCols,
top_grpColIdx,
extract_grouping_ops(groupClause),
......@@ -5636,7 +5640,7 @@ materialize_finished_plan(Plan *subplan)
Agg *
make_agg(List *tlist, List *qual,
AggStrategy aggstrategy,
bool combineStates, bool finalizeAggs,
bool combineStates, bool finalizeAggs, bool serialStates,
int numGroupCols, AttrNumber *grpColIdx, Oid *grpOperators,
List *groupingSets, List *chain,
double dNumGroups, Plan *lefttree)
......@@ -5651,6 +5655,7 @@ make_agg(List *tlist, List *qual,
node->aggstrategy = aggstrategy;
node->combineStates = combineStates;
node->finalizeAggs = finalizeAggs;
node->serialStates = serialStates;
node->numCols = numGroupCols;
node->grpColIdx = grpColIdx;
node->grpOperators = grpOperators;
......
......@@ -3455,7 +3455,8 @@ create_grouping_paths(PlannerInfo *root,
&agg_costs,
dNumPartialGroups,
false,
false));
false,
true));
else
add_partial_path(grouped_rel, (Path *)
create_group_path(root,
......@@ -3496,7 +3497,8 @@ create_grouping_paths(PlannerInfo *root,
&agg_costs,
dNumPartialGroups,
false,
false));
false,
true));
}
}
}
......@@ -3560,7 +3562,8 @@ create_grouping_paths(PlannerInfo *root,
&agg_costs,
dNumGroups,
false,
true));
true,
false));
}
else if (parse->groupClause)
{
......@@ -3626,6 +3629,7 @@ create_grouping_paths(PlannerInfo *root,
&agg_costs,
dNumGroups,
true,
true,
true));
else
add_path(grouped_rel, (Path *)
......@@ -3668,7 +3672,8 @@ create_grouping_paths(PlannerInfo *root,
&agg_costs,
dNumGroups,
false,
true));
true,
false));
}
/*
......@@ -3706,6 +3711,7 @@ create_grouping_paths(PlannerInfo *root,
&agg_costs,
dNumGroups,
true,
true,
true));
}
}
......@@ -4039,7 +4045,8 @@ create_distinct_paths(PlannerInfo *root,
NULL,
numDistinctRows,
false,
true));
true,
false));
}
/* Give a helpful error if we failed to find any implementation */
......
......@@ -2057,10 +2057,10 @@ search_indexed_tlist_for_sortgroupref(Node *node,
* search_indexed_tlist_for_partial_aggref - find an Aggref in an indexed tlist
*
* Aggrefs for partial aggregates have their aggoutputtype adjusted to set it
* to the aggregate state's type. This means that a standard equal() comparison
* won't match when comparing an Aggref which is in partial mode with an Aggref
* which is not. Here we manually compare all of the fields apart from
* aggoutputtype.
* to the aggregate state's type, or serialization type. This means that a
* standard equal() comparison won't match when comparing an Aggref which is
* in partial mode with an Aggref which is not. Here we manually compare all of
* the fields apart from aggoutputtype.
*/
static Var *
search_indexed_tlist_for_partial_aggref(Aggref *aggref, indexed_tlist *itlist,
......
......@@ -861,7 +861,8 @@ make_union_unique(SetOperationStmt *op, Path *path, List *tlist,
NULL,
dNumGroups,
false,
true);
true,
false);
}
else
{
......
......@@ -464,11 +464,15 @@ aggregates_allow_partial_walker(Node *node, partial_agg_context *context)
}
/*
* If we find any aggs with an internal transtype then we must ensure
* that pointers to aggregate states are not passed to other processes;
* therefore, we set the maximum allowed type to PAT_INTERNAL_ONLY.
* If we find any aggs with an internal transtype then we must check
* that these have a serialization type, serialization func and
* deserialization func; otherwise, we set the maximum allowed type to
* PAT_INTERNAL_ONLY.
*/
if (aggform->aggtranstype == INTERNALOID)
if (aggform->aggtranstype == INTERNALOID &&
(!OidIsValid(aggform->aggserialtype) ||
!OidIsValid(aggform->aggserialfn) ||
!OidIsValid(aggform->aggdeserialfn)))
context->allowedtype = PAT_INTERNAL_ONLY;
ReleaseSysCache(aggTuple);
......
......@@ -2433,7 +2433,8 @@ create_agg_path(PlannerInfo *root,
const AggClauseCosts *aggcosts,
double numGroups,
bool combineStates,
bool finalizeAggs)
bool finalizeAggs,
bool serialStates)
{
AggPath *pathnode = makeNode(AggPath);
......@@ -2458,6 +2459,7 @@ create_agg_path(PlannerInfo *root,
pathnode->qual = qual;
pathnode->finalizeAggs = finalizeAggs;
pathnode->combineStates = combineStates;
pathnode->serialStates = serialStates;
cost_agg(&pathnode->path, root,
aggstrategy, aggcosts,
......
......@@ -756,8 +756,8 @@ apply_pathtarget_labeling_to_tlist(List *tlist, PathTarget *target)
* apply_partialaggref_adjustment
* Convert PathTarget to be suitable for a partial aggregate node. We simply
* adjust any Aggref nodes found in the target and set the aggoutputtype to
* the aggtranstype. This allows exprType() to return the actual type that
* will be produced.
* the aggtranstype or aggserialtype. This allows exprType() to return the
* actual type that will be produced.
*
* Note: We expect 'target' to be a flat target list and not have Aggrefs burried
* within other expressions.
......@@ -785,7 +785,12 @@ apply_partialaggref_adjustment(PathTarget *target)
aggform = (Form_pg_aggregate) GETSTRUCT(aggTuple);
newaggref = (Aggref *) copyObject(aggref);
newaggref->aggoutputtype = aggform->aggtranstype;
/* use the serialization type, if one exists */
if (OidIsValid(aggform->aggserialtype))
newaggref->aggoutputtype = aggform->aggserialtype;
else
newaggref->aggoutputtype = aggform->aggtranstype;
lfirst(lc) = newaggref;
......
......@@ -1964,6 +1964,45 @@ build_aggregate_combinefn_expr(Oid agg_state_type,
*combinefnexpr = (Expr *) fexpr;
}
/*
* Like build_aggregate_transfn_expr, but creates an expression tree for the
* serialization or deserialization function of an aggregate, rather than the
* transition function. This may be used for either the serialization or
* deserialization function by swapping the first two parameters over.
*/
void
build_aggregate_serialfn_expr(Oid agg_input_type,
Oid agg_output_type,
Oid agg_input_collation,
Oid serialfn_oid,
Expr **serialfnexpr)
{
Param *argp;
List *args;
FuncExpr *fexpr;
/* Build arg list to use in the FuncExpr node. */
argp = makeNode(Param);
argp->paramkind = PARAM_EXEC;
argp->paramid = -1;
argp->paramtype = agg_input_type;
argp->paramtypmod = -1;
argp->paramcollid = agg_input_collation;
argp->location = -1;
/* takes a single arg of the agg_input_type */
args = list_make1(argp);
fexpr = makeFuncExpr(serialfn_oid,
agg_output_type,
args,
InvalidOid,
agg_input_collation,
COERCE_EXPLICIT_CALL);
fexpr->funcvariadic = false;
*serialfnexpr = (Expr *) fexpr;
}
/*
* Like build_aggregate_transfn_expr, but creates an expression tree for the
* final function of an aggregate, rather than the transition function.
......
......@@ -12557,6 +12557,8 @@ dumpAgg(Archive *fout, AggInfo *agginfo)
int i_aggtransfn;
int i_aggfinalfn;
int i_aggcombinefn;
int i_aggserialfn;
int i_aggdeserialfn;
int i_aggmtransfn;
int i_aggminvtransfn;
int i_aggmfinalfn;
......@@ -12565,6 +12567,7 @@ dumpAgg(Archive *fout, AggInfo *agginfo)
int i_aggsortop;
int i_hypothetical;
int i_aggtranstype;
int i_aggserialtype;
int i_aggtransspace;
int i_aggmtranstype;
int i_aggmtransspace;
......@@ -12574,6 +12577,8 @@ dumpAgg(Archive *fout, AggInfo *agginfo)
const char *aggtransfn;
const char *aggfinalfn;
const char *aggcombinefn;
const char *aggserialfn;
const char *aggdeserialfn;
const char *aggmtransfn;
const char *aggminvtransfn;
const char *aggmfinalfn;
......@@ -12583,6 +12588,7 @@ dumpAgg(Archive *fout, AggInfo *agginfo)
char *aggsortconvop;
bool hypothetical;
const char *aggtranstype;
const char *aggserialtype;
const char *aggtransspace;
const char *aggmtranstype;
const char *aggmtransspace;
......@@ -12608,10 +12614,11 @@ dumpAgg(Archive *fout, AggInfo *agginfo)
{
appendPQExpBuffer(query, "SELECT aggtransfn, "
"aggfinalfn, aggtranstype::pg_catalog.regtype, "
"aggcombinefn, aggmtransfn, "
"aggcombinefn, aggserialfn, aggdeserialfn, aggmtransfn, "
"aggminvtransfn, aggmfinalfn, aggmtranstype::pg_catalog.regtype, "
"aggfinalextra, aggmfinalextra, "
"aggsortop::pg_catalog.regoperator, "
"aggserialtype::pg_catalog.regtype, "
"(aggkind = 'h') AS hypothetical, "
"aggtransspace, agginitval, "
"aggmtransspace, aggminitval, "
......@@ -12627,10 +12634,12 @@ dumpAgg(Archive *fout, AggInfo *agginfo)
{
appendPQExpBuffer(query, "SELECT aggtransfn, "
"aggfinalfn, aggtranstype::pg_catalog.regtype, "
"'-' AS aggcombinefn, aggmtransfn, aggminvtransfn, "
"'-' AS aggcombinefn, '-' AS aggserialfn, "
"'-' AS aggdeserialfn, aggmtransfn, aggminvtransfn, "
"aggmfinalfn, aggmtranstype::pg_catalog.regtype, "
"aggfinalextra, aggmfinalextra, "
"aggsortop::pg_catalog.regoperator, "
"0 AS aggserialtype, "
"(aggkind = 'h') AS hypothetical, "
"aggtransspace, agginitval, "
"aggmtransspace, aggminitval, "
......@@ -12646,11 +12655,13 @@ dumpAgg(Archive *fout, AggInfo *agginfo)
{
appendPQExpBuffer(query, "SELECT aggtransfn, "
"aggfinalfn, aggtranstype::pg_catalog.regtype, "
"'-' AS aggcombinefn, '-' AS aggmtransfn, "
"'-' AS aggcombinefn, '-' AS aggserialfn, "
"'-' AS aggdeserialfn, '-' AS aggmtransfn, "
"'-' AS aggminvtransfn, '-' AS aggmfinalfn, "
"0 AS aggmtranstype, false AS aggfinalextra, "
"false AS aggmfinalextra, "
"aggsortop::pg_catalog.regoperator, "
"0 AS aggserialtype, "
"false AS hypothetical, "
"0 AS aggtransspace, agginitval, "
"0 AS aggmtransspace, NULL AS aggminitval, "
......@@ -12666,11 +12677,13 @@ dumpAgg(Archive *fout, AggInfo *agginfo)
{
appendPQExpBuffer(query, "SELECT aggtransfn, "
"aggfinalfn, aggtranstype::pg_catalog.regtype, "
"'-' AS aggcombinefn, '-' AS aggmtransfn, "
"'-' AS aggcombinefn, '-' AS aggserialfn, "
"'-' AS aggdeserialfn, '-' AS aggmtransfn, "
"'-' AS aggminvtransfn, '-' AS aggmfinalfn, "
"0 AS aggmtranstype, false AS aggfinalextra, "
"false AS aggmfinalextra, "
"aggsortop::pg_catalog.regoperator, "
"0 AS aggserialtype, "
"false AS hypothetical, "
"0 AS aggtransspace, agginitval, "
"0 AS aggmtransspace, NULL AS aggminitval, "
......@@ -12684,10 +12697,12 @@ dumpAgg(Archive *fout, AggInfo *agginfo)
{
appendPQExpBuffer(query, "SELECT aggtransfn, "
"aggfinalfn, aggtranstype::pg_catalog.regtype, "
"'-' AS aggcombinefn, '-' AS aggmtransfn, "
"'-' AS aggcombinefn, '-' AS aggserialfn, "
"'-' AS aggdeserialfn, '-' AS aggmtransfn, "
"'-' AS aggminvtransfn, '-' AS aggmfinalfn, "
"0 AS aggmtranstype, false AS aggfinalextra, "
"false AS aggmfinalextra, 0 AS aggsortop, "
"0 AS aggserialtype, "
"false AS hypothetical, "
"0 AS aggtransspace, agginitval, "
"0 AS aggmtransspace, NULL AS aggminitval, "
......@@ -12701,10 +12716,12 @@ dumpAgg(Archive *fout, AggInfo *agginfo)
{
appendPQExpBuffer(query, "SELECT aggtransfn, aggfinalfn, "
"format_type(aggtranstype, NULL) AS aggtranstype, "
"'-' AS aggcombinefn, '-' AS aggmtransfn, "
"'-' AS aggcombinefn, '-' AS aggserialfn, "
"'-' AS aggdeserialfn, '-' AS aggmtransfn, "
"'-' AS aggminvtransfn, '-' AS aggmfinalfn, "
"0 AS aggmtranstype, false AS aggfinalextra, "
"false AS aggmfinalextra, 0 AS aggsortop, "
"0 AS aggserialtype, "
"false AS hypothetical, "
"0 AS aggtransspace, agginitval, "
"0 AS aggmtransspace, NULL AS aggminitval, "
......@@ -12718,10 +12735,12 @@ dumpAgg(Archive *fout, AggInfo *agginfo)
appendPQExpBuffer(query, "SELECT aggtransfn1 AS aggtransfn, "
"aggfinalfn, "
"(SELECT typname FROM pg_type WHERE oid = aggtranstype1) AS aggtranstype, "
"'-' AS aggcombinefn, '-' AS aggmtransfn, "
"'-' AS aggcombinefn, '-' AS aggserialfn, "
"'-' AS aggdeserialfn, '-' AS aggmtransfn, "
"'-' AS aggminvtransfn, '-' AS aggmfinalfn, "
"0 AS aggmtranstype, false AS aggfinalextra, "
"false AS aggmfinalextra, 0 AS aggsortop, "
"0 AS aggserialtype, "
"false AS hypothetical, "
"0 AS aggtransspace, agginitval1 AS agginitval, "
"0 AS aggmtransspace, NULL AS aggminitval, "
......@@ -12736,12 +12755,15 @@ dumpAgg(Archive *fout, AggInfo *agginfo)
i_aggtransfn = PQfnumber(res, "aggtransfn");
i_aggfinalfn = PQfnumber(res, "aggfinalfn");
i_aggcombinefn = PQfnumber(res, "aggcombinefn");
i_aggserialfn = PQfnumber(res, "aggserialfn");
i_aggdeserialfn = PQfnumber(res, "aggdeserialfn");
i_aggmtransfn = PQfnumber(res, "aggmtransfn");
i_aggminvtransfn = PQfnumber(res, "aggminvtransfn");
i_aggmfinalfn = PQfnumber(res, "aggmfinalfn");
i_aggfinalextra = PQfnumber(res, "aggfinalextra");
i_aggmfinalextra = PQfnumber(res, "aggmfinalextra");
i_aggsortop = PQfnumber(res, "aggsortop");
i_aggserialtype = PQfnumber(res, "aggserialtype");
i_hypothetical = PQfnumber(res, "hypothetical");
i_aggtranstype = PQfnumber(res, "aggtranstype");
i_aggtransspace = PQfnumber(res, "aggtransspace");
......@@ -12754,6 +12776,8 @@ dumpAgg(Archive *fout, AggInfo *agginfo)
aggtransfn = PQgetvalue(res, 0, i_aggtransfn);
aggfinalfn = PQgetvalue(res, 0, i_aggfinalfn);
aggcombinefn = PQgetvalue(res, 0, i_aggcombinefn);
aggserialfn = PQgetvalue(res, 0, i_aggserialfn);
aggdeserialfn = PQgetvalue(res, 0, i_aggdeserialfn);
aggmtransfn = PQgetvalue(res, 0, i_aggmtransfn);
aggminvtransfn = PQgetvalue(res, 0, i_aggminvtransfn);
aggmfinalfn = PQgetvalue(res, 0, i_aggmfinalfn);
......@@ -12762,6 +12786,7 @@ dumpAgg(Archive *fout, AggInfo *agginfo)
aggsortop = PQgetvalue(res, 0, i_aggsortop);
hypothetical = (PQgetvalue(res, 0, i_hypothetical)[0] == 't');
aggtranstype = PQgetvalue(res, 0, i_aggtranstype);
aggserialtype = PQgetvalue(res, 0, i_aggserialtype);
aggtransspace = PQgetvalue(res, 0, i_aggtransspace);
aggmtranstype = PQgetvalue(res, 0, i_aggmtranstype);
aggmtransspace = PQgetvalue(res, 0, i_aggmtransspace);
......@@ -12847,6 +12872,17 @@ dumpAgg(Archive *fout, AggInfo *agginfo)
appendPQExpBuffer(details, ",\n COMBINEFUNC = %s", aggcombinefn);
}
/*
* CREATE AGGREGATE should ensure we either have all of these, or none of
* them.
*/
if (strcmp(aggserialfn, "-") != 0)
{
appendPQExpBuffer(details, ",\n SERIALFUNC = %s", aggserialfn);
appendPQExpBuffer(details, ",\n DESERIALFUNC = %s", aggdeserialfn);
appendPQExpBuffer(details, ",\n SERIALTYPE = %s", aggserialtype);
}
if (strcmp(aggmtransfn, "-") != 0)
{
appendPQExpBuffer(details, ",\n MSFUNC = %s,\n MINVFUNC = %s,\n MSTYPE = %s",
......
......@@ -53,6 +53,6 @@
*/
/* yyyymmddN */
#define CATALOG_VERSION_NO 201603231
#define CATALOG_VERSION_NO 201603291
#endif
This diff is collapsed.
......@@ -1836,6 +1836,7 @@ typedef struct AggState
bool agg_done; /* indicates completion of Agg scan */
bool combineStates; /* input tuples contain transition states */
bool finalizeAggs; /* should we call the finalfn on agg states? */
bool serialStates; /* should agg states be (de)serialized? */
int projected_set; /* The last projected grouping set */
int current_set; /* The current grouping set being evaluated */
Bitmapset *grouped_cols; /* grouped cols in current projection */
......
......@@ -712,6 +712,7 @@ typedef struct Agg
AggStrategy aggstrategy; /* basic strategy, see nodes.h */
bool combineStates; /* input tuples contain transition states */
bool finalizeAggs; /* should we call the finalfn on agg states? */
bool serialStates; /* should agg states be (de)serialized? */
int numCols; /* number of grouping columns */
AttrNumber *grpColIdx; /* their indexes in the target list */
Oid *grpOperators; /* equality operators to compare with */
......
......@@ -1296,6 +1296,7 @@ typedef struct AggPath
List *qual; /* quals (HAVING quals), if any */
bool combineStates; /* input is partially aggregated agg states */
bool finalizeAggs; /* should the executor call the finalfn? */
bool serialStates; /* should agg states be (de)serialized? */
} AggPath;
/*
......
......@@ -171,7 +171,8 @@ extern AggPath *create_agg_path(PlannerInfo *root,
const AggClauseCosts *aggcosts,
double numGroups,
bool combineStates,
bool finalizeAggs);
bool finalizeAggs,
bool serialStates);
extern GroupingSetsPath *create_groupingsets_path(PlannerInfo *root,
RelOptInfo *rel,
Path *subpath,
......
......@@ -58,7 +58,7 @@ extern bool is_projection_capable_plan(Plan *plan);
/* External use of these functions is deprecated: */
extern Sort *make_sort_from_sortclauses(List *sortcls, Plan *lefttree);
extern Agg *make_agg(List *tlist, List *qual, AggStrategy aggstrategy,
bool combineStates, bool finalizeAggs,
bool combineStates, bool finalizeAggs, bool serialStates,
int numGroupCols, AttrNumber *grpColIdx, Oid *grpOperators,
List *groupingSets, List *chain,
double dNumGroups, Plan *lefttree);
......
......@@ -51,6 +51,12 @@ extern void build_aggregate_combinefn_expr(Oid agg_state_type,
Oid combinefn_oid,
Expr **combinefnexpr);
extern void build_aggregate_serialfn_expr(Oid agg_state_type,
Oid agg_serial_type,
Oid agg_input_collation,
Oid serialfn_oid,
Expr **serialfnexpr);
extern void build_aggregate_finalfn_expr(Oid *agg_input_types,
int num_finalfn_inputs,
Oid agg_state_type,
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment