forked from googleapis/python-bigquery-sqlalchemy
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathsqlalchemy_bigquery.py
609 lines (516 loc) · 20.6 KB
/
sqlalchemy_bigquery.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
# Copyright (c) 2017 The PyBigQuery Authors
#
# Permission is hereby granted, free of charge, to any person obtaining a copy of
# this software and associated documentation files (the "Software"), to deal in
# the Software without restriction, including without limitation the rights to
# use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
# the Software, and to permit persons to whom the Software is furnished to do so,
# subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
# FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
# COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
# IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
"""Integration between SQLAlchemy and BigQuery."""
from __future__ import absolute_import
from __future__ import unicode_literals
import operator
from google import auth
import google.api_core.exceptions
from google.cloud.bigquery import dbapi
from google.cloud.bigquery.schema import SchemaField
from google.cloud.bigquery.table import TableReference
from google.api_core.exceptions import NotFound
from sqlalchemy.exc import NoSuchTableError
from sqlalchemy import types, util
from sqlalchemy.sql.compiler import (
SQLCompiler,
GenericTypeCompiler,
DDLCompiler,
IdentifierPreparer,
)
from sqlalchemy.engine.default import DefaultDialect, DefaultExecutionContext
from sqlalchemy.engine.base import Engine
from sqlalchemy.sql.schema import Column
from sqlalchemy.sql import elements
import re
from .parse_url import parse_url
from pybigquery import _helpers
FIELD_ILLEGAL_CHARACTERS = re.compile(r"[^\w]+")
class UniversalSet(object):
"""
Set containing everything
https://github.com/dropbox/PyHive/blob/master/pyhive/common.py
"""
def __contains__(self, item):
return True
class BigQueryIdentifierPreparer(IdentifierPreparer):
"""
Set containing everything
https://github.com/dropbox/PyHive/blob/master/pyhive/sqlalchemy_presto.py
"""
reserved_words = UniversalSet()
def __init__(self, dialect):
super(BigQueryIdentifierPreparer, self).__init__(
dialect, initial_quote="`",
)
def quote_column(self, value):
"""
Quote a column.
Fields are quoted separately from the record name.
"""
parts = value.split(".")
return ".".join(self.quote_identifier(x) for x in parts)
def quote(self, ident, force=None, column=False):
"""
Conditionally quote an identifier.
"""
force = getattr(ident, "quote", None)
if force is None:
if ident in self._strings:
return self._strings[ident]
else:
if self._requires_quotes(ident):
self._strings[ident] = (
self.quote_column(ident)
if column
else self.quote_identifier(ident)
)
else:
self._strings[ident] = ident
return self._strings[ident]
elif force:
return self.quote_column(ident) if column else self.quote_identifier(ident)
else:
return ident
def format_label(self, label, name=None):
name = name or label.name
# Fields must start with a letter or underscore
if not name[0].isalpha() and name[0] != "_":
name = "_" + name
# Fields must contain only letters, numbers, and underscores
name = FIELD_ILLEGAL_CHARACTERS.sub("_", name)
result = self.quote(name)
return result
_type_map = {
"STRING": types.String,
"BOOLEAN": types.Boolean,
"INTEGER": types.Integer,
"FLOAT": types.Float,
"TIMESTAMP": types.TIMESTAMP,
"DATETIME": types.DATETIME,
"DATE": types.DATE,
"BYTES": types.BINARY,
"TIME": types.TIME,
"RECORD": types.JSON,
"NUMERIC": types.DECIMAL,
}
STRING = _type_map["STRING"]
BOOLEAN = _type_map["BOOLEAN"]
INTEGER = _type_map["INTEGER"]
FLOAT = _type_map["FLOAT"]
TIMESTAMP = _type_map["TIMESTAMP"]
DATETIME = _type_map["DATETIME"]
DATE = _type_map["DATE"]
BYTES = _type_map["BYTES"]
TIME = _type_map["TIME"]
RECORD = _type_map["RECORD"]
NUMERIC = _type_map["NUMERIC"]
class BigQueryExecutionContext(DefaultExecutionContext):
def create_cursor(self):
# Set arraysize
c = super(BigQueryExecutionContext, self).create_cursor()
if self.dialect.arraysize:
c.arraysize = self.dialect.arraysize
return c
class BigQueryCompiler(SQLCompiler):
def __init__(self, dialect, statement, column_keys=None, inline=False, **kwargs):
if isinstance(statement, Column):
kwargs["compile_kwargs"] = util.immutabledict({"include_table": False})
super(BigQueryCompiler, self).__init__(
dialect, statement, column_keys, inline, **kwargs
)
def visit_column(
self, column, add_to_result_map=None, include_table=True, **kwargs
):
name = orig_name = column.name
if name is None:
name = self._fallback_column_name(column)
is_literal = column.is_literal
if not is_literal and isinstance(name, elements._truncated_label):
name = self._truncated_identifier("colident", name)
if add_to_result_map is not None:
add_to_result_map(name, orig_name, (column, name, column.key), column.type)
if is_literal:
name = self.escape_literal_column(name)
else:
name = self.preparer.quote(name, column=True)
table = column.table
if table is None or not include_table or not table.named_with_column:
return name
else:
effective_schema = self.preparer.schema_for_object(table)
if effective_schema:
schema_prefix = self.preparer.quote_schema(effective_schema) + "."
else:
schema_prefix = ""
tablename = table.name
if isinstance(tablename, elements._truncated_label):
tablename = self._truncated_identifier("alias", tablename)
return schema_prefix + self.preparer.quote(tablename) + "." + name
def visit_label(self, *args, within_group_by=False, **kwargs):
# Use labels in GROUP BY clause.
#
# Flag set in the group_by_clause method. Works around missing
# equivalent to supports_simple_order_by_label for group by.
if within_group_by:
kwargs["render_label_as_label"] = args[0]
return super(BigQueryCompiler, self).visit_label(*args, **kwargs)
def group_by_clause(self, select, **kw):
return super(BigQueryCompiler, self).group_by_clause(
select, **kw, within_group_by=True
)
class BigQueryTypeCompiler(GenericTypeCompiler):
def visit_integer(self, type_, **kw):
return "INT64"
def visit_float(self, type_, **kw):
return "FLOAT64"
def visit_text(self, type_, **kw):
return "STRING"
def visit_string(self, type_, **kw):
return "STRING"
def visit_ARRAY(self, type_, **kw):
return "ARRAY<{}>".format(self.process(type_.item_type, **kw))
def visit_BINARY(self, type_, **kw):
return "BYTES"
def visit_NUMERIC(self, type_, **kw):
return "NUMERIC"
def visit_DECIMAL(self, type_, **kw):
return "NUMERIC"
class BigQueryDDLCompiler(DDLCompiler):
# BigQuery has no support for foreign keys.
def visit_foreign_key_constraint(self, constraint):
return None
# BigQuery has no support for primary keys.
def visit_primary_key_constraint(self, constraint):
return None
def get_column_specification(self, column, **kwargs):
colspec = super(BigQueryDDLCompiler, self).get_column_specification(
column, **kwargs
)
if column.doc is not None:
colspec = "{} OPTIONS(description={})".format(
colspec, self.preparer.quote(column.doc)
)
return colspec
def post_create_table(self, table):
bq_opts = table.dialect_options["bigquery"]
opts = []
if "description" in bq_opts:
opts.append(
"description={}".format(self.preparer.quote(bq_opts["description"]))
)
if "friendly_name" in bq_opts:
opts.append(
"friendly_name={}".format(self.preparer.quote(bq_opts["friendly_name"]))
)
if opts:
return "\nOPTIONS({})".format(", ".join(opts))
return ""
class BigQueryDialect(DefaultDialect):
name = "bigquery"
driver = "bigquery"
preparer = BigQueryIdentifierPreparer
statement_compiler = BigQueryCompiler
type_compiler = BigQueryTypeCompiler
ddl_compiler = BigQueryDDLCompiler
execution_ctx_cls = BigQueryExecutionContext
supports_alter = False
supports_pk_autoincrement = False
supports_default_values = False
supports_empty_insert = False
supports_multiline_insert = True
supports_unicode_statements = True
supports_unicode_binds = True
supports_native_decimal = True
returns_unicode_strings = True
description_encoding = None
supports_native_boolean = True
supports_simple_order_by_label = True
postfetch_lastrowid = False
def __init__(
self,
arraysize=5000,
credentials_path=None,
location=None,
credentials_info=None,
*args,
**kwargs
):
super(BigQueryDialect, self).__init__(*args, **kwargs)
self.arraysize = arraysize
self.credentials_path = credentials_path
self.credentials_info = credentials_info
self.location = location
self.dataset_id = None
@classmethod
def dbapi(cls):
return dbapi
@staticmethod
def _build_formatted_table_id(table):
"""Build '<dataset_id>.<table_id>' string using given table."""
return "{}.{}".format(table.reference.dataset_id, table.table_id)
@staticmethod
def _add_default_dataset_to_job_config(job_config, project_id, dataset_id):
# If dataset_id is set, then we know the job_config isn't None
if dataset_id:
# If project_id is missing, use default project_id for the current environment
if not project_id:
_, project_id = auth.default()
job_config.default_dataset = "{}.{}".format(project_id, dataset_id)
def create_connect_args(self, url):
(
project_id,
location,
dataset_id,
arraysize,
credentials_path,
default_query_job_config,
) = parse_url(url)
self.arraysize = self.arraysize or arraysize
self.location = location or self.location
self.credentials_path = credentials_path or self.credentials_path
self.dataset_id = dataset_id
self._add_default_dataset_to_job_config(
default_query_job_config, project_id, dataset_id
)
client = _helpers.create_bigquery_client(
credentials_path=self.credentials_path,
credentials_info=self.credentials_info,
project_id=project_id,
location=self.location,
default_query_job_config=default_query_job_config,
)
return ([client], {})
def _json_deserializer(self, row):
"""JSON deserializer for RECORD types.
The DB-API layer already deserializes JSON to a dictionary, so this
just returns the input.
"""
return row
def _get_table_or_view_names(self, connection, table_type, schema=None):
current_schema = schema or self.dataset_id
get_table_name = (
self._build_formatted_table_id
if self.dataset_id is None
else operator.attrgetter("table_id")
)
client = connection.connection._client
datasets = client.list_datasets()
result = []
for dataset in datasets:
if current_schema is not None and current_schema != dataset.dataset_id:
continue
try:
tables = client.list_tables(dataset.reference)
for table in tables:
if table_type == table.table_type:
result.append(get_table_name(table))
except google.api_core.exceptions.NotFound:
# It's possible that the dataset was deleted between when we
# fetched the list of datasets and when we try to list the
# tables from it. See:
# https://github.com/googleapis/python-bigquery-sqlalchemy/issues/105
pass
return result
@staticmethod
def _split_table_name(full_table_name):
# Split full_table_name to get project, dataset and table name
dataset = None
table_name = None
project = None
table_name_split = full_table_name.split(".")
if len(table_name_split) == 1:
table_name = full_table_name
elif len(table_name_split) == 2:
dataset, table_name = table_name_split
elif len(table_name_split) == 3:
project, dataset, table_name = table_name_split
else:
raise ValueError(
"Did not understand table_name: {}".format(full_table_name)
)
return (project, dataset, table_name)
def _table_reference(
self, provided_schema_name, provided_table_name, client_project
):
project_id_from_table, dataset_id_from_table, table_id = self._split_table_name(
provided_table_name
)
project_id_from_schema = None
dataset_id_from_schema = None
if provided_schema_name is not None:
provided_schema_name_split = provided_schema_name.split(".")
if len(provided_schema_name_split) == 0:
pass
elif len(provided_schema_name_split) == 1:
if dataset_id_from_table:
project_id_from_schema = provided_schema_name_split[0]
else:
dataset_id_from_schema = provided_schema_name_split[0]
elif len(provided_schema_name_split) == 2:
project_id_from_schema = provided_schema_name_split[0]
dataset_id_from_schema = provided_schema_name_split[1]
else:
raise ValueError(
"Did not understand schema: {}".format(provided_schema_name)
)
if (
dataset_id_from_schema
and dataset_id_from_table
and dataset_id_from_schema != dataset_id_from_table
):
raise ValueError(
"dataset_id specified in schema and table_name disagree: "
"got {} in schema, and {} in table_name".format(
dataset_id_from_schema, dataset_id_from_table
)
)
if (
project_id_from_schema
and project_id_from_table
and project_id_from_schema != project_id_from_table
):
raise ValueError(
"project_id specified in schema and table_name disagree: "
"got {} in schema, and {} in table_name".format(
project_id_from_schema, project_id_from_table
)
)
project_id = project_id_from_schema or project_id_from_table or client_project
dataset_id = dataset_id_from_schema or dataset_id_from_table or self.dataset_id
table_ref = TableReference.from_string(
"{}.{}.{}".format(project_id, dataset_id, table_id)
)
return table_ref
def _get_table(self, connection, table_name, schema=None):
if isinstance(connection, Engine):
connection = connection.connect()
client = connection.connection._client
table_ref = self._table_reference(schema, table_name, client.project)
try:
table = client.get_table(table_ref)
except NotFound:
raise NoSuchTableError(table_name)
return table
def has_table(self, connection, table_name, schema=None):
try:
self._get_table(connection, table_name, schema)
return True
except NoSuchTableError:
return False
def _get_columns_helper(self, columns, cur_columns):
"""
Recurse into record type and return all the nested field names.
As contributed by @sumedhsakdeo on issue #17
"""
results = []
for col in columns:
results += [
SchemaField(
name=".".join(col.name for col in cur_columns + [col]),
field_type=col.field_type,
mode=col.mode,
description=col.description,
fields=col.fields,
)
]
if col.field_type == "RECORD":
cur_columns.append(col)
results += self._get_columns_helper(col.fields, cur_columns)
cur_columns.pop()
return results
def get_columns(self, connection, table_name, schema=None, **kw):
table = self._get_table(connection, table_name, schema)
columns = self._get_columns_helper(table.schema, [])
result = []
for col in columns:
try:
coltype = _type_map[col.field_type]
except KeyError:
util.warn(
"Did not recognize type '%s' of column '%s'"
% (col.field_type, col.name)
)
coltype = types.NullType
result.append(
{
"name": col.name,
"type": types.ARRAY(coltype) if col.mode == "REPEATED" else coltype,
"nullable": col.mode == "NULLABLE" or col.mode == "REPEATED",
"comment": col.description,
"default": None,
}
)
return result
def get_table_comment(self, connection, table_name, schema=None, **kw):
table = self._get_table(connection, table_name, schema)
return {
"text": table.description,
}
def get_foreign_keys(self, connection, table_name, schema=None, **kw):
# BigQuery has no support for foreign keys.
return []
def get_pk_constraint(self, connection, table_name, schema=None, **kw):
# BigQuery has no support for primary keys.
return {"constrained_columns": []}
def get_indexes(self, connection, table_name, schema=None, **kw):
table = self._get_table(connection, table_name, schema)
indexes = []
if table.time_partitioning:
indexes.append(
{
"name": "partition",
"column_names": [table.time_partitioning.field],
"unique": False,
}
)
if table.clustering_fields:
indexes.append(
{
"name": "clustering",
"column_names": table.clustering_fields,
"unique": False,
}
)
return indexes
def get_schema_names(self, connection, **kw):
if isinstance(connection, Engine):
connection = connection.connect()
datasets = connection.connection._client.list_datasets()
if self.dataset_id is not None:
return [d.dataset_id for d in datasets if d.dataset_id == self.dataset_id]
else:
return [d.dataset_id for d in datasets]
def get_table_names(self, connection, schema=None, **kw):
if isinstance(connection, Engine):
connection = connection.connect()
return self._get_table_or_view_names(connection, "TABLE", schema)
def get_view_names(self, connection, schema=None, **kw):
if isinstance(connection, Engine):
connection = connection.connect()
return self._get_table_or_view_names(connection, "VIEW", schema)
def do_rollback(self, dbapi_connection):
# BigQuery has no support for transactions.
pass
def _check_unicode_returns(self, connection, additional_tests=None):
# requests gives back Unicode strings
return True
def _check_unicode_description(self, connection):
# requests gives back Unicode strings
return True