Skip to content

Commit 3e51313

Browse files
committed
fix: ensure we are fully backwards compatible
1 parent 4a7f0df commit 3e51313

File tree

13 files changed

+136
-108
lines changed

13 files changed

+136
-108
lines changed

.gitignore

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
bin
2-
*.wasm
32

43
# Devenv
54
.envrc

internal/config.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ package python
44
type Config struct {
55
EmitAsync bool `json:"emit_async"` // Emits async code instead of sync
66
EmitExactTableNames bool `json:"emit_exact_table_names"`
7-
EmitGenerators bool `json:"emit_generators"` // Will we use generators or lists, defaults to true
7+
EmitGenerators bool `json:"emit_generators"` // Will we use generators or lists, defaults to false
88
EmitModule bool `json:"emit_module"` // If true emits functions in module, else wraps in a class.
99
EmitPydanticModels bool `json:"emit_pydantic_models"`
1010
EmitSyncQuerier bool `json:"emit_sync_querier"` // DEPRECATED ALIAS FOR: emit_type = 'class', emit_generators = True

internal/endtoend/endtoend_test.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -100,8 +100,9 @@ func TestGenerate(t *testing.T) {
100100
cmd := exec.Command(sqlc, "diff")
101101
cmd.Dir = dir
102102
got, err := cmd.CombinedOutput()
103+
// TODO: We are diffing patches! Does this make sense and what should we provide to the end user?
103104
if diff := cmp.Diff(string(want), string(got)); diff != "" {
104-
t.Errorf("sqlc diff mismatch (-want +got):\n%s", diff)
105+
t.Errorf("sqlc diff mismatch (-want +got):\n%s", string(got))
105106
}
106107
if len(want) == 0 && err != nil {
107108
t.Error(err)

internal/endtoend/testdata/emit_pydantic_models/sqlc.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ plugins:
33
- name: py
44
wasm:
55
url: file://../../../../bin/sqlc-gen-python.wasm
6-
sha256: "a6c5d174c407007c3717eea36ff0882744346e6ba991f92f71d6ab2895204c0e"
6+
sha256: "c97fad53818679a948c68f3ffe84530d7ca4999f636d3f3d89202c6c08ee224d"
77
sql:
88
- schema: schema.sql
99
queries: query.sql

internal/endtoend/testdata/exec_result/sqlc.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ plugins:
33
- name: py
44
wasm:
55
url: file://../../../../bin/sqlc-gen-python.wasm
6-
sha256: "a6c5d174c407007c3717eea36ff0882744346e6ba991f92f71d6ab2895204c0e"
6+
sha256: "c97fad53818679a948c68f3ffe84530d7ca4999f636d3f3d89202c6c08ee224d"
77
sql:
88
- schema: schema.sql
99
queries: query.sql

internal/endtoend/testdata/exec_rows/sqlc.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ plugins:
33
- name: py
44
wasm:
55
url: file://../../../../bin/sqlc-gen-python.wasm
6-
sha256: "a6c5d174c407007c3717eea36ff0882744346e6ba991f92f71d6ab2895204c0e"
6+
sha256: "c97fad53818679a948c68f3ffe84530d7ca4999f636d3f3d89202c6c08ee224d"
77
sql:
88
- schema: schema.sql
99
queries: query.sql

internal/endtoend/testdata/inflection_exclude_table_names/sqlc.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ plugins:
33
- name: py
44
wasm:
55
url: file://../../../../bin/sqlc-gen-python.wasm
6-
sha256: "a6c5d174c407007c3717eea36ff0882744346e6ba991f92f71d6ab2895204c0e"
6+
sha256: "c97fad53818679a948c68f3ffe84530d7ca4999f636d3f3d89202c6c08ee224d"
77
sql:
88
- schema: schema.sql
99
queries: query.sql

internal/endtoend/testdata/query_parameter_limit_two/sqlc.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ plugins:
33
- name: py
44
wasm:
55
url: file://../../../../bin/sqlc-gen-python.wasm
6-
sha256: "a6c5d174c407007c3717eea36ff0882744346e6ba991f92f71d6ab2895204c0e"
6+
sha256: "c97fad53818679a948c68f3ffe84530d7ca4999f636d3f3d89202c6c08ee224d"
77
sql:
88
- schema: schema.sql
99
queries: query.sql

internal/endtoend/testdata/query_parameter_limit_undefined/sqlc.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ plugins:
33
- name: py
44
wasm:
55
url: file://../../../../bin/sqlc-gen-python.wasm
6-
sha256: "a6c5d174c407007c3717eea36ff0882744346e6ba991f92f71d6ab2895204c0e"
6+
sha256: "c97fad53818679a948c68f3ffe84530d7ca4999f636d3f3d89202c6c08ee224d"
77
sql:
88
- schema: schema.sql
99
queries: query.sql

internal/endtoend/testdata/query_parameter_limit_zero/sqlc.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ plugins:
33
- name: py
44
wasm:
55
url: file://../../../../bin/sqlc-gen-python.wasm
6-
sha256: "a6c5d174c407007c3717eea36ff0882744346e6ba991f92f71d6ab2895204c0e"
6+
sha256: "c97fad53818679a948c68f3ffe84530d7ca4999f636d3f3d89202c6c08ee224d"
77
sql:
88
- schema: schema.sql
99
queries: query.sql

internal/endtoend/testdata/query_parameter_no_limit/sqlc.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ plugins:
33
- name: py
44
wasm:
55
url: file://../../../../bin/sqlc-gen-python.wasm
6-
sha256: "a6c5d174c407007c3717eea36ff0882744346e6ba991f92f71d6ab2895204c0e"
6+
sha256: "c97fad53818679a948c68f3ffe84530d7ca4999f636d3f3d89202c6c08ee224d"
77
sql:
88
- schema: schema.sql
99
queries: query.sql

internal/gen.go

Lines changed: 119 additions & 93 deletions
Original file line numberDiff line numberDiff line change
@@ -845,7 +845,7 @@ func querierClassDef(name string, connectionAnnotation *pyast.Node) *pyast.Class
845845
Arg: "self",
846846
},
847847
{
848-
Arg: "connection",
848+
Arg: "conn",
849849
Annotation: connectionAnnotation,
850850
},
851851
},
@@ -855,9 +855,9 @@ func querierClassDef(name string, connectionAnnotation *pyast.Node) *pyast.Class
855855
Node: &pyast.Node_Assign{
856856
Assign: &pyast.Assign{
857857
Targets: []*pyast.Node{
858-
poet.Attribute(poet.Name("self"), "_connection"),
858+
poet.Attribute(poet.Name("self"), "_conn"),
859859
},
860-
Value: poet.Name("connection"),
860+
Value: poet.Name("conn"),
861861
},
862862
},
863863
},
@@ -869,80 +869,12 @@ func querierClassDef(name string, connectionAnnotation *pyast.Node) *pyast.Class
869869
}
870870
}
871871

