|
15 | 15 | NotSupportedError, ProgrammingError) |
16 | 16 |
|
17 | 17 |
|
18 | | -PY2 = sys.version_info[0] == 2 |
19 | | -if PY2: |
20 | | - text_type = unicode |
21 | | -else: |
22 | | - text_type = str |
23 | | - |
24 | | - |
25 | 18 | #: Regular expression for :meth:`Cursor.executemany`. |
26 | 19 | #: executemany only supports simple bulk insert. |
27 | 20 | #: You can use it to load large dataset. |
@@ -95,31 +88,28 @@ def __exit__(self, *exc_info): |
95 | 88 | del exc_info |
96 | 89 | self.close() |
97 | 90 |
|
98 | | - def _ensure_bytes(self, x, encoding=None): |
99 | | - if isinstance(x, text_type): |
100 | | - x = x.encode(encoding) |
101 | | - elif isinstance(x, (tuple, list)): |
102 | | - x = type(x)(self._ensure_bytes(v, encoding=encoding) for v in x) |
103 | | - return x |
104 | | - |
105 | 91 | def _escape_args(self, args, conn): |
106 | | - ensure_bytes = partial(self._ensure_bytes, encoding=conn.encoding) |
| 92 | + encoding = conn.encoding |
| 93 | + literal = conn.literal |
| 94 | + |
| 95 | + def ensure_bytes(x): |
| 96 | + if isinstance(x, unicode): |
| 97 | + return x.encode(encoding) |
| 98 | + elif isinstance(x, tuple): |
| 99 | + return tuple(map(ensure_bytes, x)) |
| 100 | + elif isinstance(x, list): |
| 101 | + return list(map(ensure_bytes, x)) |
| 102 | + return x |
107 | 103 |
|
108 | 104 | if isinstance(args, (tuple, list)): |
109 | | - if PY2: |
110 | | - args = tuple(map(ensure_bytes, args)) |
111 | | - return tuple(conn.literal(arg) for arg in args) |
| 105 | + return tuple(literal(ensure_bytes(arg)) for arg in args) |
112 | 106 | elif isinstance(args, dict): |
113 | | - if PY2: |
114 | | - args = dict((ensure_bytes(key), ensure_bytes(val)) for |
115 | | - (key, val) in args.items()) |
116 | | - return dict((key, conn.literal(val)) for (key, val) in args.items()) |
| 107 | + return {ensure_bytes(key): literal(ensure_bytes(val)) |
| 108 | + for (key, val) in args.items()} |
117 | 109 | else: |
118 | 110 | # If it's not a dictionary let's try escaping it anyways. |
119 | 111 | # Worst case it will throw a Value error |
120 | | - if PY2: |
121 | | - args = ensure_bytes(args) |
122 | | - return conn.literal(args) |
| 112 | + return literal(ensure_bytes(args)) |
123 | 113 |
|
124 | 114 | def _check_executed(self): |
125 | 115 | if not self._executed: |
@@ -186,31 +176,20 @@ def execute(self, query, args=None): |
186 | 176 | pass |
187 | 177 | db = self._get_db() |
188 | 178 |
|
189 | | - # NOTE: |
190 | | - # Python 2: query should be bytes when executing %. |
191 | | - # All unicode in args should be encoded to bytes on Python 2. |
192 | | - # Python 3: query should be str (unicode) when executing %. |
193 | | - # All bytes in args should be decoded with ascii and surrogateescape on Python 3. |
194 | | - # db.literal(obj) always returns str. |
195 | | - |
196 | | - if PY2 and isinstance(query, unicode): |
| 179 | + if isinstance(query, unicode): |
197 | 180 | query = query.encode(db.encoding) |
198 | 181 |
|
199 | 182 | if args is not None: |
200 | 183 | if isinstance(args, dict): |
201 | 184 | args = dict((key, db.literal(item)) for key, item in args.items()) |
202 | 185 | else: |
203 | 186 | args = tuple(map(db.literal, args)) |
204 | | - if not PY2 and isinstance(query, (bytes, bytearray)): |
205 | | - query = query.decode(db.encoding) |
206 | 187 | try: |
207 | 188 | query = query % args |
208 | 189 | except TypeError as m: |
209 | 190 | raise ProgrammingError(str(m)) |
210 | 191 |
|
211 | | - if isinstance(query, unicode): |
212 | | - query = query.encode(db.encoding, 'surrogateescape') |
213 | | - |
| 192 | + assert isinstance(query, (bytes, bytearray)) |
214 | 193 | res = self._query(query) |
215 | 194 | return res |
216 | 195 |
|
@@ -247,29 +226,19 @@ def executemany(self, query, args): |
247 | 226 | def _do_execute_many(self, prefix, values, postfix, args, max_stmt_length, encoding): |
248 | 227 | conn = self._get_db() |
249 | 228 | escape = self._escape_args |
250 | | - if isinstance(prefix, text_type): |
| 229 | + if isinstance(prefix, unicode): |
251 | 230 | prefix = prefix.encode(encoding) |
252 | | - if PY2 and isinstance(values, text_type): |
| 231 | + if isinstance(values, unicode): |
253 | 232 | values = values.encode(encoding) |
254 | | - if isinstance(postfix, text_type): |
| 233 | + if isinstance(postfix, unicode): |
255 | 234 | postfix = postfix.encode(encoding) |
256 | 235 | sql = bytearray(prefix) |
257 | 236 | args = iter(args) |
258 | 237 | v = values % escape(next(args), conn) |
259 | | - if isinstance(v, text_type): |
260 | | - if PY2: |
261 | | - v = v.encode(encoding) |
262 | | - else: |
263 | | - v = v.encode(encoding, 'surrogateescape') |
264 | 238 | sql += v |
265 | 239 | rows = 0 |
266 | 240 | for arg in args: |
267 | 241 | v = values % escape(arg, conn) |
268 | | - if isinstance(v, text_type): |
269 | | - if PY2: |
270 | | - v = v.encode(encoding) |
271 | | - else: |
272 | | - v = v.encode(encoding, 'surrogateescape') |
273 | 242 | if len(sql) + len(v) + len(postfix) + 1 > max_stmt_length: |
274 | 243 | rows += self.execute(sql + postfix) |
275 | 244 | sql = bytearray(prefix) |
@@ -308,22 +277,19 @@ def callproc(self, procname, args=()): |
308 | 277 | to advance through all result sets; otherwise you may get |
309 | 278 | disconnected. |
310 | 279 | """ |
311 | | - |
312 | 280 | db = self._get_db() |
| 281 | + if isinstance(procname, unicode): |
| 282 | + procname = procname.encode(db.encoding) |
313 | 283 | if args: |
314 | | - fmt = '@_{0}_%d=%s'.format(procname) |
315 | | - q = 'SET %s' % ','.join(fmt % (index, db.literal(arg)) |
316 | | - for index, arg in enumerate(args)) |
317 | | - if isinstance(q, unicode): |
318 | | - q = q.encode(db.encoding, 'surrogateescape') |
| 284 | + fmt = b'@_' + procname + b'_%d=%s' |
| 285 | + q = b'SET %s' % b','.join(fmt % (index, db.literal(arg)) |
| 286 | + for index, arg in enumerate(args)) |
319 | 287 | self._query(q) |
320 | 288 | self.nextset() |
321 | 289 |
|
322 | | - q = "CALL %s(%s)" % (procname, |
323 | | - ','.join(['@_%s_%d' % (procname, i) |
324 | | - for i in range(len(args))])) |
325 | | - if isinstance(q, unicode): |
326 | | - q = q.encode(db.encoding, 'surrogateescape') |
| 290 | + q = b"CALL %s(%s)" % (procname, |
| 291 | + b','.join([b'@_%s_%d' % (procname, i) |
| 292 | + for i in range(len(args))])) |
327 | 293 | self._query(q) |
328 | 294 | return args |
329 | 295 |
|
|
0 commit comments