diff --git a/_mysql.c b/_mysql.c index 5b81c79d..49c8ea17 100644 --- a/_mysql.c +++ b/_mysql.c @@ -107,11 +107,11 @@ typedef struct { extern PyTypeObject _mysql_ResultObject_Type; -static int _mysql_server_init_done = 0; +static int _mysql_library_init_done = 0; #if MYSQL_VERSION_ID >= 40000 -#define check_server_init(x) if (!_mysql_server_init_done) { if (mysql_server_init(0, NULL, NULL)) { _mysql_Exception(NULL); return x; } else { _mysql_server_init_done = 1;} } +#define check_library_init(x) if (!_mysql_library_init_done) { if (mysql_server_init(0, NULL, NULL)) { _mysql_Exception(NULL); return x; } else { _mysql_library_init_done = 1;} } #else -#define check_server_init(x) if (!_mysql_server_init_done) _mysql_server_init_done = 1 +#define check_library_init(x) if (!_mysql_library_init_done) _mysql_library_init_done = 1 #endif #if MYSQL_VERSION_ID >= 50500 @@ -124,37 +124,63 @@ static int _mysql_server_init_done = 0; #define HAVE_MYSQL_OPT_TIMEOUTS 1 #endif +static void +_mysql_SetError(PyObject *type, long code, PyObject *message_obj) +{ + PyObject *args; + if (!message_obj) return; // error raised by PyString_* method in caller + + if (!(args = PyTuple_New(2))) return; // OOM raised by PyTuple_New + + PyTuple_SET_ITEM(args, 0, PyInt_FromLong(code)); + PyTuple_SET_ITEM(args, 1, message_obj); + if (PyTuple_GET_ITEM(args, 0) && PyTuple_GET_ITEM(args, 1)) { + PyErr_SetObject(type, args); + } // else arg tuple failed to be filled, and the cause is + // already raised by one of the Py*_From* calls. + + // either way, if we allocated a tuple we need to release our ref: + Py_DECREF(args); +} + +static void +_mysql_SetErrorString(PyObject *type, long code, const char *message) +{ +#ifdef IS_PY3K + PyObject *message_obj = PyUnicode_FromString(message); +#else + PyObject *message_obj = PyString_FromString(message); +#endif + _mysql_SetError(type, code, message_obj); +} + PyObject * _mysql_Exception(_mysql_ConnectionObject *c) { PyObject *t, *e; int merr; + const char *message; if (!(t = PyTuple_New(2))) return NULL; - if (!_mysql_server_init_done) { - e = _mysql_InternalError; - PyTuple_SET_ITEM(t, 0, PyInt_FromLong(-1L)); -#ifdef IS_PY3K - PyTuple_SET_ITEM(t, 1, PyUnicode_FromString("server not initialized")); -#else - PyTuple_SET_ITEM(t, 1, PyString_FromString("server not initialized")); -#endif - PyErr_SetObject(e, t); - Py_DECREF(t); + if (!_mysql_library_init_done) { + _mysql_SetErrorString(_mysql_InternalError, -1L, + "server/library not initialized"); return NULL; } merr = mysql_errno(&(c->connection)); - if (!merr) + message = mysql_error(&(c->connection)); + if (!merr) { e = _mysql_InterfaceError; - else if (merr > CR_MAX_ERROR) { - PyTuple_SET_ITEM(t, 0, PyInt_FromLong(-1L)); + } else if (merr > CR_MAX_ERROR) { + // the C library gave us an error type we don't understand #ifdef IS_PY3K - PyTuple_SET_ITEM(t, 1, PyUnicode_FromString("error totally whack")); + PyObject *message_obj = + PyUnicode_FromFormat("error totally whack: %s", message); #else - PyTuple_SET_ITEM(t, 1, PyString_FromString("error totally whack")); + PyObject *message_obj = + PyString_FromFormat("error totally whack: %s", message); #endif - PyErr_SetObject(_mysql_InterfaceError, t); - Py_DECREF(t); + _mysql_SetError(_mysql_InternalError, -1L, message_obj); return NULL; } else switch (merr) { @@ -239,26 +265,21 @@ _mysql_Exception(_mysql_ConnectionObject *c) e = _mysql_OperationalError; break; } - PyTuple_SET_ITEM(t, 0, PyInt_FromLong((long)merr)); -#ifdef IS_PY3K - PyTuple_SET_ITEM(t, 1, PyUnicode_FromString(mysql_error(&(c->connection)))); -#else - PyTuple_SET_ITEM(t, 1, PyString_FromString(mysql_error(&(c->connection)))); -#endif - PyErr_SetObject(e, t); + _mysql_SetErrorString(e, (long)merr, message); Py_DECREF(t); return NULL; } -static char _mysql_server_init__doc__[] = -"Initialize embedded server. If this client is not linked against\n\ -the embedded server library, this function does nothing.\n\ +static char _mysql_library_init__doc__[] = +"Initialize the mysql library, and start the embedded server if the\n\ +embedded server library is linked in. May be called automatically if any\n\ +other MySQLdb functionality requiring it is used.\n\ \n\ args -- sequence of command-line arguments\n\ groups -- sequence of groups to use in defaults files\n\ "; -static PyObject *_mysql_server_init( +static PyObject *_mysql_library_init( PyObject *self, PyObject *args, PyObject *kwargs) { @@ -267,7 +288,7 @@ static PyObject *_mysql_server_init( int cmd_argc=0, i, groupc; PyObject *cmd_args=NULL, *groups=NULL, *ret=NULL, *item; - if (_mysql_server_init_done) { + if (_mysql_library_init_done) { PyErr_SetString(_mysql_ProgrammingError, "already initialized"); return NULL; @@ -277,7 +298,6 @@ static PyObject *_mysql_server_init( &cmd_args, &groups)) return NULL; -#if MYSQL_VERSION_ID >= 40000 if (cmd_args) { if (!PySequence_Check(cmd_args)) { PyErr_SetString(PyExc_TypeError, @@ -340,6 +360,12 @@ static PyObject *_mysql_server_init( } /* even though this may block, don't give up the interpreter lock so that the server can't be initialized multiple times. */ +#if MYSQL_VERSION_ID >= 40110 + if (mysql_library_init(cmd_argc, cmd_args_c, groups_c)) { + _mysql_Exception(NULL); + goto finish; + } +#elif MYSQL_VERSION_ID >= 40000 if (mysql_server_init(cmd_argc, cmd_args_c, groups_c)) { _mysql_Exception(NULL); goto finish; @@ -347,31 +373,53 @@ static PyObject *_mysql_server_init( #endif ret = Py_None; Py_INCREF(Py_None); - _mysql_server_init_done = 1; + _mysql_library_init_done = 1; finish: PyMem_Free(groups_c); PyMem_Free(cmd_args_c); return ret; } -static char _mysql_server_end__doc__[] = -"Shut down embedded server. If not using an embedded server, this\n\ -does nothing."; +static char _mysql_server_init__doc__[] = +"Deprecated. Use library_init instead."; -static PyObject *_mysql_server_end( +static PyObject *_mysql_server_init( + PyObject *self, + PyObject *args, + PyObject *kwargs) { + + return _mysql_library_init(self, args, kwargs); +} + +static char _mysql_library_end__doc__[] = +"Finalize the underlying libmysql. Shuts down the embedded server if one\n\ +is linked in, and cleans up some memory otherwise."; + +static PyObject *_mysql_library_end( PyObject *self, PyObject *args) { - if (_mysql_server_init_done) { -#if MYSQL_VERSION_ID >= 40000 + if (_mysql_library_init_done) { +#if MYSQL_VERSION_ID >= 40110 + mysql_library_end(); +#elif MYSQL_VERSION_ID >= 40000 mysql_server_end(); #endif - _mysql_server_init_done = 0; + _mysql_library_init_done = 0; Py_INCREF(Py_None); return Py_None; } return _mysql_Exception(NULL); } +static char _mysql_server_end__doc__[] = +"Deprecated. Use library_end() instead."; + +static PyObject *_mysql_server_end( + PyObject *self, + PyObject *args) { + return _mysql_library_end(self, args); +} + #if MYSQL_VERSION_ID >= 32314 static char _mysql_thread_safe__doc__[] = "Indicates whether the client is compiled as thread-safe."; @@ -380,11 +428,48 @@ static PyObject *_mysql_thread_safe( PyObject *self, PyObject *args) { PyObject *flag; - if (!PyArg_ParseTuple(args, "")) return NULL; - check_server_init(NULL); + check_library_init(NULL); if (!(flag=PyInt_FromLong((long)mysql_thread_safe()))) return NULL; return flag; } + +static char _mysql_thread_init__doc__[] = +"Call early in each thread to initialize thread-specific internal data.\n\ +\n\ +Can be called in place of library_init() if no special arguments need to\n\ +be passed to the latter, and is implied by an explicit call to\n\ +library_init() for the calling thread if such a call is made.\n\ +\n\ +Safe to call multiple times from the same thread."; + +static PyObject *_mysql_thread_init( + PyObject *self, + PyObject *args) { + + check_library_init(NULL); + if (mysql_thread_init()) { + _mysql_SetError(_mysql_InternalError, -1L, + PyString_FromString("mysql_thread_init failed")); + return NULL; + } + + Py_INCREF(Py_None); + return Py_None; +} + +static char _mysql_thread_end__doc__[] = +"Call before exiting any background threads that use MySQL to free\n\ +thread-specific resources. Failing to call this function results in\n\ +a memory leak."; + +static PyObject *_mysql_thread_end( + PyObject *self, + PyObject *args) { + mysql_thread_end(); + + Py_INCREF(Py_None); + return Py_None; +} #endif static char _mysql_ResultObject__doc__[] = @@ -583,7 +668,7 @@ _mysql_ConnectionObject_Initialize( self->converter = NULL; self->open = 0; - check_server_init(-1); + check_library_init(-1); if (!PyArg_ParseTupleAndKeywords(args, kwargs, #ifdef HAVE_MYSQL_OPT_TIMEOUTS @@ -1122,7 +1207,7 @@ _mysql_escape_string( #if MYSQL_VERSION_ID < 32321 len = mysql_escape_string(out, in, size); #else - check_server_init(NULL); + check_library_init(NULL); if (self && self->open) len = mysql_real_escape_string(&(self->connection), out, in, size); else @@ -1173,7 +1258,7 @@ _mysql_string_literal( #if MYSQL_VERSION_ID < 32321 len = mysql_escape_string(out+1, in, size); #else - check_server_init(NULL); + check_library_init(NULL); if (self && self->open) len = mysql_real_escape_string(&(self->connection), out+1, in, size); else @@ -1785,7 +1870,7 @@ _mysql_get_client_info( PyObject *args) { if (!PyArg_ParseTuple(args, "")) return NULL; - check_server_init(NULL); + check_library_init(NULL); #ifdef IS_PY3K return PyUnicode_FromString(mysql_get_client_info()); #else @@ -2988,9 +3073,21 @@ _mysql_methods[] = { { "thread_safe", (PyCFunction)_mysql_thread_safe, - METH_VARARGS, + METH_NOARGS, _mysql_thread_safe__doc__ }, + { + "thread_init", + (PyCFunction)_mysql_thread_init, + METH_NOARGS, + _mysql_thread_init__doc__ + }, + { + "thread_end", + (PyCFunction)_mysql_thread_end, + METH_NOARGS, + _mysql_thread_end__doc__ + }, #endif { "server_init", @@ -3001,9 +3098,21 @@ _mysql_methods[] = { { "server_end", (PyCFunction)_mysql_server_end, - METH_VARARGS, + METH_NOARGS, _mysql_server_end__doc__ }, + { + "library_init", + (PyCFunction)_mysql_library_init, + METH_VARARGS | METH_KEYWORDS, + _mysql_library_init__doc__, + }, + { + "library_end", + (PyCFunction)_mysql_library_end, + METH_NOARGS, + _mysql_library_end__doc__, + }, {NULL, NULL} /* sentinel */ };