Skip to content

Commit 4a0f8ad

Browse files
snapiritswast
andauthored
fix: compilation of a labeled custom FunctionElement when used in grouping (#1155)
* fix: fix grouped labels using custom functions If we have a labeled custom function that we are grouping by, and the function does not support the `default` dialect, we can not compile our query. * Update tests/unit/test_compiler.py --------- Co-authored-by: Tim Sweña (Swast) <tswast@gmail.com>
1 parent dc339ff commit 4a0f8ad

File tree

2 files changed

+41
-4
lines changed

2 files changed

+41
-4
lines changed

sqlalchemy_bigquery/base.py

+2-4
Original file line numberDiff line numberDiff line change
@@ -343,10 +343,8 @@ def visit_label(self, *args, within_group_by=False, **kwargs):
343343
if within_group_by:
344344
column_label = args[0]
345345
sql_keywords = {"GROUPING SETS", "ROLLUP", "CUBE"}
346-
for keyword in sql_keywords:
347-
if keyword in str(column_label):
348-
break
349-
else: # for/else always happens unless break gets called
346+
label_str = column_label.compile(dialect=self.dialect).string
347+
if not any(keyword in label_str for keyword in sql_keywords):
350348
kwargs["render_label_as_label"] = column_label
351349

352350
return super(BigQueryCompiler, self).visit_label(*args, **kwargs)

tests/unit/test_compiler.py

+39
Original file line numberDiff line numberDiff line change
@@ -417,3 +417,42 @@ def test_complex_grouping_ops_vs_nested_grouping_ops(
417417
)
418418

419419
assert found_sql == expected_sql
420+
421+
422+
def test_label_compiler(faux_conn, metadata):
423+
class CustomLower(sqlalchemy.sql.functions.FunctionElement):
424+
name = "custom_lower"
425+
426+
@sqlalchemy.ext.compiler.compiles(CustomLower)
427+
def compile_custom_intersect(element, compiler, **kwargs):
428+
if compiler.dialect.name != "bigquery":
429+
# We only test with the BigQuery dialect, so this should never happen.
430+
raise sqlalchemy.exc.CompileError( # pragma: NO COVER
431+
f"custom_lower is not supported for dialect {compiler.dialect.name}"
432+
)
433+
434+
clauses = list(element.clauses)
435+
field = compiler.process(clauses[0], **kwargs)
436+
return f"LOWER({field})"
437+
438+
table1 = setup_table(
439+
faux_conn,
440+
"table1",
441+
metadata,
442+
sqlalchemy.Column("foo", sqlalchemy.String),
443+
sqlalchemy.Column("bar", sqlalchemy.Integer),
444+
)
445+
446+
lower_foo = CustomLower(table1.c.foo).label("some_label")
447+
q = (
448+
sqlalchemy.select(lower_foo, sqlalchemy.func.max(table1.c.bar))
449+
.select_from(table1)
450+
.group_by(lower_foo)
451+
)
452+
expected_sql = (
453+
"SELECT LOWER(`table1`.`foo`) AS `some_label`, max(`table1`.`bar`) AS `max_1` \n"
454+
"FROM `table1` GROUP BY `some_label`"
455+
)
456+
457+
found_sql = q.compile(faux_conn).string
458+
assert found_sql == expected_sql

0 commit comments

Comments
 (0)