872-
func buildQueryTree(ctx *pyTmplCtx, i *importer, source string) *pyast.Node {
873-
mod := moduleNode(ctx.SqlcVersion, source)
874-
std, pkg := i.queryImportSpecs(source)
875-
mod.Body = append(mod.Body, buildImportGroup(std), buildImportGroup(pkg))
876-
mod.Body = append(mod.Body, &pyast.Node{
877-
Node: &pyast.Node_ImportGroup{
878-
ImportGroup: &pyast.ImportGroup{
879-
Imports: []*pyast.Node{
880-
{
881-
Node: &pyast.Node_ImportFrom{
882-
ImportFrom: &pyast.ImportFrom{
883-
Module: ctx.C.Package,
884-
Names: []*pyast.Node{
885-
poet.Alias("models"),
886-
},
887-
},
888-
},
889-
},
890-
},
891-
},
892-
},
893-
})
894-
895-
for _, q := range ctx.Queries {
896-
if !ctx.OutputQuery(q.SourceName) {
897-
continue
898-
}
899-
queryText := fmt.Sprintf("-- name: %s \\\\%s\n%s\n", q.MethodName, q.Cmd, q.SQL)
900-
mod.Body = append(mod.Body, assignNode(q.ConstantName, poet.Constant(queryText)))
901-
902-
// Generate params structures
903-
for _, arg := range q.Args {
904-
if arg.EmitStruct() {
905-
var def *pyast.ClassDef
906-
if ctx.C.EmitPydanticModels {
907-
def = pydanticNode(arg.Struct.Name)
908-
} else {
909-
def = dataclassNode(arg.Struct.Name)
910-
}
911-
912-
// We need a copy as we want to make sure that nullable params are at the end of the dataclass
913-
fields := make([]Field, len(arg.Struct.Fields))
914-
copy(fields, arg.Struct.Fields)
915-
916-
// Place all nullable fields at the end and try to keep the original order as much as possible
917-
sort.SliceStable(fields, func(i int, j int) bool {
918-
return (fields[j].Type.IsNull && fields[i].Type.IsNull != fields[j].Type.IsNull) || i < j
919-
})
920-
921-
for _, f := range fields {
922-
def.Body = append(def.Body, fieldNode(f, true))
923-
}
924-
mod.Body = append(mod.Body, poet.Node(def))
925-
}
926-
}
927-
if q.Ret.EmitStruct() {
928-
var def *pyast.ClassDef
929-
if ctx.C.EmitPydanticModels {
930-
def = pydanticNode(q.Ret.Struct.Name)
931-
} else {
932-
def = dataclassNode(q.Ret.Struct.Name)
933-
}
934-
for _, f := range q.Ret.Struct.Fields {
935-
def.Body = append(def.Body, fieldNode(f, false))
936-
}
937-
mod.Body = append(mod.Body, poet.Node(def))
938-
}
939-
}
940-
872+
func buildQuerierClass(ctx *pyTmplCtx, isAsync bool) []*pyast.Node {
941873
functions := make([]*pyast.Node, 0, 10)
942874

943875
// Define some reused types based on async or sync code
944876
var connectionAnnotation *pyast.Node
945-
if ctx.C.EmitAsync {
877+
if isAsync {
946878
connectionAnnotation = typeRefNode("sqlalchemy", "ext", "asyncio", "AsyncConnection")
947879
} else {
948880
connectionAnnotation = typeRefNode("sqlalchemy", "engine", "Connection")
@@ -951,9 +883,9 @@ func buildQueryTree(ctx *pyTmplCtx, i *importer, source string) *pyast.Node {
951883
// We need to figure out how to access the SQLAlchemy connectionVar object
952884
var connectionVar *pyast.Node
953885
if ctx.C.EmitModule {
954-
connectionVar = poet.Name("connection")
886+
connectionVar = poet.Name("conn")
955887
} else {
956-
connectionVar = poet.Attribute(poet.Name("self"), "_connection")
888+
connectionVar = poet.Attribute(poet.Name("self"), "_conn")
957889
}
958890

959891
// We loop through all queries and build our query functions
@@ -968,7 +900,7 @@ func buildQueryTree(ctx *pyTmplCtx, i *importer, source string) *pyast.Node {
968900

969901
if ctx.C.EmitModule {
970902
f.Args.Args = append(f.Args.Args, &pyast.Arg{
971-
Arg: "connection",
903+
Arg: "conn",
972904
Annotation: connectionAnnotation,
973905
})
974906
} else {
@@ -980,7 +912,7 @@ func buildQueryTree(ctx *pyTmplCtx, i *importer, source string) *pyast.Node {
980912
q.AddArgs(f.Args)
981913

982914
exec := poet.Expr(connMethodNode(poet.Attribute(connectionVar, "execute"), q.ConstantName, q.ArgDictNode()))
983-
if ctx.C.EmitAsync {
915+
if isAsync {
984916
exec = poet.Await(exec)
985917
}
986918

@@ -1017,7 +949,7 @@ func buildQueryTree(ctx *pyTmplCtx, i *importer, source string) *pyast.Node {
1017949
f.Returns = subscriptNode("Optional", q.Ret.Annotation())
1018950
case ":many":
1019951
if ctx.C.EmitGenerators {
1020-
if ctx.C.EmitAsync {
952+
if isAsync {
1021953
// If we are using generators and async, we are switching to stream implementation
1022954
exec = poet.Await(connMethodNode(poet.Attribute(connectionVar, "stream"), q.ConstantName, q.ArgDictNode()))
1023955

@@ -1094,8 +1026,8 @@ func buildQueryTree(ctx *pyTmplCtx, i *importer, source string) *pyast.Node {
10941026
panic("unknown cmd " + q.Cmd)
10951027
}
10961028

1097-
// If we are emitting async code, we have to swap our sync func for an async one and fix the connection annotation.
1098-
if ctx.C.EmitAsync {
1029+
// If we are emitting async code, we have to swap our sync func for an async one and fix the conn annotation.
1030+
if isAsync {
10991031
functions = append(functions, poet.Node(&pyast.AsyncFunctionDef{
11001032
Name: f.Name,
11011033
Args: f.Args,
@@ -1107,13 +1039,115 @@ func buildQueryTree(ctx *pyTmplCtx, i *importer, source string) *pyast.Node {
11071039
}
11081040
}
11091041

1110-
// Lets see how to add all functions
1042+
return functions
1043+
}
1044+
1045+
func buildQueryTree(ctx *pyTmplCtx, i *importer, source string) *pyast.Node {
1046+
mod := moduleNode(ctx.SqlcVersion, source)
1047+
std, pkg := i.queryImportSpecs(source)
1048+
mod.Body = append(mod.Body, buildImportGroup(std), buildImportGroup(pkg))
1049+
mod.Body = append(mod.Body, &pyast.Node{
1050+
Node: &pyast.Node_ImportGroup{
1051+
ImportGroup: &pyast.ImportGroup{
1052+
Imports: []*pyast.Node{
1053+
{
1054+
Node: &pyast.Node_ImportFrom{
1055+
ImportFrom: &pyast.ImportFrom{
1056+
Module: ctx.C.Package,
1057+
Names: []*pyast.Node{
1058+
poet.Alias("models"),
1059+
},
1060+
},
1061+
},
1062+
},
1063+
},
1064+
},
1065+
},
1066+
})
1067+
1068+
for _, q := range ctx.Queries {
1069+
if !ctx.OutputQuery(q.SourceName) {
1070+
continue
1071+
}
1072+
queryText := fmt.Sprintf("-- name: %s \\\\%s\n%s\n", q.MethodName, q.Cmd, q.SQL)
1073+
mod.Body = append(mod.Body, assignNode(q.ConstantName, poet.Constant(queryText)))
1074+
1075+
// Generate params structures
1076+
for _, arg := range q.Args {
1077+
if arg.EmitStruct() {
1078+
var def *pyast.ClassDef
1079+
if ctx.C.EmitPydanticModels {
1080+
def = pydanticNode(arg.Struct.Name)
1081+
} else {
1082+
def = dataclassNode(arg.Struct.Name)
1083+
}
1084+
1085+
// We need a copy as we want to make sure that nullable params are at the end of the dataclass
1086+
fields := make([]Field, len(arg.Struct.Fields))
1087+
copy(fields, arg.Struct.Fields)
1088+
1089+
// Place all nullable fields at the end and try to keep the original order as much as possible
1090+
sort.SliceStable(fields, func(i int, j int) bool {
1091+
return (fields[j].Type.IsNull && fields[i].Type.IsNull != fields[j].Type.IsNull) || i < j
1092+
})
1093+
1094+
for _, f := range fields {
1095+
def.Body = append(def.Body, fieldNode(f, true))
1096+
}
1097+
mod.Body = append(mod.Body, poet.Node(def))
1098+
}
1099+
}
1100+
if q.Ret.EmitStruct() {
1101+
var def *pyast.ClassDef
1102+
if ctx.C.EmitPydanticModels {
1103+
def = pydanticNode(q.Ret.Struct.Name)
1104+
} else {
1105+
def = dataclassNode(q.Ret.Struct.Name)
1106+
}
1107+
for _, f := range q.Ret.Struct.Fields {
1108+
def.Body = append(def.Body, fieldNode(f, false))
1109+
}
1110+
mod.Body = append(mod.Body, poet.Node(def))
1111+
}
1112+
}
1113+
1114+
// Lets see how to add all functions, we can either add them to the module directly or from within a class.
11111115
if ctx.C.EmitModule {
1112-
mod.Body = append(mod.Body, functions...)
1116+
mod.Body = append(mod.Body, buildQuerierClass(ctx, ctx.C.EmitAsync)...)
11131117
} else {
1114-
cls := querierClassDef("Querier", connectionAnnotation)
1115-
cls.Body = append(cls.Body, functions...)
1116-
mod.Body = append(mod.Body, poet.Node(cls))
1118+
asyncConnectionAnnotation := typeRefNode("sqlalchemy", "ext", "asyncio", "AsyncConnection")
1119+
syncConnectionAnnotation := typeRefNode("sqlalchemy", "engine", "Connection")
1120+
1121+
// NOTE: For backwards compatibility we support generating multiple classes, but this is definitely suboptimal.
1122+
// It is much better to use the `emit_async: bool` config to select what type to emit
1123+
if ctx.C.EmitAsyncQuerier || ctx.C.EmitSyncQuerier {
1124+
1125+
// When using these backwards compatible settings we force behavior!
1126+
ctx.C.EmitModule = false
1127+
ctx.C.EmitGenerators = true
1128+
1129+
if ctx.C.EmitSyncQuerier {
1130+
cls := querierClassDef("Querier", syncConnectionAnnotation)
1131+
cls.Body = append(cls.Body, buildQuerierClass(ctx, false)...)
1132+
mod.Body = append(mod.Body, poet.Node(cls))
1133+
}
1134+
if ctx.C.EmitAsyncQuerier {
1135+
cls := querierClassDef("AsyncQuerier", asyncConnectionAnnotation)
1136+
cls.Body = append(cls.Body, buildQuerierClass(ctx, true)...)
1137+
mod.Body = append(mod.Body, poet.Node(cls))
1138+
}
1139+
} else {
1140+
var connectionAnnotation *pyast.Node
1141+
if ctx.C.EmitAsync {
1142+
connectionAnnotation = asyncConnectionAnnotation
1143+
} else {
1144+
connectionAnnotation = syncConnectionAnnotation
1145+
}
1146+
1147+
cls := querierClassDef("Querier", connectionAnnotation)
1148+
cls.Body = append(cls.Body, buildQuerierClass(ctx, ctx.C.EmitAsync)...)
1149+
mod.Body = append(mod.Body, poet.Node(cls))
1150+
}
11171151
}
11181152

11191153
return poet.Node(mod)
@@ -1150,14 +1184,6 @@ func Generate(_ context.Context, req *plugin.GenerateRequest) (*plugin.GenerateR
11501184
}
11511185
}
11521186

1153-
// TODO: Remove when when we drop support for deprecated EmitSyncQuerier and EmitAsyncQuerier options
1154-
if conf.EmitAsyncQuerier || conf.EmitSyncQuerier {
1155-
conf.EmitModule = false
1156-
conf.EmitGenerators = true
1157-
conf.EmitAsync = conf.EmitAsyncQuerier
1158-
// TODO/NOTE: We now have a breaking change because we emit only one flavor. What do we want to do?
1159-
}
1160-
11611187
enums := buildEnums(req)
11621188
models := buildModels(conf, req)
11631189
queries, err := buildQueries(conf, req, models)

0 commit comments

Comments
 (0)