/*
 * the plpy module
 *
 * src/pl/plpython/plpy_plpymodule.c
 */

#include "postgres.h"

#include "mb/pg_wchar.h"
#include "utils/builtins.h"

#include "plpython.h"

#include "plpy_plpymodule.h"

#include "plpy_cursorobject.h"
#include "plpy_elog.h"
#include "plpy_planobject.h"
#include "plpy_resultobject.h"
#include "plpy_spi.h"
#include "plpy_subxactobject.h"


HTAB	   *PLy_spi_exceptions = NULL;


static void PLy_add_exceptions(PyObject *plpy);
static void PLy_generate_spi_exceptions(PyObject *mod, PyObject *base);

/* module functions */
static PyObject *PLy_debug(PyObject *self, PyObject *args, PyObject *kw);
static PyObject *PLy_log(PyObject *self, PyObject *args, PyObject *kw);
static PyObject *PLy_info(PyObject *self, PyObject *args, PyObject *kw);
static PyObject *PLy_notice(PyObject *self, PyObject *args, PyObject *kw);
static PyObject *PLy_warning(PyObject *self, PyObject *args, PyObject *kw);
static PyObject *PLy_error(PyObject *self, PyObject *args, PyObject *kw);
static PyObject *PLy_fatal(PyObject *self, PyObject *args, PyObject *kw);
static PyObject *PLy_quote_literal(PyObject *self, PyObject *args);
static PyObject *PLy_quote_nullable(PyObject *self, PyObject *args);
static PyObject *PLy_quote_ident(PyObject *self, PyObject *args);


/* A list of all known exceptions, generated from backend/utils/errcodes.txt */
typedef struct ExceptionMap
{
	char	   *name;
	char	   *classname;
	int			sqlstate;
} ExceptionMap;

static const ExceptionMap exception_map[] = {
#include "spiexceptions.h"
	{NULL, NULL, 0}
};

static PyMethodDef PLy_methods[] = {
	/*
	 * logging methods
	 */
	{"debug", (PyCFunction) PLy_debug, METH_VARARGS|METH_KEYWORDS, NULL},
	{"log", (PyCFunction) PLy_log, METH_VARARGS|METH_KEYWORDS, NULL},
	{"info", (PyCFunction) PLy_info, METH_VARARGS|METH_KEYWORDS, NULL},
	{"notice", (PyCFunction) PLy_notice, METH_VARARGS|METH_KEYWORDS, NULL},
	{"warning", (PyCFunction) PLy_warning, METH_VARARGS|METH_KEYWORDS, NULL},
	{"error", (PyCFunction) PLy_error, METH_VARARGS|METH_KEYWORDS, NULL},
	{"fatal", (PyCFunction) PLy_fatal, METH_VARARGS|METH_KEYWORDS, NULL},

	/*
	 * create a stored plan
	 */
	{"prepare", PLy_spi_prepare, METH_VARARGS, NULL},

	/*
	 * execute a plan or query
	 */
	{"execute", PLy_spi_execute, METH_VARARGS, NULL},

	/*
	 * escaping strings
	 */
	{"quote_literal", PLy_quote_literal, METH_VARARGS, NULL},
	{"quote_nullable", PLy_quote_nullable, METH_VARARGS, NULL},
	{"quote_ident", PLy_quote_ident, METH_VARARGS, NULL},

	/*
	 * create the subtransaction context manager
	 */
	{"subtransaction", PLy_subtransaction_new, METH_NOARGS, NULL},

	/*
	 * create a cursor
	 */
	{"cursor", PLy_cursor, METH_VARARGS, NULL},

	{NULL, NULL, 0, NULL}
};

static PyMethodDef PLy_exc_methods[] = {
	{NULL, NULL, 0, NULL}
};

#if PY_MAJOR_VERSION >= 3
static PyModuleDef PLy_module = {
	PyModuleDef_HEAD_INIT,		/* m_base */
	"plpy",						/* m_name */
	NULL,						/* m_doc */
	-1,							/* m_size */
	PLy_methods,				/* m_methods */
};

static PyModuleDef PLy_exc_module = {
	PyModuleDef_HEAD_INIT,		/* m_base */
	"spiexceptions",			/* m_name */
	NULL,						/* m_doc */
	-1,							/* m_size */
	PLy_exc_methods,			/* m_methods */
	NULL,						/* m_reload */
	NULL,						/* m_traverse */
	NULL,						/* m_clear */
	NULL						/* m_free */
};

/*
 * Must have external linkage, because PyMODINIT_FUNC does dllexport on
 * Windows-like platforms.
 */
PyMODINIT_FUNC
PyInit_plpy(void)
{
	PyObject   *m;

	m = PyModule_Create(&PLy_module);
	if (m == NULL)
		return NULL;

	PLy_add_exceptions(m);

	return m;
}
#endif   /* PY_MAJOR_VERSION >= 3 */

void
PLy_init_plpy(void)
{
	PyObject   *main_mod,
			   *main_dict,
			   *plpy_mod;

#if PY_MAJOR_VERSION < 3
	PyObject   *plpy;
#endif

	/*
	 * initialize plpy module
	 */
	PLy_plan_init_type();
	PLy_result_init_type();
	PLy_subtransaction_init_type();
	PLy_cursor_init_type();

#if PY_MAJOR_VERSION >= 3
	PyModule_Create(&PLy_module);
	/* for Python 3 we initialized the exceptions in PyInit_plpy */
#else
	plpy = Py_InitModule("plpy", PLy_methods);
	PLy_add_exceptions(plpy);
#endif

	/* PyDict_SetItemString(plpy, "PlanType", (PyObject *) &PLy_PlanType); */

	/*
	 * initialize main module, and add plpy
	 */
	main_mod = PyImport_AddModule("__main__");
	main_dict = PyModule_GetDict(main_mod);
	plpy_mod = PyImport_AddModule("plpy");
	if (plpy_mod == NULL)
		PLy_elog(ERROR, "could not import \"plpy\" module");
	PyDict_SetItemString(main_dict, "plpy", plpy_mod);
	if (PyErr_Occurred())
		PLy_elog(ERROR, "could not import \"plpy\" module");
}

static void
PLy_add_exceptions(PyObject *plpy)
{
	PyObject   *excmod;
	HASHCTL		hash_ctl;

#if PY_MAJOR_VERSION < 3
	excmod = Py_InitModule("spiexceptions", PLy_exc_methods);
#else
	excmod = PyModule_Create(&PLy_exc_module);
#endif
	if (PyModule_AddObject(plpy, "spiexceptions", excmod) < 0)
		PLy_elog(ERROR, "could not add the spiexceptions module");

	/*
	 * XXX it appears that in some circumstances the reference count of the
	 * spiexceptions module drops to zero causing a Python assert failure when
	 * the garbage collector visits the module. This has been observed on the
	 * buildfarm. To fix this, add an additional ref for the module here.
	 *
	 * This shouldn't cause a memory leak - we don't want this garbage
	 * collected, and this function shouldn't be called more than once per
	 * backend.
	 */
	Py_INCREF(excmod);

	PLy_exc_error = PyErr_NewException("plpy.Error", NULL, NULL);
	PLy_exc_fatal = PyErr_NewException("plpy.Fatal", NULL, NULL);
	PLy_exc_spi_error = PyErr_NewException("plpy.SPIError", NULL, NULL);

	if (PLy_exc_error == NULL ||
		PLy_exc_fatal == NULL ||
		PLy_exc_spi_error == NULL)
		PLy_elog(ERROR, "could not create the base SPI exceptions");

	Py_INCREF(PLy_exc_error);
	PyModule_AddObject(plpy, "Error", PLy_exc_error);
	Py_INCREF(PLy_exc_fatal);
	PyModule_AddObject(plpy, "Fatal", PLy_exc_fatal);
	Py_INCREF(PLy_exc_spi_error);
	PyModule_AddObject(plpy, "SPIError", PLy_exc_spi_error);

	memset(&hash_ctl, 0, sizeof(hash_ctl));
	hash_ctl.keysize = sizeof(int);
	hash_ctl.entrysize = sizeof(PLyExceptionEntry);
	PLy_spi_exceptions = hash_create("SPI exceptions", 256,
									 &hash_ctl, HASH_ELEM | HASH_BLOBS);

	PLy_generate_spi_exceptions(excmod, PLy_exc_spi_error);
}

/*
 * Add all the autogenerated exceptions as subclasses of SPIError
 */
static void
PLy_generate_spi_exceptions(PyObject *mod, PyObject *base)
{
	int			i;

	for (i = 0; exception_map[i].name != NULL; i++)
	{
		bool		found;
		PyObject   *exc;
		PLyExceptionEntry *entry;
		PyObject   *sqlstate;
		PyObject   *dict = PyDict_New();

		if (dict == NULL)
			PLy_elog(ERROR, "could not generate SPI exceptions");

		sqlstate = PyString_FromString(unpack_sql_state(exception_map[i].sqlstate));
		if (sqlstate == NULL)
			PLy_elog(ERROR, "could not generate SPI exceptions");

		PyDict_SetItemString(dict, "sqlstate", sqlstate);
		Py_DECREF(sqlstate);
		exc = PyErr_NewException(exception_map[i].name, base, dict);
		PyModule_AddObject(mod, exception_map[i].classname, exc);
		entry = hash_search(PLy_spi_exceptions, &exception_map[i].sqlstate,
							HASH_ENTER, &found);
		entry->exc = exc;
		Assert(!found);
	}
}


/*
 * the python interface to the elog function
 * don't confuse these with PLy_elog
 */
static PyObject *PLy_output(volatile int level, PyObject *self,
										  PyObject *args, PyObject *kw);

static PyObject *
PLy_debug(PyObject *self, PyObject *args, PyObject *kw)
{
	return PLy_output(DEBUG2, self, args, kw);
}

static PyObject *
PLy_log(PyObject *self, PyObject *args, PyObject *kw)
{
	return PLy_output(LOG, self, args, kw);
}

static PyObject *
PLy_info(PyObject *self, PyObject *args, PyObject *kw)
{
	return PLy_output(INFO, self, args, kw);
}

static PyObject *
PLy_notice(PyObject *self, PyObject *args, PyObject *kw)
{
	return PLy_output(NOTICE, self, args, kw);
}

static PyObject *
PLy_warning(PyObject *self, PyObject *args, PyObject *kw)
{
	return PLy_output(WARNING, self, args, kw);
}

static PyObject *
PLy_error(PyObject *self, PyObject *args, PyObject *kw)
{
	return PLy_output(ERROR, self, args, kw);
}

static PyObject *
PLy_fatal(PyObject *self, PyObject *args, PyObject *kw)
{
	return PLy_output(FATAL, self, args, kw);
}

static PyObject *
PLy_quote_literal(PyObject *self, PyObject *args)
{
	const char *str;
	char	   *quoted;
	PyObject   *ret;

	if (!PyArg_ParseTuple(args, "s", &str))
		return NULL;

	quoted = quote_literal_cstr(str);
	ret = PyString_FromString(quoted);
	pfree(quoted);

	return ret;
}

static PyObject *
PLy_quote_nullable(PyObject *self, PyObject *args)
{
	const char *str;
	char	   *quoted;
	PyObject   *ret;

	if (!PyArg_ParseTuple(args, "z", &str))
		return NULL;

	if (str == NULL)
		return PyString_FromString("NULL");

	quoted = quote_literal_cstr(str);
	ret = PyString_FromString(quoted);
	pfree(quoted);

	return ret;
}

static PyObject *
PLy_quote_ident(PyObject *self, PyObject *args)
{
	const char *str;
	const char *quoted;
	PyObject   *ret;

	if (!PyArg_ParseTuple(args, "s", &str))
		return NULL;

	quoted = quote_identifier(str);
	ret = PyString_FromString(quoted);

	return ret;
}

/* enforce cast of object to string */
static char *
object_to_string(PyObject *obj)
{
	if (obj)
	{
		PyObject *so = PyObject_Str(obj);

		if (so != NULL)
		{
			char *str;

			str = pstrdup(PyString_AsString(so));
			Py_DECREF(so);

			return str;
		}
	}

	return NULL;
}

