Commit c533c147 authored by Robert Haas's avatar Robert Haas

Add a missing_ok argument to get_object_address().

This lays the groundwork for an upcoming patch to streamline the
handling of DROP commands.

KaiGai Kohei
parent e1cd66f7
...@@ -72,15 +72,19 @@ ...@@ -72,15 +72,19 @@
#include "utils/tqual.h" #include "utils/tqual.h"
static ObjectAddress get_object_address_unqualified(ObjectType objtype, static ObjectAddress get_object_address_unqualified(ObjectType objtype,
List *qualname); List *qualname, bool missing_ok);
static Relation get_relation_by_qualified_name(ObjectType objtype, static ObjectAddress get_relation_by_qualified_name(ObjectType objtype,
List *objname, LOCKMODE lockmode); List *objname, Relation *relp,
LOCKMODE lockmode, bool missing_ok);
static ObjectAddress get_object_address_relobject(ObjectType objtype, static ObjectAddress get_object_address_relobject(ObjectType objtype,
List *objname, Relation *relp); List *objname, Relation *relp, bool missing_ok);
static ObjectAddress get_object_address_attribute(ObjectType objtype, static ObjectAddress get_object_address_attribute(ObjectType objtype,
List *objname, Relation *relp, LOCKMODE lockmode); List *objname, Relation *relp,
LOCKMODE lockmode, bool missing_ok);
static ObjectAddress get_object_address_type(ObjectType objtype,
List *objname, bool missing_ok);
static ObjectAddress get_object_address_opcf(ObjectType objtype, List *objname, static ObjectAddress get_object_address_opcf(ObjectType objtype, List *objname,
List *objargs); List *objargs, bool missing_ok);
static bool object_exists(ObjectAddress address); static bool object_exists(ObjectAddress address);
...@@ -106,7 +110,7 @@ static bool object_exists(ObjectAddress address); ...@@ -106,7 +110,7 @@ static bool object_exists(ObjectAddress address);
*/ */
ObjectAddress ObjectAddress
get_object_address(ObjectType objtype, List *objname, List *objargs, get_object_address(ObjectType objtype, List *objname, List *objargs,
Relation *relp, LOCKMODE lockmode) Relation *relp, LOCKMODE lockmode, bool missing_ok)
{ {
ObjectAddress address; ObjectAddress address;
Relation relation = NULL; Relation relation = NULL;
...@@ -121,21 +125,22 @@ get_object_address(ObjectType objtype, List *objname, List *objargs, ...@@ -121,21 +125,22 @@ get_object_address(ObjectType objtype, List *objname, List *objargs,
case OBJECT_TABLE: case OBJECT_TABLE:
case OBJECT_VIEW: case OBJECT_VIEW:
case OBJECT_FOREIGN_TABLE: case OBJECT_FOREIGN_TABLE:
relation = address =
get_relation_by_qualified_name(objtype, objname, lockmode); get_relation_by_qualified_name(objtype, objname,
address.classId = RelationRelationId; &relation, lockmode,
address.objectId = RelationGetRelid(relation); missing_ok);
address.objectSubId = 0;
break; break;
case OBJECT_COLUMN: case OBJECT_COLUMN:
address = address =
get_object_address_attribute(objtype, objname, &relation, get_object_address_attribute(objtype, objname,
lockmode); &relation, lockmode,
missing_ok);
break; break;
case OBJECT_RULE: case OBJECT_RULE:
case OBJECT_TRIGGER: case OBJECT_TRIGGER:
case OBJECT_CONSTRAINT: case OBJECT_CONSTRAINT:
address = get_object_address_relobject(objtype, objname, &relation); address = get_object_address_relobject(objtype, objname,
&relation, missing_ok);
break; break;
case OBJECT_DATABASE: case OBJECT_DATABASE:
case OBJECT_EXTENSION: case OBJECT_EXTENSION:
...@@ -145,23 +150,23 @@ get_object_address(ObjectType objtype, List *objname, List *objargs, ...@@ -145,23 +150,23 @@ get_object_address(ObjectType objtype, List *objname, List *objargs,
case OBJECT_LANGUAGE: case OBJECT_LANGUAGE:
case OBJECT_FDW: case OBJECT_FDW:
case OBJECT_FOREIGN_SERVER: case OBJECT_FOREIGN_SERVER:
address = get_object_address_unqualified(objtype, objname); address = get_object_address_unqualified(objtype,
objname, missing_ok);
break; break;
case OBJECT_TYPE: case OBJECT_TYPE:
case OBJECT_DOMAIN: case OBJECT_DOMAIN:
address.classId = TypeRelationId; address = get_object_address_type(objtype, objname, missing_ok);
address.objectId =
typenameTypeId(NULL, makeTypeNameFromNameList(objname));
address.objectSubId = 0;
break; break;
case OBJECT_AGGREGATE: case OBJECT_AGGREGATE:
address.classId = ProcedureRelationId; address.classId = ProcedureRelationId;
address.objectId = LookupAggNameTypeNames(objname, objargs, false); address.objectId =
LookupAggNameTypeNames(objname, objargs, missing_ok);
address.objectSubId = 0; address.objectSubId = 0;
break; break;
case OBJECT_FUNCTION: case OBJECT_FUNCTION:
address.classId = ProcedureRelationId; address.classId = ProcedureRelationId;
address.objectId = LookupFuncNameTypeNames(objname, objargs, false); address.objectId =
LookupFuncNameTypeNames(objname, objargs, missing_ok);
address.objectSubId = 0; address.objectSubId = 0;
break; break;
case OBJECT_OPERATOR: case OBJECT_OPERATOR:
...@@ -171,22 +176,23 @@ get_object_address(ObjectType objtype, List *objname, List *objargs, ...@@ -171,22 +176,23 @@ get_object_address(ObjectType objtype, List *objname, List *objargs,
LookupOperNameTypeNames(NULL, objname, LookupOperNameTypeNames(NULL, objname,
(TypeName *) linitial(objargs), (TypeName *) linitial(objargs),
(TypeName *) lsecond(objargs), (TypeName *) lsecond(objargs),
false, -1); missing_ok, -1);
address.objectSubId = 0; address.objectSubId = 0;
break; break;
case OBJECT_COLLATION: case OBJECT_COLLATION:
address.classId = CollationRelationId; address.classId = CollationRelationId;
address.objectId = get_collation_oid(objname, false); address.objectId = get_collation_oid(objname, missing_ok);
address.objectSubId = 0; address.objectSubId = 0;
break; break;
case OBJECT_CONVERSION: case OBJECT_CONVERSION:
address.classId = ConversionRelationId; address.classId = ConversionRelationId;
address.objectId = get_conversion_oid(objname, false); address.objectId = get_conversion_oid(objname, missing_ok);
address.objectSubId = 0; address.objectSubId = 0;
break; break;
case OBJECT_OPCLASS: case OBJECT_OPCLASS:
case OBJECT_OPFAMILY: case OBJECT_OPFAMILY:
address = get_object_address_opcf(objtype, objname, objargs); address = get_object_address_opcf(objtype,
objname, objargs, missing_ok);
break; break;
case OBJECT_LARGEOBJECT: case OBJECT_LARGEOBJECT:
Assert(list_length(objname) == 1); Assert(list_length(objname) == 1);
...@@ -194,10 +200,13 @@ get_object_address(ObjectType objtype, List *objname, List *objargs, ...@@ -194,10 +200,13 @@ get_object_address(ObjectType objtype, List *objname, List *objargs,
address.objectId = oidparse(linitial(objname)); address.objectId = oidparse(linitial(objname));
address.objectSubId = 0; address.objectSubId = 0;
if (!LargeObjectExists(address.objectId)) if (!LargeObjectExists(address.objectId))
{
if (!missing_ok)
ereport(ERROR, ereport(ERROR,
(errcode(ERRCODE_UNDEFINED_OBJECT), (errcode(ERRCODE_UNDEFINED_OBJECT),
errmsg("large object %u does not exist", errmsg("large object %u does not exist",
address.objectId))); address.objectId)));
}
break; break;
case OBJECT_CAST: case OBJECT_CAST:
{ {
...@@ -208,28 +217,28 @@ get_object_address(ObjectType objtype, List *objname, List *objargs, ...@@ -208,28 +217,28 @@ get_object_address(ObjectType objtype, List *objname, List *objargs,
address.classId = CastRelationId; address.classId = CastRelationId;
address.objectId = address.objectId =
get_cast_oid(sourcetypeid, targettypeid, false); get_cast_oid(sourcetypeid, targettypeid, missing_ok);
address.objectSubId = 0; address.objectSubId = 0;
} }
break; break;
case OBJECT_TSPARSER: case OBJECT_TSPARSER:
address.classId = TSParserRelationId; address.classId = TSParserRelationId;
address.objectId = get_ts_parser_oid(objname, false); address.objectId = get_ts_parser_oid(objname, missing_ok);
address.objectSubId = 0; address.objectSubId = 0;
break; break;
case OBJECT_TSDICTIONARY: case OBJECT_TSDICTIONARY:
address.classId = TSDictionaryRelationId; address.classId = TSDictionaryRelationId;
address.objectId = get_ts_dict_oid(objname, false); address.objectId = get_ts_dict_oid(objname, missing_ok);
address.objectSubId = 0; address.objectSubId = 0;
break; break;
case OBJECT_TSTEMPLATE: case OBJECT_TSTEMPLATE:
address.classId = TSTemplateRelationId; address.classId = TSTemplateRelationId;
address.objectId = get_ts_template_oid(objname, false); address.objectId = get_ts_template_oid(objname, missing_ok);
address.objectSubId = 0; address.objectSubId = 0;
break; break;
case OBJECT_TSCONFIGURATION: case OBJECT_TSCONFIGURATION:
address.classId = TSConfigRelationId; address.classId = TSConfigRelationId;
address.objectId = get_ts_config_oid(objname, false); address.objectId = get_ts_config_oid(objname, missing_ok);
address.objectSubId = 0; address.objectSubId = 0;
break; break;
default: default:
...@@ -240,6 +249,15 @@ get_object_address(ObjectType objtype, List *objname, List *objargs, ...@@ -240,6 +249,15 @@ get_object_address(ObjectType objtype, List *objname, List *objargs,
address.objectSubId = 0; address.objectSubId = 0;
} }
/*
* If we could not find the supplied object, return without locking.
*/
if (!OidIsValid(address.objectId))
{
Assert(missing_ok);
return address;
}
/* /*
* If we're dealing with a relation or attribute, then the relation is * If we're dealing with a relation or attribute, then the relation is
* already locked. If we're dealing with any other type of object, we * already locked. If we're dealing with any other type of object, we
...@@ -267,7 +285,8 @@ get_object_address(ObjectType objtype, List *objname, List *objargs, ...@@ -267,7 +285,8 @@ get_object_address(ObjectType objtype, List *objname, List *objargs,
* unqualified name. * unqualified name.
*/ */
static ObjectAddress static ObjectAddress
get_object_address_unqualified(ObjectType objtype, List *qualname) get_object_address_unqualified(ObjectType objtype,
List *qualname, bool missing_ok)
{ {
const char *name; const char *name;
ObjectAddress address; ObjectAddress address;
...@@ -323,42 +342,42 @@ get_object_address_unqualified(ObjectType objtype, List *qualname) ...@@ -323,42 +342,42 @@ get_object_address_unqualified(ObjectType objtype, List *qualname)
{ {
case OBJECT_DATABASE: case OBJECT_DATABASE:
address.classId = DatabaseRelationId; address.classId = DatabaseRelationId;
address.objectId = get_database_oid(name, false); address.objectId = get_database_oid(name, missing_ok);
address.objectSubId = 0; address.objectSubId = 0;
break; break;
case OBJECT_EXTENSION: case OBJECT_EXTENSION:
address.classId = ExtensionRelationId; address.classId = ExtensionRelationId;
address.objectId = get_extension_oid(name, false); address.objectId = get_extension_oid(name, missing_ok);
address.objectSubId = 0; address.objectSubId = 0;
break; break;
case OBJECT_TABLESPACE: case OBJECT_TABLESPACE:
address.classId = TableSpaceRelationId; address.classId = TableSpaceRelationId;
address.objectId = get_tablespace_oid(name, false); address.objectId = get_tablespace_oid(name, missing_ok);
address.objectSubId = 0; address.objectSubId = 0;
break; break;
case OBJECT_ROLE: case OBJECT_ROLE:
address.classId = AuthIdRelationId; address.classId = AuthIdRelationId;
address.objectId = get_role_oid(name, false); address.objectId = get_role_oid(name, missing_ok);
address.objectSubId = 0; address.objectSubId = 0;
break; break;
case OBJECT_SCHEMA: case OBJECT_SCHEMA:
address.classId = NamespaceRelationId; address.classId = NamespaceRelationId;
address.objectId = get_namespace_oid(name, false); address.objectId = get_namespace_oid(name, missing_ok);
address.objectSubId = 0; address.objectSubId = 0;
break; break;
case OBJECT_LANGUAGE: case OBJECT_LANGUAGE:
address.classId = LanguageRelationId; address.classId = LanguageRelationId;
address.objectId = get_language_oid(name, false); address.objectId = get_language_oid(name, missing_ok);
address.objectSubId = 0; address.objectSubId = 0;
break; break;
case OBJECT_FDW: case OBJECT_FDW:
address.classId = ForeignDataWrapperRelationId; address.classId = ForeignDataWrapperRelationId;
address.objectId = get_foreign_data_wrapper_oid(name, false); address.objectId = get_foreign_data_wrapper_oid(name, missing_ok);
address.objectSubId = 0; address.objectSubId = 0;
break; break;
case OBJECT_FOREIGN_SERVER: case OBJECT_FOREIGN_SERVER:
address.classId = ForeignServerRelationId; address.classId = ForeignServerRelationId;
address.objectId = get_foreign_server_oid(name, false); address.objectId = get_foreign_server_oid(name, missing_ok);
address.objectSubId = 0; address.objectSubId = 0;
break; break;
default: default:
...@@ -375,13 +394,23 @@ get_object_address_unqualified(ObjectType objtype, List *qualname) ...@@ -375,13 +394,23 @@ get_object_address_unqualified(ObjectType objtype, List *qualname)
/* /*
* Locate a relation by qualified name. * Locate a relation by qualified name.
*/ */
static Relation static ObjectAddress
get_relation_by_qualified_name(ObjectType objtype, List *objname, get_relation_by_qualified_name(ObjectType objtype, List *objname,
LOCKMODE lockmode) Relation *relp, LOCKMODE lockmode,
bool missing_ok)
{ {
Relation relation; Relation relation;
ObjectAddress address;
address.classId = RelationRelationId;
address.objectId = InvalidOid;
address.objectSubId = 0;
relation = relation_openrv_extended(makeRangeVarFromNameList(objname),
lockmode, missing_ok);
if (!relation)
return address;
relation = relation_openrv(makeRangeVarFromNameList(objname), lockmode);
switch (objtype) switch (objtype)
{ {
case OBJECT_INDEX: case OBJECT_INDEX:
...@@ -424,7 +453,11 @@ get_relation_by_qualified_name(ObjectType objtype, List *objname, ...@@ -424,7 +453,11 @@ get_relation_by_qualified_name(ObjectType objtype, List *objname,
break; break;
} }
return relation; /* Done */
address.objectId = RelationGetRelid(relation);
*relp = relation;
return address;
} }
/* /*
...@@ -435,7 +468,8 @@ get_relation_by_qualified_name(ObjectType objtype, List *objname, ...@@ -435,7 +468,8 @@ get_relation_by_qualified_name(ObjectType objtype, List *objname,
* mode for the object itself, not the relation to which it is attached. * mode for the object itself, not the relation to which it is attached.
*/ */
static ObjectAddress static ObjectAddress
get_object_address_relobject(ObjectType objtype, List *objname, Relation *relp) get_object_address_relobject(ObjectType objtype, List *objname,
Relation *relp, bool missing_ok)
{ {
ObjectAddress address; ObjectAddress address;
Relation relation = NULL; Relation relation = NULL;
...@@ -461,9 +495,9 @@ get_object_address_relobject(ObjectType objtype, List *objname, Relation *relp) ...@@ -461,9 +495,9 @@ get_object_address_relobject(ObjectType objtype, List *objname, Relation *relp)
if (objtype != OBJECT_RULE) if (objtype != OBJECT_RULE)
elog(ERROR, "must specify relation and object name"); elog(ERROR, "must specify relation and object name");
address.classId = RewriteRelationId; address.classId = RewriteRelationId;
address.objectId = get_rewrite_oid_without_relid(depname, &reloid); address.objectId =
get_rewrite_oid_without_relid(depname, &reloid, missing_ok);
address.objectSubId = 0; address.objectSubId = 0;
relation = heap_open(reloid, AccessShareLock);
} }
else else
{ {
...@@ -480,17 +514,18 @@ get_object_address_relobject(ObjectType objtype, List *objname, Relation *relp) ...@@ -480,17 +514,18 @@ get_object_address_relobject(ObjectType objtype, List *objname, Relation *relp)
{ {
case OBJECT_RULE: case OBJECT_RULE:
address.classId = RewriteRelationId; address.classId = RewriteRelationId;
address.objectId = get_rewrite_oid(reloid, depname, false); address.objectId = get_rewrite_oid(reloid, depname, missing_ok);
address.objectSubId = 0; address.objectSubId = 0;
break; break;
case OBJECT_TRIGGER: case OBJECT_TRIGGER:
address.classId = TriggerRelationId; address.classId = TriggerRelationId;
address.objectId = get_trigger_oid(reloid, depname, false); address.objectId = get_trigger_oid(reloid, depname, missing_ok);
address.objectSubId = 0; address.objectSubId = 0;
break; break;
case OBJECT_CONSTRAINT: case OBJECT_CONSTRAINT:
address.classId = ConstraintRelationId; address.classId = ConstraintRelationId;
address.objectId = get_constraint_oid(reloid, depname, false); address.objectId =
get_constraint_oid(reloid, depname, missing_ok);
address.objectSubId = 0; address.objectSubId = 0;
break; break;
default: default:
...@@ -512,13 +547,15 @@ get_object_address_relobject(ObjectType objtype, List *objname, Relation *relp) ...@@ -512,13 +547,15 @@ get_object_address_relobject(ObjectType objtype, List *objname, Relation *relp)
*/ */
static ObjectAddress static ObjectAddress
get_object_address_attribute(ObjectType objtype, List *objname, get_object_address_attribute(ObjectType objtype, List *objname,
Relation *relp, LOCKMODE lockmode) Relation *relp, LOCKMODE lockmode,
bool missing_ok)
{ {
ObjectAddress address; ObjectAddress address;
List *relname; List *relname;
Oid reloid; Oid reloid;
Relation relation; Relation relation;
const char *attname; const char *attname;
AttrNumber attnum;
/* Extract relation name and open relation. */ /* Extract relation name and open relation. */
attname = strVal(lfirst(list_tail(objname))); attname = strVal(lfirst(list_tail(objname)));
...@@ -527,24 +564,77 @@ get_object_address_attribute(ObjectType objtype, List *objname, ...@@ -527,24 +564,77 @@ get_object_address_attribute(ObjectType objtype, List *objname,
reloid = RelationGetRelid(relation); reloid = RelationGetRelid(relation);
/* Look up attribute and construct return value. */ /* Look up attribute and construct return value. */
attnum = get_attnum(reloid, attname);
if (attnum == InvalidAttrNumber)
{
if (!missing_ok)
ereport(ERROR,
(errcode(ERRCODE_UNDEFINED_COLUMN),
errmsg("column \"%s\" of relation \"%s\" does not exist",
attname, NameListToString(relname))));
address.classId = RelationRelationId;
address.objectId = InvalidOid;
address.objectSubId = InvalidAttrNumber;
return address;
}
address.classId = RelationRelationId; address.classId = RelationRelationId;
address.objectId = reloid; address.objectId = reloid;
address.objectSubId = get_attnum(reloid, attname); address.objectSubId = attnum;
if (address.objectSubId == InvalidAttrNumber)
ereport(ERROR,
(errcode(ERRCODE_UNDEFINED_COLUMN),
errmsg("column \"%s\" of relation \"%s\" does not exist",
attname, RelationGetRelationName(relation))));
*relp = relation; *relp = relation;
return address; return address;
} }
/*
* Find the ObjectAddress for a type or domain
*/
static ObjectAddress
get_object_address_type(ObjectType objtype,
List *objname, bool missing_ok)
{
ObjectAddress address;
TypeName *typename;
Type tup;
typename = makeTypeNameFromNameList(objname);
address.classId = TypeRelationId;
address.objectId = InvalidOid;
address.objectSubId = 0;
tup = LookupTypeName(NULL, typename, NULL);
if (!HeapTupleIsValid(tup))
{
if (!missing_ok)
ereport(ERROR,
(errcode(ERRCODE_UNDEFINED_OBJECT),
errmsg("type \"%s\" does not exist",
TypeNameToString(typename))));
return address;
}
address.objectId = typeTypeId(tup);
if (objtype == OBJECT_DOMAIN)
{
if (((Form_pg_type) GETSTRUCT(tup))->typtype != TYPTYPE_DOMAIN)
ereport(ERROR,
(errcode(ERRCODE_WRONG_OBJECT_TYPE),
errmsg("\"%s\" is not a domain",
TypeNameToString(typename))));
}
ReleaseSysCache(tup);
return address;
}
/* /*
* Find the ObjectAddress for an opclass or opfamily. * Find the ObjectAddress for an opclass or opfamily.
*/ */
static ObjectAddress static ObjectAddress
get_object_address_opcf(ObjectType objtype, List *objname, List *objargs) get_object_address_opcf(ObjectType objtype,
List *objname, List *objargs, bool missing_ok)
{ {
Oid amoid; Oid amoid;
ObjectAddress address; ObjectAddress address;
...@@ -556,12 +646,12 @@ get_object_address_opcf(ObjectType objtype, List *objname, List *objargs) ...@@ -556,12 +646,12 @@ get_object_address_opcf(ObjectType objtype, List *objname, List *objargs)
{ {
case OBJECT_OPCLASS: case OBJECT_OPCLASS:
address.classId = OperatorClassRelationId; address.classId = OperatorClassRelationId;
address.objectId = get_opclass_oid(amoid, objname, false); address.objectId = get_opclass_oid(amoid, objname, missing_ok);
address.objectSubId = 0; address.objectSubId = 0;
break; break;
case OBJECT_OPFAMILY: case OBJECT_OPFAMILY:
address.classId = OperatorFamilyRelationId; address.classId = OperatorFamilyRelationId;
address.objectId = get_opfamily_oid(amoid, objname, false); address.objectId = get_opfamily_oid(amoid, objname, missing_ok);
address.objectSubId = 0; address.objectSubId = 0;
break; break;
default: default:
......
...@@ -69,7 +69,7 @@ CommentObject(CommentStmt *stmt) ...@@ -69,7 +69,7 @@ CommentObject(CommentStmt *stmt)
* against concurrent DROP operations. * against concurrent DROP operations.
*/ */
address = get_object_address(stmt->objtype, stmt->objname, stmt->objargs, address = get_object_address(stmt->objtype, stmt->objname, stmt->objargs,
&relation, ShareUpdateExclusiveLock); &relation, ShareUpdateExclusiveLock, false);
/* Require ownership of the target object. */ /* Require ownership of the target object. */
check_object_ownership(GetUserId(), stmt->objtype, address, check_object_ownership(GetUserId(), stmt->objtype, address,
......
...@@ -2703,7 +2703,7 @@ ExecAlterExtensionContentsStmt(AlterExtensionContentsStmt *stmt) ...@@ -2703,7 +2703,7 @@ ExecAlterExtensionContentsStmt(AlterExtensionContentsStmt *stmt)
* against concurrent DROP and ALTER EXTENSION ADD/DROP operations. * against concurrent DROP and ALTER EXTENSION ADD/DROP operations.
*/ */
object = get_object_address(stmt->objtype, stmt->objname, stmt->objargs, object = get_object_address(stmt->objtype, stmt->objname, stmt->objargs,
&relation, ShareUpdateExclusiveLock); &relation, ShareUpdateExclusiveLock, false);
/* Permission check: must own target object, too */ /* Permission check: must own target object, too */
check_object_ownership(GetUserId(), stmt->objtype, object, check_object_ownership(GetUserId(), stmt->objtype, object,
......
...@@ -88,7 +88,7 @@ ExecSecLabelStmt(SecLabelStmt *stmt) ...@@ -88,7 +88,7 @@ ExecSecLabelStmt(SecLabelStmt *stmt)
* guard against concurrent modifications. * guard against concurrent modifications.
*/ */
address = get_object_address(stmt->objtype, stmt->objname, stmt->objargs, address = get_object_address(stmt->objtype, stmt->objname, stmt->objargs,
&relation, ShareUpdateExclusiveLock); &relation, ShareUpdateExclusiveLock, false);
/* Require ownership of the target object. */ /* Require ownership of the target object. */
check_object_ownership(GetUserId(), stmt->objtype, address, check_object_ownership(GetUserId(), stmt->objtype, address,
......
...@@ -132,7 +132,8 @@ get_rewrite_oid(Oid relid, const char *rulename, bool missing_ok) ...@@ -132,7 +132,8 @@ get_rewrite_oid(Oid relid, const char *rulename, bool missing_ok)
* were unique across the entire database. * were unique across the entire database.
*/ */
Oid Oid
get_rewrite_oid_without_relid(const char *rulename, Oid *reloid) get_rewrite_oid_without_relid(const char *rulename,
Oid *reloid, bool missing_ok)
{ {
Relation RewriteRelation; Relation RewriteRelation;
HeapScanDesc scanDesc; HeapScanDesc scanDesc;
...@@ -151,20 +152,26 @@ get_rewrite_oid_without_relid(const char *rulename, Oid *reloid) ...@@ -151,20 +152,26 @@ get_rewrite_oid_without_relid(const char *rulename, Oid *reloid)
htup = heap_getnext(scanDesc, ForwardScanDirection); htup = heap_getnext(scanDesc, ForwardScanDirection);
if (!HeapTupleIsValid(htup)) if (!HeapTupleIsValid(htup))
ereport(ERROR, {
(errcode(ERRCODE_UNDEFINED_OBJECT), if (!missing_ok)
errmsg("rule \"%s\" does not exist", rulename))); ereport(ERROR,
(errcode(ERRCODE_UNDEFINED_OBJECT),
ruleoid = HeapTupleGetOid(htup); errmsg("rule \"%s\" does not exist", rulename)));
if (reloid != NULL) ruleoid = InvalidOid;
*reloid = ((Form_pg_rewrite) GETSTRUCT(htup))->ev_class; }
else
if (HeapTupleIsValid(htup = heap_getnext(scanDesc, ForwardScanDirection))) {
ereport(ERROR, ruleoid = HeapTupleGetOid(htup);
(errcode(ERRCODE_DUPLICATE_OBJECT), if (reloid != NULL)
errmsg("there are multiple rules named \"%s\"", rulename), *reloid = ((Form_pg_rewrite) GETSTRUCT(htup))->ev_class;
errhint("Specify a relation name as well as a rule name.")));
htup = heap_getnext(scanDesc, ForwardScanDirection);
if (HeapTupleIsValid(htup))
ereport(ERROR,
(errcode(ERRCODE_DUPLICATE_OBJECT),
errmsg("there are multiple rules named \"%s\"", rulename),
errhint("Specify a relation name as well as a rule name.")));
}
heap_endscan(scanDesc); heap_endscan(scanDesc);
heap_close(RewriteRelation, AccessShareLock); heap_close(RewriteRelation, AccessShareLock);
......
...@@ -28,7 +28,8 @@ typedef struct ObjectAddress ...@@ -28,7 +28,8 @@ typedef struct ObjectAddress
} ObjectAddress; } ObjectAddress;
extern ObjectAddress get_object_address(ObjectType objtype, List *objname, extern ObjectAddress get_object_address(ObjectType objtype, List *objname,
List *objargs, Relation *relp, LOCKMODE lockmode); List *objargs, Relation *relp,
LOCKMODE lockmode, bool missing_ok);
extern void check_object_ownership(Oid roleid, extern void check_object_ownership(Oid roleid,
ObjectType objtype, ObjectAddress address, ObjectType objtype, ObjectAddress address,
......
...@@ -23,6 +23,7 @@ extern void SetRelationRuleStatus(Oid relationId, bool relHasRules, ...@@ -23,6 +23,7 @@ extern void SetRelationRuleStatus(Oid relationId, bool relHasRules,
bool relIsBecomingView); bool relIsBecomingView);
extern Oid get_rewrite_oid(Oid relid, const char *rulename, bool missing_ok); extern Oid get_rewrite_oid(Oid relid, const char *rulename, bool missing_ok);
extern Oid get_rewrite_oid_without_relid(const char *rulename, Oid *relid); extern Oid get_rewrite_oid_without_relid(const char *rulename,
Oid *relid, bool missing_ok);
#endif /* REWRITESUPPORT_H */ #endif /* REWRITESUPPORT_H */
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment