@@ -943,7 +943,7 @@ _mysql_escape_string(
943943{
944944 PyObject * str ;
945945 char * in , * out ;
946- int len ;
946+ unsigned long len ;
947947 Py_ssize_t size ;
948948 if (!PyArg_ParseTuple (args , "s#:escape_string" , & in , & size )) return NULL ;
949949 str = PyBytes_FromStringAndSize ((char * ) NULL , size * 2 + 1 );
@@ -980,35 +980,52 @@ _mysql_string_literal(
980980 _mysql_ConnectionObject * self ,
981981 PyObject * o )
982982{
983- PyObject * str , * s ;
984- char * in , * out ;
985- unsigned long len ;
986- Py_ssize_t size ;
983+ PyObject * s ; // input string or bytes. need to decref.
987984
988985 if (self && PyModule_Check ((PyObject * )self ))
989986 self = NULL ;
990987
991988 if (PyBytes_Check (o )) {
992989 s = o ;
993990 Py_INCREF (s );
994- } else {
995- s = PyObject_Str (o );
996- if (!s ) return NULL ;
997- {
998- PyObject * t = PyUnicode_AsASCIIString (s );
999- Py_DECREF (s );
1000- if (!t ) return NULL ;
991+ }
992+ else {
993+ PyObject * t = PyObject_Str (o );
994+ if (!t ) return NULL ;
995+
996+ const char * encoding = (self && self -> open ) ?
997+ _get_encoding (& self -> connection ) : utf8 ;
998+ if (encoding == utf8 ) {
1001999 s = t ;
10021000 }
1001+ else {
1002+ s = PyUnicode_AsEncodedString (t , encoding , "strict" );
1003+ Py_DECREF (t );
1004+ if (!s ) return NULL ;
1005+ }
10031006 }
1004- in = PyBytes_AsString (s );
1005- size = PyBytes_GET_SIZE (s );
1006- str = PyBytes_FromStringAndSize ((char * ) NULL , size * 2 + 3 );
1007+
1008+ // Prepare input string (in, size)
1009+ const char * in ;
1010+ Py_ssize_t size ;
1011+ if (PyUnicode_Check (s )) {
1012+ in = PyUnicode_AsUTF8AndSize (s , & size );
1013+ } else {
1014+ assert (PyBytes_Check (s ));
1015+ in = PyBytes_AsString (s );
1016+ size = PyBytes_GET_SIZE (s );
1017+ }
1018+
1019+ // Prepare output buffer (str, out)
1020+ PyObject * str = PyBytes_FromStringAndSize ((char * ) NULL , size * 2 + 3 );
10071021 if (!str ) {
10081022 Py_DECREF (s );
10091023 return PyErr_NoMemory ();
10101024 }
1011- out = PyBytes_AS_STRING (str );
1025+ char * out = PyBytes_AS_STRING (str );
1026+
1027+ // escape
1028+ unsigned long len ;
10121029 if (self && self -> open ) {
10131030#if MYSQL_VERSION_ID >= 50707 && !defined(MARIADB_BASE_VERSION ) && !defined(MARIADB_VERSION_ID )
10141031 len = mysql_real_escape_string_quote (& (self -> connection ), out + 1 , in , size , '\'' );
@@ -1018,10 +1035,14 @@ _mysql_string_literal(
10181035 } else {
10191036 len = mysql_escape_string (out + 1 , in , size );
10201037 }
1021- * out = * (out + len + 1 ) = '\'' ;
1022- if (_PyBytes_Resize (& str , len + 2 ) < 0 ) return NULL ;
1038+
10231039 Py_DECREF (s );
1024- return (str );
1040+ * out = * (out + len + 1 ) = '\'' ;
1041+ if (_PyBytes_Resize (& str , len + 2 ) < 0 ) {
1042+ Py_DECREF (str );
1043+ return NULL ;
1044+ }
1045+ return str ;
10251046}
10261047
10271048static PyObject *
@@ -1499,8 +1520,9 @@ _mysql_ResultObject_discard(
14991520 // do nothing
15001521 }
15011522 Py_END_ALLOW_THREADS
1502- if (mysql_errno (self -> conn )) {
1503- return _mysql_Exception (self -> conn );
1523+ _mysql_ConnectionObject * conn = (_mysql_ConnectionObject * )self -> conn ;
1524+ if (mysql_errno (& conn -> connection )) {
1525+ return _mysql_Exception (conn );
15041526 }
15051527 Py_RETURN_NONE ;
15061528}
0 commit comments