static PyObject *
PLy_output(volatile int level, PyObject *self, PyObject *args, PyObject *kw)
{
	int sqlstate = 0;
	char *volatile sqlstatestr = NULL;
	char *volatile message = NULL;
	char *volatile detail = NULL;
	char *volatile hint = NULL;
	char *volatile column = NULL;
	char *volatile constraint = NULL;
	char *volatile datatype = NULL;
	char *volatile table = NULL;
	char *volatile schema = NULL;
	volatile MemoryContext oldcontext;
	PyObject *key, *value;
	PyObject *volatile so;
	Py_ssize_t pos = 0;

	if (PyTuple_Size(args) == 1)
	{
		/*
		 * Treat single argument specially to avoid undesirable ('tuple',)
		 * decoration.
		 */
		PyObject   *o;

		if (!PyArg_UnpackTuple(args, "plpy.elog", 1, 1, &o))
			PLy_elog(ERROR, "could not unpack arguments in plpy.elog");
		so = PyObject_Str(o);
	}
	else
		so = PyObject_Str(args);

	if (so == NULL || ((message = PyString_AsString(so)) == NULL))
	{
		level = ERROR;
		message = dgettext(TEXTDOMAIN, "could not parse error message in plpy.elog");
	}
	message = pstrdup(message);

	Py_XDECREF(so);

	if (kw != NULL)
	{
		while (PyDict_Next(kw, &pos, &key, &value))
		{
			char *keyword = PyString_AsString(key);

			if (strcmp(keyword, "message") == 0)
			{
				/* the message should not be overwriten */
				if (PyTuple_Size(args) != 0)
					PLy_elog(ERROR, "the message is already specified");

				if (message)
					pfree(message);
				message = object_to_string(value);
			}
			else if (strcmp(keyword, "detail") == 0)
				detail = object_to_string(value);
			else if (strcmp(keyword, "hint") == 0)
				hint = object_to_string(value);
			else if (strcmp(keyword, "sqlstate") == 0)
				sqlstatestr = object_to_string(value);
			else if (strcmp(keyword, "schema") == 0)
				schema = object_to_string(value);
			else if (strcmp(keyword, "table") == 0)
				table = object_to_string(value);
			else if (strcmp(keyword, "column") == 0)
				column = object_to_string(value);
			else if (strcmp(keyword, "datatype") == 0)
				datatype = object_to_string(value);
			else if (strcmp(keyword, "constraint") == 0)
				constraint = object_to_string(value);
		else
			PLy_elog(ERROR, "'%s' is an invalid keyword argument for this function",
								keyword);
		}
	}

	if (sqlstatestr != NULL)
	{
		if (strlen(sqlstatestr) != 5)
			PLy_elog(ERROR, "invalid SQLSTATE code");

		if (strspn(sqlstatestr, "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ") != 5)
			PLy_elog(ERROR, "invalid SQLSTATE code");

		sqlstate = MAKE_SQLSTATE(sqlstatestr[0],
							  sqlstatestr[1],
							  sqlstatestr[2],
							  sqlstatestr[3],
							  sqlstatestr[4]);
	}

	oldcontext = CurrentMemoryContext;
	PG_TRY();
	{
		if (message != NULL)
			pg_verifymbstr(message, strlen(message), false);
		if (detail != NULL)
			pg_verifymbstr(detail, strlen(detail), false);
		if (hint != NULL)
			pg_verifymbstr(hint, strlen(hint), false);
		if (schema != NULL)
			pg_verifymbstr(schema, strlen(schema), false);
		if (table != NULL)
			pg_verifymbstr(table, strlen(table), false);
		if (column != NULL)
			pg_verifymbstr(column, strlen(column), false);
		if (datatype != NULL)
			pg_verifymbstr(datatype, strlen(datatype), false);
		if (constraint != NULL)
			pg_verifymbstr(constraint, strlen(constraint), false);

		ereport(level,
				((sqlstate != 0) ? errcode(sqlstate) : 0,
				 (message != NULL) ? errmsg_internal("%s", message) : 0,
				 (detail != NULL) ? errdetail_internal("%s", detail) : 0,
				 (hint != NULL) ? errhint("%s", hint) : 0,
				 (column != NULL) ?
				 err_generic_string(PG_DIAG_COLUMN_NAME, column) : 0,
				 (constraint != NULL) ?
				 err_generic_string(PG_DIAG_CONSTRAINT_NAME, constraint) : 0,
				 (datatype != NULL) ?
				 err_generic_string(PG_DIAG_DATATYPE_NAME, datatype) : 0,
				 (table != NULL) ?
				 err_generic_string(PG_DIAG_TABLE_NAME, table) : 0,
				 (schema != NULL) ?
				 err_generic_string(PG_DIAG_SCHEMA_NAME, schema) : 0));
	}
	PG_CATCH();
	{
		ErrorData	*edata;

		MemoryContextSwitchTo(oldcontext);
		edata = CopyErrorData();
		FlushErrorState();

		PLy_exception_set_with_details(PLy_exc_error, edata);
		FreeErrorData(edata);

		return NULL;
	}
	PG_END_TRY();

	/*
	 * return a legal object so the interpreter will continue on its merry way
	 */
	Py_INCREF(Py_None);
	return Py_None;
}
