Skip to content

Commit 21c2a70

Browse files
authored
generate different sql code based on dialect (#6915)
## 📝 Summary <!-- Provide a concise summary of what this pull request is addressing. If this PR fixes any issues, list them here by number (e.g., Fixes #123). --> Fixes #6875. Might also tackle #6872. ## 🔍 Description of Changes <!-- Detail the specific changes made in this pull request. Explain the problem addressed and how it was resolved. If applicable, provide before and after comparisons, screenshots, or any relevant details to help reviewers understand the changes easily. --> ## 📋 Checklist - [x] I have read the [contributor guidelines](https://github.com/marimo-team/marimo/blob/main/CONTRIBUTING.md). - [ ] For large changes, or changes that affect the public API: this change was discussed or approved through an issue, on [Discord](https://marimo.io/discord?ref=pr), or the community [discussions](https://github.com/marimo-team/marimo/discussions) (Please provide a link if applicable). - [x] I have added tests for the changes made. - [x] I have run the code and verified that it works as expected.
1 parent b6f1652 commit 21c2a70

File tree

8 files changed

+493
-147
lines changed

8 files changed

+493
-147
lines changed

frontend/src/components/datasources/__tests__/utils.test.ts

Lines changed: 319 additions & 138 deletions
Large diffs are not rendered by default.

frontend/src/components/datasources/datasources.tsx

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -214,6 +214,7 @@ export const DataSources: React.FC = () => {
214214
databaseName={database.name}
215215
hasSearch={hasSearch}
216216
searchValue={searchValue}
217+
dialect={connection.dialect}
217218
/>
218219
</DatabaseItem>
219220
))}
@@ -340,6 +341,7 @@ const SchemaList: React.FC<{
340341
schemas: DatabaseSchema[];
341342
defaultSchema?: string | null;
342343
defaultDatabase?: string | null;
344+
dialect: string;
343345
engineName: string;
344346
databaseName: string;
345347
hasSearch: boolean;
@@ -348,6 +350,7 @@ const SchemaList: React.FC<{
348350
schemas,
349351
defaultSchema,
350352
defaultDatabase,
353+
dialect,
351354
engineName,
352355
databaseName,
353356
hasSearch,
@@ -384,6 +387,7 @@ const SchemaList: React.FC<{
384387
schema: schema.name,
385388
defaultSchema: defaultSchema,
386389
defaultDatabase: defaultDatabase,
390+
dialect: dialect,
387391
}}
388392
/>
389393
</SchemaItem>

frontend/src/components/datasources/utils.ts

Lines changed: 93 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,90 @@
11
/* Copyright 2024 Marimo. All rights reserved. */
2+
3+
import { BigQueryDialect } from "@marimo-team/codemirror-sql/dialects";
4+
import { isKnownDialect } from "@/core/codemirror/language/languages/sql/utils";
25
import type { SQLTableContext } from "@/core/datasets/data-source-connections";
36
import { DUCKDB_ENGINE } from "@/core/datasets/engines";
47
import type { DataTable, DataType } from "@/core/kernel/messages";
8+
import { logNever } from "@/utils/assertNever";
59
import type { ColumnHeaderStatsKey } from "../data-table/types";
610

711
// Some databases have no schemas, so we don't show it (eg. Clickhouse)
812
export function isSchemaless(schemaName: string) {
913
return schemaName === "";
1014
}
1115

16+
interface SqlCodeFormatter {
17+
/**
18+
* Format the table name based on dialect-specific rules
19+
*/
20+
formatTableName: (tableName: string) => string;
21+
/**
22+
* Format the SELECT clause
23+
*/
24+
formatSelectClause: (columnName: string, tableName: string) => string;
25+
}
26+
27+
const defaultFormatter: SqlCodeFormatter = {
28+
formatTableName: (tableName: string) => tableName,
29+
formatSelectClause: (columnName: string, tableName: string) =>
30+
`SELECT ${columnName} FROM ${tableName} LIMIT 100`,
31+
};
32+
33+
function getFormatter(dialect: string): SqlCodeFormatter {
34+
dialect = dialect.toLowerCase();
35+
if (!isKnownDialect(dialect)) {
36+
return defaultFormatter;
37+
}
38+
39+
switch (dialect) {
40+
case "bigquery": {
41+
const quote = BigQueryDialect.spec.identifierQuotes;
42+
return {
43+
// BigQuery uses backticks for identifiers
44+
formatTableName: (tableName: string) => `${quote}${tableName}${quote}`,
45+
formatSelectClause: defaultFormatter.formatSelectClause,
46+
};
47+
}
48+
case "mssql":
49+
case "sqlserver":
50+
return {
51+
formatTableName: defaultFormatter.formatTableName,
52+
formatSelectClause: (columnName: string, tableName: string) =>
53+
`SELECT TOP 100 ${columnName} FROM ${tableName}`,
54+
};
55+
case "timescaledb":
56+
return {
57+
// TimescaleDB uses double quotes for identifiers
58+
formatTableName: (tableName: string) => {
59+
const parts = tableName.split(".");
60+
return parts.map((part) => `"${part}"`).join(".");
61+
},
62+
formatSelectClause: defaultFormatter.formatSelectClause,
63+
};
64+
case "postgresql":
65+
case "postgres":
66+
case "db2":
67+
case "mysql":
68+
case "sqlite":
69+
case "duckdb":
70+
case "mariadb":
71+
case "cassandra":
72+
case "noql":
73+
case "athena":
74+
case "hive":
75+
case "redshift":
76+
case "snowflake":
77+
case "flink":
78+
case "mongodb":
79+
case "oracle":
80+
case "oracledb":
81+
return defaultFormatter;
82+
default:
83+
logNever(dialect);
84+
return defaultFormatter;
85+
}
86+
}
87+
1288
export function sqlCode({
1389
table,
1490
columnName,
@@ -19,8 +95,14 @@ export function sqlCode({
1995
sqlTableContext?: SQLTableContext;
2096
}) {
2197
if (sqlTableContext) {
22-
const { engine, schema, defaultSchema, defaultDatabase, database } =
23-
sqlTableContext;
98+
const {
99+
engine,
100+
schema,
101+
defaultSchema,
102+
defaultDatabase,
103+
database,
104+
dialect,
105+
} = sqlTableContext;
24106
let tableName = table.name;
25107

26108
// Set the fully qualified table name based on schema and database
@@ -39,11 +121,18 @@ export function sqlCode({
39121
}
40122
}
41123

124+
const formatter = getFormatter(dialect);
125+
const formattedTableName = formatter.formatTableName(tableName);
126+
const selectClause = formatter.formatSelectClause(
127+
columnName,
128+
formattedTableName,
129+
);
130+
42131
if (engine === DUCKDB_ENGINE) {
43-
return `_df = mo.sql(f"SELECT ${columnName} FROM ${tableName} LIMIT 100")`;
132+
return `_df = mo.sql(f"""\n${selectClause}\n""")`;
44133
}
45134

46-
return `_df = mo.sql(f"SELECT ${columnName} FROM ${tableName} LIMIT 100", engine=${engine})`;
135+
return `_df = mo.sql(f"""\n${selectClause}\n""", engine=${engine})`;
47136
}
48137

49138
return `_df = mo.sql(f'SELECT "${columnName}" FROM ${table.name} LIMIT 100')`;

frontend/src/core/codemirror/language/languages/sql/completion-store.ts

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,7 @@ class SQLCompletionStore {
134134
if (!connection) {
135135
return ModifiedStandardSQL;
136136
}
137-
return guessDialect(connection) ?? ModifiedStandardSQL;
137+
return guessDialect(connection);
138138
}
139139

140140
getCompletionSource(connectionName: ConnectionName): SQLConfig | null {
@@ -152,7 +152,7 @@ class SQLCompletionStore {
152152
const schema = this.cache.getOrCreate(connection);
153153

154154
return {
155-
dialect: guessDialect(connection) ?? ModifiedStandardSQL,
155+
dialect: guessDialect(connection),
156156
schema: schema.shouldAddLocalTables
157157
? { ...schema.schema, ...getTablesMap() }
158158
: schema.schema,

frontend/src/core/codemirror/language/languages/sql/sql.ts

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ import type { HotkeyProvider } from "@/core/hotkeys/hotkeys";
3838
import type { ValidateSQLResult } from "@/core/kernel/messages";
3939
import { store } from "@/core/state/jotai";
4040
import { resolvedThemeAtom } from "@/theme/useTheme";
41+
import { logNever } from "@/utils/assertNever";
4142
import { Logger } from "@/utils/Logger";
4243
import { variableCompletionSource } from "../../embedded/embedded-python";
4344
import { languageMetadataField } from "../../metadata";
@@ -52,6 +53,7 @@ import {
5253
} from "./completion-sources";
5354
import { SCHEMA_CACHE } from "./completion-store";
5455
import { getSQLMode, type SQLMode } from "./sql-mode";
56+
import { isKnownDialect } from "./utils";
5557

5658
const DEFAULT_DIALECT = DuckDBDialect;
5759
const DEFAULT_PARSER_DIALECT = "DuckDB";
@@ -353,6 +355,11 @@ function connectionNameToParserDialect(
353355
): ParserDialects | null {
354356
const dialect =
355357
SCHEMA_CACHE.getInternalDialect(connectionName)?.toLowerCase();
358+
359+
if (!dialect || !isKnownDialect(dialect)) {
360+
return null;
361+
}
362+
356363
switch (dialect) {
357364
case "postgresql":
358365
case "postgres":
@@ -385,8 +392,15 @@ function connectionNameToParserDialect(
385392
case "flink":
386393
return "FlinkSQL";
387394
case "mongodb":
395+
case "noql":
388396
return "Noql";
397+
case "oracle":
398+
case "oracledb":
399+
case "timescaledb":
400+
Logger.debug("Unsupported dialect", { dialect });
401+
return null;
389402
default:
403+
logNever(dialect);
390404
return null;
391405
}
392406
}

frontend/src/core/codemirror/language/languages/sql/utils.ts

Lines changed: 54 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,14 +17,52 @@ import {
1717
DuckDBDialect,
1818
} from "@marimo-team/codemirror-sql/dialects";
1919
import type { DataSourceConnection } from "@/core/kernel/messages";
20+
import { logNever } from "@/utils/assertNever";
21+
import { Logger } from "@/utils/Logger";
22+
23+
const KNOWN_DIALECTS_ARRAY = [
24+
"postgresql",
25+
"postgres",
26+
"db2",
27+
"mysql",
28+
"sqlite",
29+
"mssql",
30+
"sqlserver",
31+
"duckdb",
32+
"mariadb",
33+
"cassandra",
34+
"noql",
35+
"athena",
36+
"bigquery",
37+
"hive",
38+
"redshift",
39+
"snowflake",
40+
"flink",
41+
"mongodb",
42+
"oracle",
43+
"oracledb",
44+
"timescaledb",
45+
] as const;
46+
const KNOWN_DIALECTS: ReadonlySet<string> = new Set(KNOWN_DIALECTS_ARRAY);
47+
type KnownDialect = (typeof KNOWN_DIALECTS_ARRAY)[number];
48+
49+
export function isKnownDialect(dialect: string): dialect is KnownDialect {
50+
return KNOWN_DIALECTS.has(dialect);
51+
}
2052

2153
/**
2254
* Guess the CodeMirror SQL dialect from the backend connection dialect.
55+
* If unknown, return the standard SQL dialect.
2356
*/
2457
export function guessDialect(
2558
connection: Pick<DataSourceConnection, "dialect">,
26-
): SQLDialect | undefined {
27-
switch (connection.dialect) {
59+
): SQLDialect {
60+
const dialect = connection.dialect;
61+
if (!isKnownDialect(dialect)) {
62+
return ModifiedStandardSQL;
63+
}
64+
65+
switch (dialect) {
2866
case "postgresql":
2967
case "postgres":
3068
return PostgreSQL;
@@ -46,8 +84,21 @@ export function guessDialect(
4684
return PLSQL;
4785
case "bigquery":
4886
return BigQueryDialect;
87+
case "timescaledb":
88+
return PostgreSQL; // TimescaleDB is a PostgreSQL dialect
89+
case "athena":
90+
case "db2":
91+
case "hive":
92+
case "redshift":
93+
case "snowflake":
94+
case "flink":
95+
case "mongodb":
96+
case "noql":
97+
Logger.debug("Unsupported dialect", { dialect });
98+
return ModifiedStandardSQL;
4999
default:
50-
return undefined;
100+
logNever(dialect);
101+
return ModifiedStandardSQL;
51102
}
52103
}
53104

frontend/src/core/datasets/__tests__/data-source.test.ts

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -270,6 +270,7 @@ describe("add table list", () => {
270270
engine: "conn1" as ConnectionName,
271271
database: "db1",
272272
schema: "public",
273+
dialect: "sqlite",
273274
});
274275

275276
const conn1 = newState.connectionsMap.get("conn1" as ConnectionName);
@@ -283,6 +284,7 @@ describe("add table list", () => {
283284
engine: "conn1" as ConnectionName,
284285
database: "db1",
285286
schema: "public",
287+
dialect: "sqlite",
286288
};
287289

288290
const tableList: DataTable[] = [
@@ -344,6 +346,7 @@ describe("add table list", () => {
344346
engine: "conn1" as ConnectionName,
345347
database: "db1",
346348
schema: "non_existent",
349+
dialect: "sqlite",
347350
});
348351

349352
const conn1 = newState.connectionsMap.get("conn1" as ConnectionName);
@@ -407,6 +410,7 @@ describe("add table", () => {
407410
engine: "conn1" as ConnectionName,
408411
database: "db1",
409412
schema: "public",
413+
dialect: "sqlite",
410414
});
411415

412416
const conn1 = newState.connectionsMap.get("conn1" as ConnectionName);
@@ -420,6 +424,7 @@ describe("add table", () => {
420424
engine: "conn1" as ConnectionName,
421425
database: "db1",
422426
schema: "public",
427+
dialect: "sqlite",
423428
};
424429

425430
const table: DataTable = {
@@ -476,6 +481,7 @@ describe("add table", () => {
476481
engine: "conn1" as ConnectionName,
477482
database: "db1",
478483
schema: "non_existent",
484+
dialect: "sqlite",
479485
});
480486

481487
const conn1 = newState.connectionsMap.get("conn1" as ConnectionName);

frontend/src/core/datasets/data-source-connections.ts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ export interface SQLTableContext {
4747
engine: string;
4848
database: string;
4949
schema: string;
50+
dialect: string;
5051
defaultSchema?: string | null;
5152
defaultDatabase?: string | null;
5253
}

0 commit comments

Comments
 (0)