8000 Allow aggregate transition states to be serialized and deserialized. · postgres/postgres@5fe5a2c · GitHub
[go: up one dir, main page]

Skip to content

Commit 5fe5a2c

Browse files
committed
Allow aggregate transition states to be serialized and deserialized.
This is necessary infrastructure for supporting parallel aggregation for aggregates whose transition type is "internal". Such values can't be passed between cooperating processes, because they are just pointers. David Rowley, reviewed by Tomas Vondra and by me.
1 parent 7f0a2c8 commit 5fe5a2c

File tree

25 files changed

+794
-193
lines changed
  • optimizer
  • parser
  • bin/pg_dump
  • include
  • 25 files changed

    +794
    -193
    lines changed

    doc/src/sgml/catalogs.sgml

    Lines changed: 18 additions & 0 deletions
    Original file line numberDiff line numberDiff line change
    @@ -412,6 +412,18 @@
    412412
    <entry><literal><link linkend="catalog-pg-proc"><structname>pg_proc</structname></link>.oid</literal></entry>
    413413
    <entry>Combine function (zero if none)</entry>
    414414
    </row>
    415+
    <row>
    416+
    <entry><structfield>aggserialfn</structfield></entry>
    417+
    <entry><type>regproc</type></entry>
    418+
    <entry><literal><link linkend="catalog-pg-proc"><structname>pg_proc</structname></link>.oid</literal></entry>
    419+
    <entry>Serialization function (zero if none)</entry>
    420+
    </row>
    421+
    <row>
    422+
    <entry><structfield>aggdeserialfn</structfield></entry>
    423+
    <entry><type>regproc</type></entry>
    424+
    <entry><literal><link linkend="catalog-pg-proc"><structname>pg_proc</structname></link>.oid</literal></entry>
    425+
    <entry>Deserialization function (zero if none)</entry>
    426+
    </row>
    415427
    <row>
    416428
    <entry><structfield>aggmtransfn</structfield></entry>
    417429
    <entry><type>regproc</type></entry>
    @@ -454,6 +466,12 @@
    454466
    <entry><literal><link linkend="catalog-pg-type"><structname>pg_type</structname></link>.oid</literal></entry>
    455467
    <entry>Data type of the aggregate function's internal transition (state) data</entry>
    456468
    </row>
    469+
    <row>
    470+
    <entry><structfield>aggserialtype</structfield></entry>
    471+
    <entry><type>oid</type></entry>
    472+
    <entry><literal><link linkend="catalog-pg-type"><structname>pg_type</structname></link>.oid</literal></entry>
    473+
    <entry>Return data type of the aggregate function's serialization function (zero if none)</entry>
    474+
    </row>
    457475
    <row>
    458476
    <entry><structfield>aggtransspace</structfield></entry>
    459477
    <entry><type>int4</type></entry>

    doc/src/sgml/ref/create_aggregate.sgml

    Lines changed: 50 additions & 0 deletions
    Original file line numberDiff line numberDiff line change
    @@ -28,6 +28,9 @@ CREATE AGGREGATE <replaceable class="parameter">name</replaceable> ( [ <replacea
    2828
    [ , FINALFUNC = <replaceable class="PARAMETER">ffunc</replaceable> ]
    2929
    [ , FINALFUNC_EXTRA ]
    3030
    [ , COMBINEFUNC = <replaceable class="PARAMETER">combinefunc</replaceable> ]
    31+
    [ , SERIALFUNC = <replaceable class="PARAMETER">serialfunc</replaceable> ]
    32+
    [ , DESERIALFUNC = <replaceable class="PARAMETER">deserialfunc</replaceable> ]
    33+
    [ , SERIALTYPE = <replaceable class="PARAMETER">serialtype</replaceable> ]
    3134
    [ , INITCOND = <replaceable class="PARAMETER">initial_condition</replaceable> ]
    3235
    [ , MSFUNC = <replaceable class="PARAMETER">msfunc</replaceable> ]
    3336
    [ , MINVFUNC = <replaceable class="PARAMETER">minvfunc</replaceable> ]
    @@ -47,6 +50,9 @@ CREATE AGGREGATE <replaceable class="parameter">name</replaceable> ( [ [ <replac
    4750
    [ , FINALFUNC = <replaceable class="PARAMETER">ffunc</replaceable> ]
    4851
    [ , FINALFUNC_EXTRA ]
    4952
    [ , COMBINEFUNC = <replaceable class="PARAMETER">combinefunc</replaceable> ]
    53+
    [ , SERIALFUNC = <replaceable class="PARAMETER">serialfunc</replaceable> ]
    54+
    [ , DESERIALFUNC = <replaceable class="PARAMETER">deserialfunc</replaceable> ]
    55+
    [ , SERIALTYPE = <replaceable class="PARAMETER">serialtype</replaceable> ]
    5056
    [ , INITCOND = <replaceable class="PARAMETER">initial_condition</replaceable> ]
    5157
    [ , HYPOTHETICAL ]
    5258
    )
    @@ -61,6 +67,9 @@ CREATE AGGREGATE <replaceable class="PARAMETER">name</replaceable> (
    6167
    [ , FINALFUNC = <replaceable class="PARAMETER">ffunc</replaceable> ]
    6268
    [ , FINALFUNC_EXTRA ]
    6369
    [ , COMBINEFUNC = <replaceable class="PARAMETER">combinefunc</replaceable> ]
    70+
    [ , SERIALFUNC = <replaceable class="PARAMETER">serialfunc</replaceable> ]
    71+
    [ , DESERIALFUNC = <replaceable class="PARAMETER">deserialfunc</replaceable> ]
    72+
    [ , SERIALTYPE = <replaceable class="PARAMETER">serialtype</replaceable> ]
    6473
    [ , INITCOND = <replaceable class="PARAMETER">initial_condition</replaceable> ]
    6574
    [ , MSFUNC = <replaceable class="PARAMETER">msfunc</replaceable> ]
    6675
    [ , MINVFUNC = <replaceable class="PARAMETER">minvfunc</replaceable> ]
    @@ -436,6 +445,47 @@ SELECT col FROM tab ORDER BY col USING sortop LIMIT 1;
    436445
    </listitem>
    437446
    </varlistentry>
    438447

    448+
    <varlistentry>
    449+
    <term><replaceable class="PARAMETER">serialfunc</replaceable></term>
    450+
    <listitem>
    451+
    <para>
    452+
    In order to allow aggregate functions with an <literal>INTERNAL</>
    453+
    <replaceable class="PARAMETER">state_data_type</replaceable> to
    454+
    participate in parallel aggregation, the aggregate must have a valid
    455+
    <replaceable class="PARAMETER">serialfunc</replaceable>, which must
    456+
    serialize the aggregate state into <replaceable class="PARAMETER">
    457+
    serialtype</replaceable>. This function must take a single argument of
    458+
    <replaceable class="PARAMETER">state_data_type</replaceable> and return
    459+
    <replaceable class="PARAMETER">serialtype</replaceable>. A
    460+
    corresponding <replaceable class="PARAMETER">deserialfunc</replaceable>
    461+
    is also required.
    462+
    </para>
    463+
    </listitem>
    464+
    </varlistentry>
    465+
    466+
    <varlistentry>
    467+
    <term><replaceable class="PARAMETER">deserialfunc</replaceable></term>
    468+
    <listitem>
    469+
    <para>
    470+
    Deserializes a previously serialized aggregate state back into
    471+
    <replaceable class="PARAMETER">state_data_type</replaceable>. This
    472+
    function must take a single argument of <replaceable class="PARAMETER">
    473+
    serialtype</replaceable> and return <replaceable class="PARAMETER">
    474+
    state_data_type</replaceable>.
    475+
    </para>
    476+
    </listitem>
    477+
    </varlistentry>
    478+
    479+
    <varlistentry>
    480+
    <term><replaceable class="PARAMETER">serialtype</replaceable></term>
    481+
    <listitem>
    482+
    <para>
    483+
    The data type to into which an <literal>INTERNAL</literal> aggregate
    484+
    state should be serialized.
    485+
    </para>
    486+
    </listitem>
    487+
    </varlistentry>
    488+
    439489
    <varlistentry>
    440490
    <term><replaceable class="PARAMETER">initial_condition</replaceable></term>
    441491
    <listitem>

    src/backend/catalog/pg_aggregate.c

    Lines changed: 79 additions & 1 deletion
    Original file line numberDiff line numberDiff line change
    @@ -58,13 +58,16 @@ AggregateCreate(const char *aggName,
    5858
    List *aggtransfnName,
    5959
    List *aggfinalfnName,
    6060
    List *aggcombinefnName,
    61+
    List *aggserialfnName,
    62+
    List *aggdeserialfnName,
    6163
    List *aggmtransfnName,
    6264
    List *aggminvtransfnName,
    6365
    List *aggmfinalfnName,
    6466
    bool finalfnExtraArgs,
    6567
    bool mfinalfnExtraArgs,
    6668
    List *aggsortopName,
    6769
    Oid aggTransType,
    70+
    Oid aggSerialType,
    6871
    int32 aggTransSpace,
    6972
    Oid aggmTransType,
    7073
    int32 aggmTransSpace,
    @@ -79,6 +82,8 @@ AggregateCreate(const char *aggName,
    7982
    Oid transfn;
    8083
    Oid finalfn = InvalidOid; /* can be omitted */
    8184
    Oid combinefn = InvalidOid; /* can be omitted */
    85+
    Oid serialfn = InvalidOid; /* can be omitted */
    86+
    Oid deserialfn = InvalidOid; /* can be omitted */
    8287
    Oid mtransfn = InvalidOid; /* can be omitted */
    8388
    Oid minvtransfn = InvalidOid; /* can be omitted */
    8489
    Oid mfinalfn = InvalidOid; /* can be omitted */
    @@ -420,6 +425,57 @@ AggregateCreate(const char *aggName,
    420425
    errmsg("return type of combine function %s is not %s",
    421426
    NameListToString(aggcombinefnName),
    422427
    format_type_be(aggTransType))));
    428+
    429+
    /*
    430+
    * A combine function to combine INTERNAL states must accept nulls and
    431+
    * ensure that the returned state is in the correct memory context.
    432+
    */
    433+
    if (aggTransType == INTERNALOID && func_strict(combinefn))
    434+
    ereport(ERROR,
    435+
    (errcode(ERRCODE_INVALID_FUNCTION_DEFINITION),
    436+
    errmsg("combine function with \"%s\" transition type must not be declared STRICT",
    437+
    format_type_be(aggTransType))));
    438+
    439+
    }
    440+
    441+
    /*
    442+
    * Validate the serialization function, if present. We must ensure that the
    443+
    * return type of this function is the same as the specified serialType.
    444+
    */
    445+
    if (aggserialfnName)
    446+
    {
    447+
    fnArgs[0] = aggTransType;
    448+
    449+
    serialfn = lookup_agg_function(aggserialfnName, 1,
    450+
    fnArgs, variadicArgType,
    451+
    &rettype);
    452+
    453+
    if (rettype != aggSerialType)
    454+
    ereport(ERROR,
    455+
    (errcode(ERRCODE_DATATYPE_MISMATCH),
    456+
    errmsg("return type of serialization function %s is not %s",
    457+
    NameListToString(aggserialfnName),
    458+
    format_type_be(aggSerialType))));
    459+
    }
    460+
    461+
    /*
    462+
    * Validate the deserialization function, if present. We must ensure that
    463+
    * the return type of this function is the same as the transType.
    464+
    */
    465+
    if (aggdeserialfnName)
    466+
    {
    467+
    fnArgs[0] = aggSerialType;
    468+
    469+
    deserialfn = lookup_agg_function(aggdeserialfnName, 1,
    470+
    fnArgs, variadicArgType,
    471+
    &rettype);
    472+
    473+
    if (rettype != aggTransType)
    474+
    ereport(ERROR,
    475+
    (errcode(ERRCODE_DATATYPE_MISMATCH),
    476+
    errmsg("return type of deserialization function %s is not %s",
    477+
    NameListToString(aggdeserialfnName),
    478+
    format_type_be(aggTransType))));
    423479
    }
    424480

    425481
    /*
    @@ -594,13 +650,16 @@ AggregateCreate(const char *aggName,
    594650
    values[Anum_pg_aggregate_aggtransfn - 1] = ObjectIdGetDatum(transfn);
    595651
    values[Anum_pg_aggregate_aggfinalfn - 1] = ObjectIdGetDatum(finalfn);
    596652
    values[Anum_pg_aggregate_aggcombinefn - 1] = ObjectIdGetDatum(combinefn);
    653+
    values[Anum_pg_aggregate_aggserialfn - 1] = ObjectIdGetDatum(serialfn);
    654+
    values[Anum_pg_aggregate_aggdeserialfn - 1] = ObjectIdGetDatum(deserialfn);
    597655
    values[Anum_pg_aggregate_aggmtransfn - 1] = ObjectIdGetDatum(mtransfn);
    598656
    values[Anum_pg_aggregate_aggminvtransfn - 1] = ObjectIdGetDatum(minvtransfn);
    599657
    values[Anum_pg_aggregate_aggmfinalfn - 1] = ObjectIdGetDatum(mfinalfn);
    600658
    values[Anum_pg_aggregate_aggfinalextra - 1] = BoolGetDatum(finalfnExtraArgs);
    601659
    values[Anum_pg_aggregate_aggmfinalextra - 1] = BoolGetDatum(mfinalfnExtraArgs);
    602660
    values[Anum_pg_aggregate_aggsortop - 1] = ObjectIdGetDatum(sortop);
    603661
    values[Anum_pg_aggregate_aggtranstype - 1] = ObjectIdGetDatum(aggTransType);
    662+
    values[Anum_pg_aggregate_aggserialtype - 1] = ObjectIdGetDatum(aggSerialType);
    604663
    values[Anum_pg_aggregate_aggtransspace - 1] = Int32GetDatum(aggTransSpace);
    605664
    values[Anum_pg_aggregate_aggmtranstype - 1] = ObjectIdGetDatum(aggmTransType);
    606665
    values[Anum_pg_aggregate_aggmtransspace - 1] = Int32GetDatum(aggmTransSpace);
    @@ -627,7 +686,8 @@ AggregateCreate(const char *aggName,
    627686
    * Create dependencies for the aggregate (above and beyond those already
    628687
    * made by ProcedureCreate). Note: we don't need an explicit dependency
    629688
    * on aggTransType since we depend on it indirectly through transfn.
    630-
    * Likewise for aggmTransType if any.
    689+
    * Likewise for aggmTransType using the mtransfunc, and also for
    690+
    * aggSerialType using the serialfn, if they exist.
    631691
    */
    632692

    633693
    /* Depends on transition function */
    @@ -654,6 +714,24 @@ AggregateCreate(const char *aggName,
    654714
    recordDependencyOn(&myself, &referenced, DEPENDENCY_NORMAL);
    655715
    }
    656716

    717+
    /* Depends on serialization function, if any */
    718+
    if (OidIsValid(serialfn))
    719+
    {
    720+
    referenced.classId = ProcedureRelationId;
    721+
    referenced.objectId = serialfn;
    722+
    referenced.objectSubId = 0;
    723+
    recordDependencyOn(&myself, &referenced, DEPENDENCY_NORMAL);
    724+
    }
    725+
    726+
    /* Depends on deserialization function, if any */
    727+
    if (OidIsValid(deserialfn))
    728+
    {
    729+
    referenced.classId = ProcedureRelationId;
    730+
    referenced.objectId = deserialfn;
    731+
    referenced.objectSubId = 0;
    732+
    recordDependencyOn(&myself, &referenced, DEPENDENCY_NORMAL);
    733+
    }
    734+
    657735
    /* Depends on forward transition function, if any */
    658736
    if (OidIsValid(mtransfn))
    659737
    {

    src/backend/commands/aggregatecmds.c

    Lines changed: 82 additions & 0 deletions
    Original file line numberDiff line numberDiff line change
    @@ -62,6 +62,8 @@ DefineAggregate(List *name, List *args, bool oldstyle, List *parameters,
    6262
    List *transfuncName = NIL;
    6363
    List *finalfuncName = NIL;
    6464
    List *combinefuncName = NIL;
    65+
    List *serialfuncName = NIL;
    66+
    List *deserialfuncName = NIL;
    6567
    List *mtransfuncName = NIL;
    6668
    List *minvtransfuncName = NIL;
    6769
    List *mfinalfuncName = NIL;
    @@ -70,6 +72,7 @@ DefineAggregate(List *name, List *args, bool oldstyle, List *parameters,
    7072
    List *sortoperatorName = NIL;
    7173
    TypeName *baseType = NULL;
    7274
    TypeName *transType = NULL;
    75+
    TypeName *serialType = NULL;
    7376
    TypeName *mtransType = NULL;
    7477
    int32 transSpace = 0;
    7578
    int32 mtransSpace = 0;
    @@ -84,6 +87,7 @@ DefineAggregate(List *name, List *args, bool oldstyle, List *parameters,
    8487
    List *parameterDefaults;
    8588
    Oid variadicArgType;
    8689
    Oid transTypeId;
    90+
    Oid serialTypeId = InvalidOid;
    8791
    Oid mtransTypeId = InvalidOid;
    8892
    char transTypeType;
    8993
    char mtransTypeType = 0;
    @@ -127,6 +131,10 @@ DefineAggregate(List *name, List *args, bool oldstyle, List *parameters,
    127131
    finalfuncName = defGetQualifiedName(defel);
    128132
    else if (pg_strcasecmp(defel->defname, "combinefunc") == 0)
    129133
    combinefuncName = defGetQualifiedName(defel);
    134+
    else if (pg_strcasecmp(defel->defname, "serialfunc") == 0)
    135+
    serialfuncName = defGetQualifiedName(defel);
    136+
    else if (pg_strcasecmp(defel->defname, "deserialfunc") == 0)
    137+
    deserialfuncName = defGetQualifiedName(defel);
    130138
    else if (pg_strcasecmp(defel->defname, "msfunc") == 0)
    131139
    mtransfuncName = defGetQualifiedName(defel);
    132140
    else if (pg_strcasecmp(defel->defname, "minvfunc") == 0)
    @@ -154,6 +162,8 @@ DefineAggregate(List *name, List *args, bool oldstyle, List *parameters,
    154162
    }
    155163
    else if (pg_strcasecmp(defel->defname, "stype") == 0)
    156164
    transType = defGetTypeName(defel);
    165+
    else if (pg_strcasecmp(defel->defname, "serialtype") == 0)
    166+
    serialType = defGetTypeName(defel);
    157167
    else if (pg_strcasecmp(defel->defname, "stype1") == 0)
    158168
    transType = defGetTypeName(defel);
    159169
    else if (pg_strcasecmp(defel->defname, "sspace") == 0)
    @@ -319,6 +329,75 @@ DefineAggregate(List *name, List *args, bool oldstyle, List *parameters,
    319329
    format_type_be(transTypeId))));
    320330
    }
    321331

    332+
    if (serialType)
    333+
    {
    334+
    /*
    335+
    * There's little point in having a serialization/deserialization
    336+
    * function on aggregates that don't have an internal state, so let's
    337+
    * just disallow this as it may help clear up any confusion or needless
    338+
    * authoring of these functions.
    339+
    */
    340+
    if (transTypeId != INTERNALOID)
    341+
    ereport(ERROR,
    342+
    (errcode(ERRCODE_INVALID_FUNCTION_DEFINITION),
    343+
    errmsg("a serialization type must only be specified when the aggregate transition data type is \"%s\"",
    344+
    format_type_be(INTERNALOID))));
    345+
    346+
    serialTypeId = typenameTypeId(NULL, serialType);
    347+
    348+
    if (get_typtype(mtransTypeId) == TYPTYPE_PSEUDO &&
    349+
    !IsPolymorphicType(serialTypeId))
    350+
    ereport(ERROR,
    351+
    (errcode(ERRCODE_INVALID_FUNCTION_DEFINITION),
    352+
    errmsg("aggregate serialization data type cannot be %s",
    353+
    format_type_be(serialTypeId))));
    354+
    355+
    /*
    356+
    * We disallow INTERNAL serialType as the whole point of the
    357+
    * serialized types is to allow the aggregate state to be output,
    358+
    * and we cannot output INTERNAL. This check, combined with the one
    359+
    * above ensures that the trans type and serialization type are not the
    360+
    * same.
    361+
    */
    362+
    if (serialTypeId == INTERNALOID)
    363+
    ereport(ERROR,
    364+
    (errcode(ERRCODE_INVALID_FUNCTION_DEFINITION),
    365+
    errmsg("aggregate serialization type cannot be \"%s\"",
    366+
    format_type_be(serialTypeId))));
    367+
    368+
    /*
    369+
    * If serialType is specified then serialfuncName and deserialfuncName
    370+
    * must be present; if not, then none of the serialization options
    371+
    * should have been specified.
    372+
    */
    373+
    if (serialfuncName == NIL)
    374+
    ereport(ERROR,
    375+
    (errcode(ERRCODE_INVALID_FUNCTION_DEFINITION),
    376+
    errmsg("aggregate serialization function must be specified when serialization type is specified")));
    377+
    378+
    if (deserialfuncName == NIL)
    379+
    ereport(ERROR,
    380+
    (errcode(ERRCODE_INVALID_FUNCTION_DEFINITION),
    381+
    errmsg("aggregate deserialization function must be specified when serialization type is specified")));
    382+
    }
    383+
    else
    384+
    {
    385+
    /*
    386+
    * If serialization type was not specified then there shouldn't be a
    387+
    * serialization function.
    388+
    */
    389+
    if (serialfuncName != NIL)
    390+
    ereport(ERROR,
    391+
    (errcode(ERRCODE_INVALID_FUNCTION_DEFINITION),
    392+
    errmsg("must specify serialization type when specifying serialization function")));
    393+
    394+
    /* likewise for the deserialization function */
    395+
    if (deserialfuncName != NIL)
    396+
    ereport(ERROR,
    397+
    (errcode(ERRCODE_INVALID_FUNCTION_DEFINITION),
    398+
    errmsg("must specify serialization type when specifying deserialization function")));
    399+
    }
    400+
    322401
    /*
    323402
    * If a moving-aggregate transtype is specified, look that up. Same
    324403
    * restrictions as for transtype.
    @@ -387,13 +466,16 @@ DefineAggregate(List *name, List *args, bool oldstyle, List *parameters,
    387466
    transfuncName, /* step function name */
    388467
    finalfuncName, /* final function name */
    389468
    combinefuncName, /* combine function name */
    469+
    serialfuncName, /* serial function name */
    470+
    deserialfuncName, /* deserial function name */
    390471
    mtransfuncName, /* fwd trans function name */
    391472
    minvtransfuncName, /* inv trans function name */
    392473
    mfinalfuncName, /* final function name */
    393474
    finalfuncExtraArgs,
    394475
    mfinalfuncExtraArgs,
    395476
    sortoperatorName, /* sort operator name */
    396477
    transTypeId, /* transition data type */
    478+
    serialTypeId, /* serialization data type */
    397479
    transSpace, /* transition space */
    398480
    mtransTypeId, /* transition data type */
    399481
    mtransSpace, /* transition space */

    0 commit comments

    Comments
     (0)
    0