|
16 | 16 | FIM_SUFFIX_TAG = "<|fim_suffix|>" |
17 | 17 | FIM_MIDDLE_TAG = "<|fim_middle|>" |
18 | 18 |
|
19 | | -language_rules = { |
| 19 | +LANGUAGES: list[Language] = ["python", "sql", "markdown"] |
| 20 | +language_rules: dict[Language, list[str]] = { |
20 | 21 | "python": [ |
21 | 22 | "For matplotlib: use plt.gca() as the last expression instead of plt.show().", |
22 | 23 | "For plotly: return the figure object directly.", |
|
33 | 34 | } |
34 | 35 |
|
35 | 36 |
|
| 37 | +language_rules_multiple_cells: dict[Language, list[str]] = { |
| 38 | + "sql": [ |
| 39 | + 'SQL cells start with df = mo.sql(f"""<your query>""") for DuckDB, or df = mo.sql(f"""<your query>""", engine=engine) for other SQL engines.', |
| 40 | + "This will automatically display the result in the UI. You do not need to return the dataframe in the cell.", |
| 41 | + "The SQL must use the syntax of the database engine specified in the `engine` variable. If no engine, then use duckdb syntax.", |
| 42 | + ] |
| 43 | +} |
| 44 | + |
| 45 | + |
36 | 46 | def _format_schema_info(tables: Optional[list[SchemaTable]]) -> str: |
37 | 47 | """Helper to format schema information from context""" |
38 | 48 | if not tables: |
@@ -166,10 +176,13 @@ def get_refactor_or_insert_notebook_cell_system_prompt( |
166 | 176 |
|
167 | 177 | if support_multiple_cells: |
168 | 178 | # Add all language rules for multi-cell scenarios |
169 | | - for lang in language_rules: |
170 | | - if len(language_rules[lang]) > 0: |
| 179 | + for lang in LANGUAGES: |
| 180 | + language_rule = language_rules_multiple_cells.get( |
| 181 | + lang, language_rules.get(lang, []) |
| 182 | + ) |
| 183 | + if language_rule: |
171 | 184 | system_prompt += ( |
172 | | - f"\n\n## Rules for {lang}:\n{_rules(language_rules[lang])}" |
| 185 | + f"\n\n## Rules for {lang}:\n{_rules(language_rule)}" |
173 | 186 | ) |
174 | 187 | elif language in language_rules and language_rules[language]: |
175 | 188 | system_prompt += ( |
|
0 commit comments