Commit f1464c53 authored by Simon Riggs's avatar Simon Riggs

Improve parse representation for MERGE

Separation of parser data structures from executor, as
requested by Tom Lane. Further improvements possible.

While there, implement error for multiple VALUES clauses via parser
to allow line number of error, as requested by Andres Freund.

Author: Pavan Deolasee

Discussion: https://www.postgresql.org/message-id/CABOikdPpqjectFchg0FyTOpsGXyPoqwgC==OLKWuxgBOsrDDZw@mail.gmail.com
parent 3b0b4f31
...@@ -2136,6 +2136,20 @@ _copyOnConflictExpr(const OnConflictExpr *from) ...@@ -2136,6 +2136,20 @@ _copyOnConflictExpr(const OnConflictExpr *from)
return newnode; return newnode;
} }
static MergeAction *
_copyMergeAction(const MergeAction *from)
{
MergeAction *newnode = makeNode(MergeAction);
COPY_SCALAR_FIELD(matched);
COPY_SCALAR_FIELD(commandType);
COPY_SCALAR_FIELD(override);
COPY_NODE_FIELD(qual);
COPY_NODE_FIELD(targetList);
return newnode;
}
/* **************************************************************** /* ****************************************************************
* relation.h copy functions * relation.h copy functions
* *
...@@ -3054,24 +3068,24 @@ _copyMergeStmt(const MergeStmt *from) ...@@ -3054,24 +3068,24 @@ _copyMergeStmt(const MergeStmt *from)
COPY_NODE_FIELD(relation); COPY_NODE_FIELD(relation);
COPY_NODE_FIELD(source_relation); COPY_NODE_FIELD(source_relation);
COPY_NODE_FIELD(join_condition); COPY_NODE_FIELD(join_condition);
COPY_NODE_FIELD(mergeActionList); COPY_NODE_FIELD(mergeWhenClauses);
COPY_NODE_FIELD(withClause); COPY_NODE_FIELD(withClause);
return newnode; return newnode;
} }
static MergeAction * static MergeWhenClause *
_copyMergeAction(const MergeAction *from) _copyMergeWhenClause(const MergeWhenClause *from)
{ {
MergeAction *newnode = makeNode(MergeAction); MergeWhenClause *newnode = makeNode(MergeWhenClause);
COPY_SCALAR_FIELD(matched); COPY_SCALAR_FIELD(matched);
COPY_SCALAR_FIELD(commandType); COPY_SCALAR_FIELD(commandType);
COPY_NODE_FIELD(condition); COPY_NODE_FIELD(condition);
COPY_NODE_FIELD(qual);
COPY_NODE_FIELD(stmt);
COPY_NODE_FIELD(targetList); COPY_NODE_FIELD(targetList);
COPY_NODE_FIELD(cols);
COPY_NODE_FIELD(values);
COPY_SCALAR_FIELD(override);
return newnode; return newnode;
} }
...@@ -5059,6 +5073,9 @@ copyObjectImpl(const void *from) ...@@ -5059,6 +5073,9 @@ copyObjectImpl(const void *from)
case T_OnConflictExpr: case T_OnConflictExpr:
retval = _copyOnConflictExpr(from); retval = _copyOnConflictExpr(from);
break; break;
case T_MergeAction:
retval = _copyMergeAction(from);
break;
/* /*
* RELATION NODES * RELATION NODES
...@@ -5140,8 +5157,8 @@ copyObjectImpl(const void *from) ...@@ -5140,8 +5157,8 @@ copyObjectImpl(const void *from)
case T_MergeStmt: case T_MergeStmt:
retval = _copyMergeStmt(from); retval = _copyMergeStmt(from);
break; break;
case T_MergeAction: case T_MergeWhenClause:
retval = _copyMergeAction(from); retval = _copyMergeWhenClause(from);
break; break;
case T_SelectStmt: case T_SelectStmt:
retval = _copySelectStmt(from); retval = _copySelectStmt(from);
......
...@@ -812,6 +812,18 @@ _equalOnConflictExpr(const OnConflictExpr *a, const OnConflictExpr *b) ...@@ -812,6 +812,18 @@ _equalOnConflictExpr(const OnConflictExpr *a, const OnConflictExpr *b)
return true; return true;
} }
static bool
_equalMergeAction(const MergeAction *a, const MergeAction *b)
{
COMPARE_SCALAR_FIELD(matched);
COMPARE_SCALAR_FIELD(commandType);
COMPARE_SCALAR_FIELD(override);
COMPARE_NODE_FIELD(qual);
COMPARE_NODE_FIELD(targetList);
return true;
}
/* /*
* Stuff from relation.h * Stuff from relation.h
*/ */
...@@ -1050,21 +1062,22 @@ _equalMergeStmt(const MergeStmt *a, const MergeStmt *b) ...@@ -1050,21 +1062,22 @@ _equalMergeStmt(const MergeStmt *a, const MergeStmt *b)
COMPARE_NODE_FIELD(relation); COMPARE_NODE_FIELD(relation);
COMPARE_NODE_FIELD(source_relation); COMPARE_NODE_FIELD(source_relation);
COMPARE_NODE_FIELD(join_condition); COMPARE_NODE_FIELD(join_condition);
COMPARE_NODE_FIELD(mergeActionList); COMPARE_NODE_FIELD(mergeWhenClauses);
COMPARE_NODE_FIELD(withClause); COMPARE_NODE_FIELD(withClause);
return true; return true;
} }
static bool static bool
_equalMergeAction(const MergeAction *a, const MergeAction *b) _equalMergeWhenClause(const MergeWhenClause *a, const MergeWhenClause *b)
{ {
COMPARE_SCALAR_FIELD(matched); COMPARE_SCALAR_FIELD(matched);
COMPARE_SCALAR_FIELD(commandType); COMPARE_SCALAR_FIELD(commandType);
COMPARE_NODE_FIELD(condition); COMPARE_NODE_FIELD(condition);
COMPARE_NODE_FIELD(qual);
COMPARE_NODE_FIELD(stmt);
COMPARE_NODE_FIELD(targetList); COMPARE_NODE_FIELD(targetList);
COMPARE_NODE_FIELD(cols);
COMPARE_NODE_FIELD(values);
COMPARE_SCALAR_FIELD(override);
return true; return true;
} }
...@@ -3192,6 +3205,9 @@ equal(const void *a, const void *b) ...@@ -3192,6 +3205,9 @@ equal(const void *a, const void *b)
case T_OnConflictExpr: case T_OnConflictExpr:
retval = _equalOnConflictExpr(a, b); retval = _equalOnConflictExpr(a, b);
break; break;
case T_MergeAction:
retval = _equalMergeAction(a, b);
break;
case T_JoinExpr: case T_JoinExpr:
retval = _equalJoinExpr(a, b); retval = _equalJoinExpr(a, b);
break; break;
...@@ -3263,8 +3279,8 @@ equal(const void *a, const void *b) ...@@ -3263,8 +3279,8 @@ equal(const void *a, const void *b)
case T_MergeStmt: case T_MergeStmt:
retval = _equalMergeStmt(a, b); retval = _equalMergeStmt(a, b);
break; break;
case T_MergeAction: case T_MergeWhenClause:
retval = _equalMergeAction(a, b); retval = _equalMergeWhenClause(a, b);
break; break;
case T_SelectStmt: case T_SelectStmt:
retval = _equalSelectStmt(a, b); retval = _equalSelectStmt(a, b);
......
...@@ -3444,19 +3444,23 @@ raw_expression_tree_walker(Node *node, ...@@ -3444,19 +3444,23 @@ raw_expression_tree_walker(Node *node,
return true; return true;
if (walker(stmt->join_condition, context)) if (walker(stmt->join_condition, context))
return true; return true;
if (walker(stmt->mergeActionList, context)) if (walker(stmt->mergeWhenClauses, context))
return true; return true;
if (walker(stmt->withClause, context)) if (walker(stmt->withClause, context))
return true; return true;
} }
break; break;
case T_MergeAction: case T_MergeWhenClause:
{ {
MergeAction *action = (MergeAction *) node; MergeWhenClause *mergeWhenClause = (MergeWhenClause *) node;
if (walker(action->targetList, context)) if (walker(mergeWhenClause->condition, context))
return true; return true;
if (walker(action->qual, context)) if (walker(mergeWhenClause->targetList, context))
return true;
if (walker(mergeWhenClause->cols, context))
return true;
if (walker(mergeWhenClause->values, context))
return true; return true;
} }
break; break;
......
...@@ -396,16 +396,17 @@ _outModifyTable(StringInfo str, const ModifyTable *node) ...@@ -396,16 +396,17 @@ _outModifyTable(StringInfo str, const ModifyTable *node)
} }
static void static void
_outMergeAction(StringInfo str, const MergeAction *node) _outMergeWhenClause(StringInfo str, const MergeWhenClause *node)
{ {
WRITE_NODE_TYPE("MERGEACTION"); WRITE_NODE_TYPE("MERGEWHENCLAUSE");
WRITE_BOOL_FIELD(matched); WRITE_BOOL_FIELD(matched);
WRITE_ENUM_FIELD(commandType, CmdType); WRITE_ENUM_FIELD(commandType, CmdType);
WRITE_NODE_FIELD(condition); WRITE_NODE_FIELD(condition);
WRITE_NODE_FIELD(qual);
WRITE_NODE_FIELD(stmt);
WRITE_NODE_FIELD(targetList); WRITE_NODE_FIELD(targetList);
WRITE_NODE_FIELD(cols);
WRITE_NODE_FIELD(values);
WRITE_ENUM_FIELD(override, OverridingKind);
} }
static void static void
...@@ -1724,6 +1725,17 @@ _outOnConflictExpr(StringInfo str, const OnConflictExpr *node) ...@@ -1724,6 +1725,17 @@ _outOnConflictExpr(StringInfo str, const OnConflictExpr *node)
WRITE_NODE_FIELD(exclRelTlist); WRITE_NODE_FIELD(exclRelTlist);
} }
static void
_outMergeAction(StringInfo str, const MergeAction *node)
{
WRITE_NODE_TYPE("MERGEACTION");
WRITE_BOOL_FIELD(matched);
WRITE_ENUM_FIELD(commandType, CmdType);
WRITE_NODE_FIELD(qual);
WRITE_NODE_FIELD(targetList);
}
/***************************************************************************** /*****************************************************************************
* *
* Stuff from relation.h. * Stuff from relation.h.
...@@ -3679,8 +3691,8 @@ outNode(StringInfo str, const void *obj) ...@@ -3679,8 +3691,8 @@ outNode(StringInfo str, const void *obj)
case T_ModifyTable: case T_ModifyTable:
_outModifyTable(str, obj); _outModifyTable(str, obj);
break; break;
case T_MergeAction: case T_MergeWhenClause:
_outMergeAction(str, obj); _outMergeWhenClause(str, obj);
break; break;
case T_Append: case T_Append:
_outAppend(str, obj); _outAppend(str, obj);
...@@ -3958,6 +3970,9 @@ outNode(StringInfo str, const void *obj) ...@@ -3958,6 +3970,9 @@ outNode(StringInfo str, const void *obj)
case T_OnConflictExpr: case T_OnConflictExpr:
_outOnConflictExpr(str, obj); _outOnConflictExpr(str, obj);
break; break;
case T_MergeAction:
_outMergeAction(str, obj);
break;
case T_Path: case T_Path:
_outPath(str, obj); _outPath(str, obj);
break; break;
......
...@@ -1331,6 +1331,22 @@ _readOnConflictExpr(void) ...@@ -1331,6 +1331,22 @@ _readOnConflictExpr(void)
READ_DONE(); READ_DONE();
} }
/*
* _readMergeAction
*/
static MergeAction *
_readMergeAction(void)
{
READ_LOCALS(MergeAction);
READ_BOOL_FIELD(matched);
READ_ENUM_FIELD(commandType, CmdType);
READ_NODE_FIELD(qual);
READ_NODE_FIELD(targetList);
READ_DONE();
}
/* /*
* Stuff from parsenodes.h. * Stuff from parsenodes.h.
*/ */
...@@ -1602,19 +1618,20 @@ _readModifyTable(void) ...@@ -1602,19 +1618,20 @@ _readModifyTable(void)
} }
/* /*
* _readMergeAction * _readMergeWhenClause
*/ */
static MergeAction * static MergeWhenClause *
_readMergeAction(void) _readMergeWhenClause(void)
{ {
READ_LOCALS(MergeAction); READ_LOCALS(MergeWhenClause);
READ_BOOL_FIELD(matched); READ_BOOL_FIELD(matched);
READ_ENUM_FIELD(commandType, CmdType); READ_ENUM_FIELD(commandType, CmdType);
READ_NODE_FIELD(condition); READ_NODE_FIELD(condition);
READ_NODE_FIELD(qual);
READ_NODE_FIELD(stmt);
READ_NODE_FIELD(targetList); READ_NODE_FIELD(targetList);
READ_NODE_FIELD(cols);
READ_NODE_FIELD(values);
READ_ENUM_FIELD(override, OverridingKind);
READ_DONE(); READ_DONE();
} }
...@@ -2596,6 +2613,8 @@ parseNodeString(void) ...@@ -2596,6 +2613,8 @@ parseNodeString(void)
return_value = _readFromExpr(); return_value = _readFromExpr();
else if (MATCH("ONCONFLICTEXPR", 14)) else if (MATCH("ONCONFLICTEXPR", 14))
return_value = _readOnConflictExpr(); return_value = _readOnConflictExpr();
else if (MATCH("MERGEACTION", 11))
return_value = _readMergeAction();
else if (MATCH("RTE", 3)) else if (MATCH("RTE", 3))
return_value = _readRangeTblEntry(); return_value = _readRangeTblEntry();
else if (MATCH("RANGETBLFUNCTION", 16)) else if (MATCH("RANGETBLFUNCTION", 16))
...@@ -2618,8 +2637,8 @@ parseNodeString(void) ...@@ -2618,8 +2637,8 @@ parseNodeString(void)
return_value = _readProjectSet(); return_value = _readProjectSet();
else if (MATCH("MODIFYTABLE", 11)) else if (MATCH("MODIFYTABLE", 11))
return_value = _readModifyTable(); return_value = _readModifyTable();
else if (MATCH("MERGEACTION", 11)) else if (MATCH("MERGEWHENCLAUSE", 15))
return_value = _readMergeAction(); return_value = _readMergeWhenClause();
else if (MATCH("APPEND", 6)) else if (MATCH("APPEND", 6))
return_value = _readAppend(); return_value = _readAppend();
else if (MATCH("MERGEAPPEND", 11)) else if (MATCH("MERGEAPPEND", 11))
......
...@@ -241,6 +241,7 @@ static Node *makeRecursiveViewSelect(char *relname, List *aliases, Node *query); ...@@ -241,6 +241,7 @@ static Node *makeRecursiveViewSelect(char *relname, List *aliases, Node *query);
PartitionSpec *partspec; PartitionSpec *partspec;
PartitionBoundSpec *partboundspec; PartitionBoundSpec *partboundspec;
RoleSpec *rolespec; RoleSpec *rolespec;
MergeWhenClause *mergewhen;
} }
%type <node> stmt schema_stmt %type <node> stmt schema_stmt
...@@ -400,6 +401,7 @@ static Node *makeRecursiveViewSelect(char *relname, List *aliases, Node *query); ...@@ -400,6 +401,7 @@ static Node *makeRecursiveViewSelect(char *relname, List *aliases, Node *query);
TriggerTransitions TriggerReferencing TriggerTransitions TriggerReferencing
publication_name_list publication_name_list
vacuum_relation_list opt_vacuum_relation_list vacuum_relation_list opt_vacuum_relation_list
merge_values_clause
%type <list> group_by_list %type <list> group_by_list
%type <node> group_by_item empty_grouping_set rollup_clause cube_clause %type <node> group_by_item empty_grouping_set rollup_clause cube_clause
...@@ -460,6 +462,7 @@ static Node *makeRecursiveViewSelect(char *relname, List *aliases, Node *query); ...@@ -460,6 +462,7 @@ static Node *makeRecursiveViewSelect(char *relname, List *aliases, Node *query);
%type <istmt> insert_rest %type <istmt> insert_rest
%type <infer> opt_conf_expr %type <infer> opt_conf_expr
%type <onconflict> opt_on_conflict %type <onconflict> opt_on_conflict
%type <mergewhen> merge_insert merge_update merge_delete
%type <vsetstmt> generic_set set_rest set_rest_more generic_reset reset_rest %type <vsetstmt> generic_set set_rest set_rest_more generic_reset reset_rest
SetResetClause FunctionSetResetClause SetResetClause FunctionSetResetClause
...@@ -587,7 +590,6 @@ static Node *makeRecursiveViewSelect(char *relname, List *aliases, Node *query); ...@@ -587,7 +590,6 @@ static Node *makeRecursiveViewSelect(char *relname, List *aliases, Node *query);
%type <node> merge_when_clause opt_merge_when_and_condition %type <node> merge_when_clause opt_merge_when_and_condition
%type <list> merge_when_list %type <list> merge_when_list
%type <node> merge_update merge_delete merge_insert
/* /*
* Non-keyword token types. These are hard-wired into the "flex" lexer. * Non-keyword token types. These are hard-wired into the "flex" lexer.
...@@ -11116,7 +11118,7 @@ MergeStmt: ...@@ -11116,7 +11118,7 @@ MergeStmt:
m->relation = $4; m->relation = $4;
m->source_relation = $6; m->source_relation = $6;
m->join_condition = $8; m->join_condition = $8;
m->mergeActionList = $9; m->mergeWhenClauses = $9;
$$ = (Node *)m; $$ = (Node *)m;
} }
...@@ -11131,45 +11133,37 @@ merge_when_list: ...@@ -11131,45 +11133,37 @@ merge_when_list:
merge_when_clause: merge_when_clause:
WHEN MATCHED opt_merge_when_and_condition THEN merge_update WHEN MATCHED opt_merge_when_and_condition THEN merge_update
{ {
MergeAction *m = makeNode(MergeAction); $5->matched = true;
$5->commandType = CMD_UPDATE;
$5->condition = $3;
m->matched = true; $$ = (Node *) $5;
m->commandType = CMD_UPDATE;
m->condition = $3;
m->stmt = $5;
$$ = (Node *)m;
} }
| WHEN MATCHED opt_merge_when_and_condition THEN merge_delete | WHEN MATCHED opt_merge_when_and_condition THEN merge_delete
{ {
MergeAction *m = makeNode(MergeAction); MergeWhenClause *m = makeNode(MergeWhenClause);
m->matched = true; m->matched = true;
m->commandType = CMD_DELETE; m->commandType = CMD_DELETE;
m->condition = $3; m->condition = $3;
m->stmt = $5;
$$ = (Node *)m; $$ = (Node *)m;
} }
| WHEN NOT MATCHED opt_merge_when_and_condition THEN merge_insert | WHEN NOT MATCHED opt_merge_when_and_condition THEN merge_insert
{ {
MergeAction *m = makeNode(MergeAction); $6->matched = false;
$6->commandType = CMD_INSERT;
$6->condition = $4;
m->matched = false; $$ = (Node *) $6;
m->commandType = CMD_INSERT;
m->condition = $4;
m->stmt = $6;
$$ = (Node *)m;
} }
| WHEN NOT MATCHED opt_merge_when_and_condition THEN DO NOTHING | WHEN NOT MATCHED opt_merge_when_and_condition THEN DO NOTHING
{ {
MergeAction *m = makeNode(MergeAction); MergeWhenClause *m = makeNode(MergeWhenClause);
m->matched = false; m->matched = false;
m->commandType = CMD_NOTHING; m->commandType = CMD_NOTHING;
m->condition = $4; m->condition = $4;
m->stmt = NULL;
$$ = (Node *)m; $$ = (Node *)m;
} }
...@@ -11181,65 +11175,63 @@ opt_merge_when_and_condition: ...@@ -11181,65 +11175,63 @@ opt_merge_when_and_condition:
; ;
merge_delete: merge_delete:
DELETE_P DELETE_P { $$ = NULL; }
{
DeleteStmt *n = makeNode(DeleteStmt);
$$ = (Node *)n;
}
; ;
merge_update: merge_update:
UPDATE SET set_clause_list UPDATE SET set_clause_list
{ {
UpdateStmt *n = makeNode(UpdateStmt); MergeWhenClause *n = makeNode(MergeWhenClause);
n->targetList = $3; n->targetList = $3;
$$ = (Node *)n; $$ = n;
} }
; ;
merge_insert: merge_insert:
INSERT values_clause INSERT merge_values_clause
{ {
InsertStmt *n = makeNode(InsertStmt); MergeWhenClause *n = makeNode(MergeWhenClause);
n->cols = NIL; n->cols = NIL;
n->selectStmt = $2; n->values = $2;
$$ = n;
$$ = (Node *)n;
} }
| INSERT OVERRIDING override_kind VALUE_P values_clause | INSERT OVERRIDING override_kind VALUE_P merge_values_clause
{ {
InsertStmt *n = makeNode(InsertStmt); MergeWhenClause *n = makeNode(MergeWhenClause);
n->cols = NIL; n->cols = NIL;
n->override = $3; n->override = $3;
n->selectStmt = $5; n->values = $5;
$$ = n;
$$ = (Node *)n;
} }
| INSERT '(' insert_column_list ')' values_clause | INSERT '(' insert_column_list ')' merge_values_clause
{ {
InsertStmt *n = makeNode(InsertStmt); MergeWhenClause *n = makeNode(MergeWhenClause);
n->cols = $3; n->cols = $3;
n->selectStmt = $5; n->values = $5;
$$ = n;
$$ = (Node *)n;
} }
| INSERT '(' insert_column_list ')' OVERRIDING override_kind VALUE_P values_clause | INSERT '(' insert_column_list ')' OVERRIDING override_kind VALUE_P merge_values_clause
{ {
InsertStmt *n = makeNode(InsertStmt); MergeWhenClause *n = makeNode(MergeWhenClause);
n->cols = $3; n->cols = $3;
n->override = $6; n->override = $6;
n->selectStmt = $8; n->values = $8;
$$ = n;
$$ = (Node *)n;
} }
| INSERT DEFAULT VALUES | INSERT DEFAULT VALUES
{ {
InsertStmt *n = makeNode(InsertStmt); MergeWhenClause *n = makeNode(MergeWhenClause);
n->cols = NIL; n->cols = NIL;
n->selectStmt = NULL; n->values = NIL;
$$ = n;
}
;
$$ = (Node *)n; merge_values_clause:
VALUES '(' expr_list ')'
{
$$ = $3;
} }
; ;
......
...@@ -33,8 +33,8 @@ ...@@ -33,8 +33,8 @@
static int transformMergeJoinClause(ParseState *pstate, Node *merge, static int transformMergeJoinClause(ParseState *pstate, Node *merge,
List **mergeSourceTargetList); List **mergeSourceTargetList);
static void setNamespaceForMergeAction(ParseState *pstate, static void setNamespaceForMergeWhen(ParseState *pstate,
MergeAction *action); MergeWhenClause *mergeWhenClause);
static void setNamespaceVisibilityForRTE(List *namespace, RangeTblEntry *rte, static void setNamespaceVisibilityForRTE(List *namespace, RangeTblEntry *rte,
bool rel_visible, bool rel_visible,
bool cols_visible); bool cols_visible);
...@@ -138,7 +138,7 @@ transformMergeJoinClause(ParseState *pstate, Node *merge, ...@@ -138,7 +138,7 @@ transformMergeJoinClause(ParseState *pstate, Node *merge,
* that columns can be referenced unqualified from these relations. * that columns can be referenced unqualified from these relations.
*/ */
static void static void
setNamespaceForMergeAction(ParseState *pstate, MergeAction *action) setNamespaceForMergeWhen(ParseState *pstate, MergeWhenClause *mergeWhenClause)
{ {
RangeTblEntry *targetRelRTE, RangeTblEntry *targetRelRTE,
*sourceRelRTE; *sourceRelRTE;
...@@ -152,7 +152,7 @@ setNamespaceForMergeAction(ParseState *pstate, MergeAction *action) ...@@ -152,7 +152,7 @@ setNamespaceForMergeAction(ParseState *pstate, MergeAction *action)
*/ */
sourceRelRTE = rt_fetch(list_length(pstate->p_rtable) - 1, pstate->p_rtable); sourceRelRTE = rt_fetch(list_length(pstate->p_rtable) - 1, pstate->p_rtable);
switch (action->commandType) switch (mergeWhenClause->commandType)
{ {
case CMD_INSERT: case CMD_INSERT:
...@@ -198,6 +198,7 @@ transformMergeStmt(ParseState *pstate, MergeStmt *stmt) ...@@ -198,6 +198,7 @@ transformMergeStmt(ParseState *pstate, MergeStmt *stmt)
bool is_terminal[2]; bool is_terminal[2];
JoinExpr *joinexpr; JoinExpr *joinexpr;
RangeTblEntry *resultRelRTE, *mergeRelRTE; RangeTblEntry *resultRelRTE, *mergeRelRTE;
List *mergeActionList;
/* There can't be any outer WITH to worry about */ /* There can't be any outer WITH to worry about */
Assert(pstate->p_ctenamespace == NIL); Assert(pstate->p_ctenamespace == NIL);
...@@ -222,43 +223,18 @@ transformMergeStmt(ParseState *pstate, MergeStmt *stmt) ...@@ -222,43 +223,18 @@ transformMergeStmt(ParseState *pstate, MergeStmt *stmt)
*/ */
is_terminal[0] = false; is_terminal[0] = false;
is_terminal[1] = false; is_terminal[1] = false;
foreach(l, stmt->mergeActionList) foreach(l, stmt->mergeWhenClauses)
{ {
MergeAction *action = (MergeAction *) lfirst(l); MergeWhenClause *mergeWhenClause = (MergeWhenClause *) lfirst(l);
int when_type = (action->matched ? 0 : 1); int when_type = (mergeWhenClause->matched ? 0 : 1);
/* /*
* Collect action types so we can check Target permissions * Collect action types so we can check Target permissions
*/ */
switch (action->commandType) switch (mergeWhenClause->commandType)
{ {
case CMD_INSERT: case CMD_INSERT:
{ targetPerms |= ACL_INSERT;
InsertStmt *istmt = (InsertStmt *) action->stmt;
SelectStmt *selectStmt = (SelectStmt *) istmt->selectStmt;
/*
* The grammar allows attaching ORDER BY, LIMIT, FOR
* UPDATE, or WITH to a VALUES clause and also multiple
* VALUES clauses. If we have any of those, ERROR.
*/
if (selectStmt && (selectStmt->valuesLists == NIL ||
selectStmt->sortClause != NIL ||
selectStmt->limitOffset != NULL ||
selectStmt->limitCount != NULL ||
selectStmt->lockingClause != NIL ||
selectStmt->withClause != NULL))
ereport(ERROR,
(errcode(ERRCODE_SYNTAX_ERROR),
errmsg("SELECT not allowed in MERGE INSERT statement")));
if (selectStmt && list_length(selectStmt->valuesLists) > 1)
ereport(ERROR,
(errcode(ERRCODE_SYNTAX_ERROR),
errmsg("Multiple VALUES clauses not allowed in MERGE INSERT statement")));
targetPerms |= ACL_INSERT;
}
break; break;
case CMD_UPDATE: case CMD_UPDATE:
targetPerms |= ACL_UPDATE; targetPerms |= ACL_UPDATE;
...@@ -275,7 +251,7 @@ transformMergeStmt(ParseState *pstate, MergeStmt *stmt) ...@@ -275,7 +251,7 @@ transformMergeStmt(ParseState *pstate, MergeStmt *stmt)
/* /*
* Check for unreachable WHEN clauses * Check for unreachable WHEN clauses
*/ */
if (action->condition == NULL) if (mergeWhenClause->condition == NULL)
is_terminal[when_type] = true; is_terminal[when_type] = true;
else if (is_terminal[when_type]) else if (is_terminal[when_type])
ereport(ERROR, ereport(ERROR,
...@@ -461,15 +437,20 @@ transformMergeStmt(ParseState *pstate, MergeStmt *stmt) ...@@ -461,15 +437,20 @@ transformMergeStmt(ParseState *pstate, MergeStmt *stmt)
* both of those already have RTEs. There is nothing like the EXCLUDED * both of those already have RTEs. There is nothing like the EXCLUDED
* pseudo-relation for INSERT ON CONFLICT. * pseudo-relation for INSERT ON CONFLICT.
*/ */
foreach(l, stmt->mergeActionList) mergeActionList = NIL;
foreach(l, stmt->mergeWhenClauses)
{ {
MergeAction *action = (MergeAction *) lfirst(l); MergeWhenClause *mergeWhenClause = (MergeWhenClause *) lfirst(l);
MergeAction *action = makeNode(MergeAction);
action->commandType = mergeWhenClause->commandType;
action->matched = mergeWhenClause->matched;
/* /*
* Set namespace for the specific action. This must be done before * Set namespace for the specific action. This must be done before
* analyzing the WHEN quals and the action targetlisst. * analyzing the WHEN quals and the action targetlisst.
*/ */
setNamespaceForMergeAction(pstate, action); setNamespaceForMergeWhen(pstate, mergeWhenClause);
/* /*
* Transform the when condition. * Transform the when condition.
...@@ -478,7 +459,7 @@ transformMergeStmt(ParseState *pstate, MergeStmt *stmt) ...@@ -478,7 +459,7 @@ transformMergeStmt(ParseState *pstate, MergeStmt *stmt)
* are evaluated separately during execution to decide which of the * are evaluated separately during execution to decide which of the
* WHEN MATCHED or WHEN NOT MATCHED actions to execute. * WHEN MATCHED or WHEN NOT MATCHED actions to execute.
*/ */
action->qual = transformWhereClause(pstate, action->condition, action->qual = transformWhereClause(pstate, mergeWhenClause->condition,
EXPR_KIND_MERGE_WHEN_AND, "WHEN"); EXPR_KIND_MERGE_WHEN_AND, "WHEN");
/* /*
...@@ -488,8 +469,6 @@ transformMergeStmt(ParseState *pstate, MergeStmt *stmt) ...@@ -488,8 +469,6 @@ transformMergeStmt(ParseState *pstate, MergeStmt *stmt)
{ {
case CMD_INSERT: case CMD_INSERT:
{ {
InsertStmt *istmt = (InsertStmt *) action->stmt;
SelectStmt *selectStmt = (SelectStmt *) istmt->selectStmt;
List *exprList = NIL; List *exprList = NIL;
ListCell *lc; ListCell *lc;
RangeTblEntry *rte; RangeTblEntry *rte;
...@@ -500,13 +479,17 @@ transformMergeStmt(ParseState *pstate, MergeStmt *stmt) ...@@ -500,13 +479,17 @@ transformMergeStmt(ParseState *pstate, MergeStmt *stmt)
pstate->p_is_insert = true; pstate->p_is_insert = true;
icolumns = checkInsertTargets(pstate, istmt->cols, &attrnos); icolumns = checkInsertTargets(pstate,
mergeWhenClause->cols,
&attrnos);
Assert(list_length(icolumns) == list_length(attrnos)); Assert(list_length(icolumns) == list_length(attrnos));
action->override = mergeWhenClause->override;
/* /*
* Handle INSERT much like in transformInsertStmt * Handle INSERT much like in transformInsertStmt
*/ */
if (selectStmt == NULL) if (mergeWhenClause->values == NIL)
{ {
/* /*
* We have INSERT ... DEFAULT VALUES. We can handle * We have INSERT ... DEFAULT VALUES. We can handle
...@@ -525,23 +508,19 @@ transformMergeStmt(ParseState *pstate, MergeStmt *stmt) ...@@ -525,23 +508,19 @@ transformMergeStmt(ParseState *pstate, MergeStmt *stmt)
* as the Query's targetlist, with no VALUES RTE. So * as the Query's targetlist, with no VALUES RTE. So
* it works just like a SELECT without any FROM. * it works just like a SELECT without any FROM.
*/ */
List *valuesLists = selectStmt->valuesLists;
Assert(list_length(valuesLists) == 1);
Assert(selectStmt->intoClause == NULL);
/* /*
* Do basic expression transformation (same as a ROW() * Do basic expression transformation (same as a ROW()
* expr, but allow SetToDefault at top level) * expr, but allow SetToDefault at top level)
*/ */
exprList = transformExpressionList(pstate, exprList = transformExpressionList(pstate,
(List *) linitial(valuesLists), mergeWhenClause->values,
EXPR_KIND_VALUES_SINGLE, EXPR_KIND_VALUES_SINGLE,
true); true);
/* Prepare row for assignment to target table */ /* Prepare row for assignment to target table */
exprList = transformInsertRow(pstate, exprList, exprList = transformInsertRow(pstate, exprList,
istmt->cols, mergeWhenClause->cols,
icolumns, attrnos, icolumns, attrnos,
false); false);
} }
...@@ -580,10 +559,9 @@ transformMergeStmt(ParseState *pstate, MergeStmt *stmt) ...@@ -580,10 +559,9 @@ transformMergeStmt(ParseState *pstate, MergeStmt *stmt)
break; break;
case CMD_UPDATE: case CMD_UPDATE:
{ {
UpdateStmt *ustmt = (UpdateStmt *) action->stmt;
pstate->p_is_insert = false; pstate->p_is_insert = false;
action->targetList = transformUpdateTargetList(pstate, ustmt->targetList); action->targetList = transformUpdateTargetList(pstate,
mergeWhenClause->targetList);
} }
break; break;
case CMD_DELETE: case CMD_DELETE:
...@@ -595,9 +573,11 @@ transformMergeStmt(ParseState *pstate, MergeStmt *stmt) ...@@ -595,9 +573,11 @@ transformMergeStmt(ParseState *pstate, MergeStmt *stmt)
default: default:
elog(ERROR, "unknown action in MERGE WHEN clause"); elog(ERROR, "unknown action in MERGE WHEN clause");
} }
mergeActionList = lappend(mergeActionList, action);
} }
qry->mergeActionList = stmt->mergeActionList; qry->mergeActionList = mergeActionList;
/* XXX maybe later */ /* XXX maybe later */
qry->returningList = NULL; qry->returningList = NULL;
......
...@@ -3417,12 +3417,10 @@ RewriteQuery(Query *parsetree, List *rewrite_events) ...@@ -3417,12 +3417,10 @@ RewriteQuery(Query *parsetree, List *rewrite_events)
break; break;
case CMD_INSERT: case CMD_INSERT:
{ {
InsertStmt *istmt = (InsertStmt *) action->stmt;
action->targetList = action->targetList =
rewriteTargetListIU(action->targetList, rewriteTargetListIU(action->targetList,
action->commandType, action->commandType,
istmt->override, action->override,
rt_entry_relation, rt_entry_relation,
parsetree->resultRelation, parsetree->resultRelation,
NULL); NULL);
......
...@@ -269,6 +269,7 @@ typedef enum NodeTag ...@@ -269,6 +269,7 @@ typedef enum NodeTag
T_RollupData, T_RollupData,
T_GroupingSetData, T_GroupingSetData,
T_StatisticExtInfo, T_StatisticExtInfo,
T_MergeAction,
/* /*
* TAGS FOR MEMORY NODES (memnodes.h) * TAGS FOR MEMORY NODES (memnodes.h)
...@@ -310,7 +311,6 @@ typedef enum NodeTag ...@@ -310,7 +311,6 @@ typedef enum NodeTag
T_DeleteStmt, T_DeleteStmt,
T_UpdateStmt, T_UpdateStmt,
T_MergeStmt, T_MergeStmt,
T_MergeAction,
T_SelectStmt, T_SelectStmt,
T_AlterTableStmt, T_AlterTableStmt,
T_AlterTableCmd, T_AlterTableCmd,
...@@ -475,6 +475,7 @@ typedef enum NodeTag ...@@ -475,6 +475,7 @@ typedef enum NodeTag
T_PartitionRangeDatum, T_PartitionRangeDatum,
T_PartitionCmd, T_PartitionCmd,
T_VacuumRelation, T_VacuumRelation,
T_MergeWhenClause,
/* /*
* TAGS FOR REPLICATION GRAMMAR PARSE NODES (replnodes.h) * TAGS FOR REPLICATION GRAMMAR PARSE NODES (replnodes.h)
......
...@@ -1518,19 +1518,34 @@ typedef struct MergeStmt ...@@ -1518,19 +1518,34 @@ typedef struct MergeStmt
RangeVar *relation; /* target relation to merge into */ RangeVar *relation; /* target relation to merge into */
Node *source_relation; /* source relation */ Node *source_relation; /* source relation */
Node *join_condition; /* join condition between source and target */ Node *join_condition; /* join condition between source and target */
List *mergeActionList; /* list of MergeAction(s) */ List *mergeWhenClauses; /* list of MergeWhenClause(es) */
WithClause *withClause; /* WITH clause */ WithClause *withClause; /* WITH clause */
} MergeStmt; } MergeStmt;
typedef struct MergeAction typedef struct MergeWhenClause
{ {
NodeTag type; NodeTag type;
bool matched; /* true=MATCHED, false=NOT MATCHED */ bool matched; /* true=MATCHED, false=NOT MATCHED */
Node *condition; /* WHEN AND conditions (raw parser) */
Node *qual; /* transformed WHEN AND conditions */
CmdType commandType; /* INSERT/UPDATE/DELETE/DO NOTHING */ CmdType commandType; /* INSERT/UPDATE/DELETE/DO NOTHING */
Node *stmt; /* T_UpdateStmt etc */ Node *condition; /* WHEN AND conditions (raw parser) */
List *targetList; /* the target list (of ResTarget) */ List *targetList; /* INSERT/UPDATE targetlist */
/* the following members are only useful for INSERT action */
List *cols; /* optional: names of the target columns */
List *values; /* VALUES to INSERT, or NULL */
OverridingKind override; /* OVERRIDING clause */
} MergeWhenClause;
/*
* WHEN [NOT] MATCHED THEN action info
*/
typedef struct MergeAction
{
NodeTag type;
bool matched; /* true=MATCHED, false=NOT MATCHED */
OverridingKind override; /* OVERRIDING clause */
Node *qual; /* transformed WHEN AND conditions */
CmdType commandType; /* INSERT/UPDATE/DELETE/DO NOTHING */
List *targetList; /* the target list (of ResTarget) */
} MergeAction; } MergeAction;
/* ---------------------- /* ----------------------
......
...@@ -90,7 +90,9 @@ USING source AS s ...@@ -90,7 +90,9 @@ USING source AS s
ON t.tid = s.sid ON t.tid = s.sid
WHEN NOT MATCHED THEN WHEN NOT MATCHED THEN
INSERT VALUES (1,1), (2,2); INSERT VALUES (1,1), (2,2);
ERROR: Multiple VALUES clauses not allowed in MERGE INSERT statement ERROR: syntax error at or near ","
LINE 5: INSERT VALUES (1,1), (2,2);
^
; ;
-- SELECT query for INSERT -- SELECT query for INSERT
MERGE INTO target t MERGE INTO target t
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment