Skip to content

Commit c8bf991

Browse files
zelchZephaniah E. Loss-Cutler-Hullkyleconroy
authored
Handle convertPatternInExpr. (#1088)
* Handle convertPatternInExpr. This can be handled just by calling convert on the subquery, and it is necessary for some kinds of queries. Like the following, without this we don't pickup the parameter, and thus can't handle the query at all: select * from foo where field in (select substring(?, 1, seq) from seq_1_to_19) * Add ast.In This is used for queries like the following: SELECT a, b from foo where foo.a in (select a from bar where bar.b = ?); SELECT a, b from foo where foo.a in (?, ?); SELECT a, b from foo where foo.a not in (?, ?); SELECT a, b from foo where ? in (foo.a, foo.b); The support is heavily based on that found in the fork found at: https://github.com/xiazemin/sqlc.git However it should be noted that this is not just a straight copy and paste job, there are several pieces from that branch for features that don't exist in the main sqlc repository, the sqlc repository has changed a bit since the fork, and the fork version does not handle the latter case at all. * Adapt convertPatternInExpr to use ast.In This is used for queries like the following: SELECT a, b from foo where foo.a in (select a from bar where bar.b = ?); SELECT a, b from foo where foo.a in (?, ?); SELECT a, b from foo where foo.a not in (?, ?); SELECT a, b from foo where ? in (foo.a, foo.b); The support is heavily based on that found in the fork found at: https://github.com/xiazemin/sqlc.git However it should be noted that this is not just a straight copy and paste job, there are several pieces from that branch for features that don't exist in the main sqlc repository, the sqlc repository has changed a bit since the fork, and the fork version does not handle the latter case at all. * Add test cases for convertPatternInExpr. This covers all the cases I can think of, though if I missed some, it's quite possible that they are broken in the code as well as absent in the tests. * Only check for duplicate c.id on query outputs. That is, we want to deduplicate the same column from showing up multiple times, and thus we want to avoid adding a suffix, when we are getting the results from a query, however we most definitely do not want to do this for query parameters. Especially since the logic for handling this in the query parameters is missing the deduplication, but also because the database itself expects the same number of parameters as we have placeholders in the query. * Make seen use fieldName instead of colName. fieldName is after any mangilng to make it fit go variable standards, and since that is the name that we are actually using, that is the name that we should be checking for conflicts. This only matters when we have multiple fields that differ only by things that are changed when we are converting to the go variable standards. * Update test files for A_2 vs A. This should now build a little better, now that we have fixed the generator. * Make seen use the pre-modification fieldName. Oops, this fixes a bug introduced in 885bd3d. * DefaultSchema support for ast.In. Needed now that we're merging in master. Co-authored-by: Zephaniah E. Loss-Cutler-Hull <warp@aehallh.com> Co-authored-by: Kyle Conroy <kyle@conroy.org>
1 parent c2dcd56 commit c8bf991

File tree

11 files changed

+375
-7
lines changed

11 files changed

+375
-7
lines changed

internal/codegen/golang/result.go

+8-6
Original file line numberDiff line numberDiff line change
@@ -181,7 +181,7 @@ func buildQueries(r *compiler.Result, settings config.CombinedSettings, structs
181181
gq.Arg = QueryValue{
182182
Emit: true,
183183
Name: "arg",
184-
Struct: columnsToStruct(r, gq.MethodName+"Params", cols, settings),
184+
Struct: columnsToStruct(r, gq.MethodName+"Params", cols, settings, false),
185185
SQLPackage: sqlpkg,
186186
}
187187
}
@@ -225,7 +225,7 @@ func buildQueries(r *compiler.Result, settings config.CombinedSettings, structs
225225
Column: c,
226226
})
227227
}
228-
gs = columnsToStruct(r, gq.MethodName+"Row", columns, settings)
228+
gs = columnsToStruct(r, gq.MethodName+"Row", columns, settings, true)
229229
emit = true
230230
}
231231
gq.Ret = QueryValue{
@@ -249,7 +249,7 @@ func buildQueries(r *compiler.Result, settings config.CombinedSettings, structs
249249
// JSON tags: count, count_2, count_2
250250
//
251251
// This is unlikely to happen, so don't fix it yet
252-
func columnsToStruct(r *compiler.Result, name string, columns []goColumn, settings config.CombinedSettings) *Struct {
252+
func columnsToStruct(r *compiler.Result, name string, columns []goColumn, settings config.CombinedSettings, useID bool) *Struct {
253253
gs := Struct{
254254
Name: name,
255255
}
@@ -259,12 +259,13 @@ func columnsToStruct(r *compiler.Result, name string, columns []goColumn, settin
259259
colName := columnName(c.Column, i)
260260
tagName := colName
261261
fieldName := StructName(colName, settings)
262+
baseFieldName := fieldName
262263
// Track suffixes by the ID of the column, so that columns referring to the same numbered parameter can be
263264
// reused.
264265
suffix := 0
265-
if o, ok := suffixes[c.id]; ok {
266+
if o, ok := suffixes[c.id]; ok && useID {
266267
suffix = o
267-
} else if v := seen[colName]; v > 0 {
268+
} else if v := seen[fieldName]; v > 0 {
268269
suffix = v + 1
269270
}
270271
suffixes[c.id] = suffix
@@ -284,7 +285,8 @@ func columnsToStruct(r *compiler.Result, name string, columns []goColumn, settin
284285
Type: goType(r, c.Column, settings),
285286
Tags: tags,
286287
})
287-
seen[colName]++
288+
seen[baseFieldName]++
288289
}
290+
289291
return &gs
290292
}

internal/compiler/find_params.go

+19
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,25 @@ func (p paramSearch) Visit(node ast.Node) astutils.Visitor {
142142
p.seen[n.Location] = struct{}{}
143143
}
144144
return nil
145+
146+
case *ast.In:
147+
if n.Sel == nil {
148+
p.parent = node
149+
} else {
150+
if sel, ok := n.Sel.(*ast.SelectStmt); ok && sel.FromClause != nil {
151+
from := sel.FromClause
152+
if schema, ok := from.Items[0].(*ast.RangeVar); ok && schema != nil {
153+
p.rangeVar = &ast.RangeVar{
154+
Catalogname: schema.Catalogname,
155+
Schemaname: schema.Schemaname,
156+
Relname: schema.Relname,
157+
}
158+
}
159+
}
160+
}
161+
if _, ok := n.Expr.(*ast.ParamRef); ok {
162+
p.Visit(n.Expr)
163+
}
145164
}
146165
return p
147166
}

internal/compiler/resolve.go

+98
Original file line numberDiff line numberDiff line change
@@ -370,6 +370,104 @@ func resolveCatalogRefs(c *catalog.Catalog, qc *QueryCatalog, rvs []*ast.RangeVa
370370
case *ast.ParamRef:
371371
a = append(a, Parameter{Number: ref.ref.Number})
372372

373+
case *ast.In:
374+
if n == nil || n.List == nil {
375+
fmt.Println("ast.In is nil")
376+
continue
377+
}
378+
379+
number := 0
380+
if pr, ok := n.List[0].(*ast.ParamRef); ok {
381+
number = pr.Number
382+
}
383+
384+
location := 0
385+
var key, alias string
386+
var items []string
387+
388+
if left, ok := n.Expr.(*ast.ColumnRef); ok {
389+
location = left.Location
390+
items = stringSlice(left.Fields)
391+
} else if left, ok := n.Expr.(*ast.ParamRef); ok {
392+
if len(n.List) <= 0 {
393+
continue
394+
}
395+
if right, ok := n.List[0].(*ast.ColumnRef); ok {
396+
location = left.Location
397+
items = stringSlice(right.Fields)
398+
} else {
399+
continue
400+
}
401+
} else {
402+
continue
403+
}
404+
405+
switch len(items) {
406+
case 1:
407+
key = items[0]
408+
case 2:
409+
alias = items[0]
410+
key = items[1]
411+
default:
412+
panic("too many field items: " + strconv.Itoa(len(items)))
413+
}
414+
415+
var found int
416+
if n.Sel == nil {
417+
search := tables
418+
if alias != "" {
419+
if original, ok := aliasMap[alias]; ok {
420+
search = []*ast.TableName{original}
421+
} else {
422+
for _, fqn := range tables {
423+
if fqn.Name == alias {
424+
search = []*ast.TableName{fqn}
425+
}
426+
}
427+
}
428+
}
429+
430+
for _, table := range search {
431+
schema := table.Schema
432+
if schema == "" {
433+
schema = c.DefaultSchema
434+
}
435+
if c, ok := typeMap[schema][table.Name][key]; ok {
436+
found += 1
437+
if ref.name != "" {
438+
key = ref.name
439+
}
440+
a = append(a, Parameter{
441+
Number: number,
442+
Column: &Column{
443+
Name: parameterName(ref.ref.Number, key),
444+
DataType: dataType(&c.Type),
445+
NotNull: c.IsNotNull,
446+
IsArray: c.IsArray,
447+
Table: table,
448+
},
449+
})
450+
}
451+
}
452+
} else {
453+
fmt.Println("------------------------")
454+
}
455+
456+
if found == 0 {
457+
return nil, &sqlerr.Error{
458+
Code: "42703",
459+
Message: fmt.Sprintf("396: column \"%s\" does not exist", key),
460+
Location: location,
461+
}
462+
}
463+
if found > 1 {
464+
return nil, &sqlerr.Error{
465+
Code: "42703",
466+
Message: fmt.Sprintf("in same name column reference \"%s\" is ambiguous", key),
467+
Location: location,
468+
}
469+
}
470+
373471
default:
374472
fmt.Printf("unsupported reference type: %T", n)
375473
}

internal/endtoend/testdata/pattern_in_expr/mysql/go/db.go

+29
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

internal/endtoend/testdata/pattern_in_expr/mysql/go/models.go

+17
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

internal/endtoend/testdata/pattern_in_expr/mysql/go/query.sql.go

+127
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
CREATE TABLE foo (a text, b text);
2+
CREATE TABLE bar (a text, b text);
3+
/* name: FooByBarB :many */
4+
SELECT a, b from foo where foo.a in (select a from bar where bar.b = ?);
5+
6+
/* name: FooByList :many */
7+
SELECT a, b from foo where foo.a in (?, ?);
8+
9+
/* name: FooByNotList :many */
10+
SELECT a, b from foo where foo.a not in (?, ?);
11+
12+
/* name: FooByParamList :many */
13+
SELECT a, b from foo where ? in (foo.a, foo.b);
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
{
2+
"version": "1",
3+
"packages": [
4+
{
5+
"engine": "mysql",
6+
"path": "go",
7+
"name": "querytest",
8+
"schema": "query.sql",
9+
"queries": "query.sql"
10+
}
11+
]
12+
}

0 commit comments

Comments
 (0)