@@ -845,7 +845,7 @@ func querierClassDef(name string, connectionAnnotation *pyast.Node) *pyast.Class
845
845
Arg : "self" ,
846
846
},
847
847
{
848
- Arg : "connection " ,
848
+ Arg : "conn " ,
849
849
Annotation : connectionAnnotation ,
850
850
},
851
851
},
@@ -855,9 +855,9 @@ func querierClassDef(name string, connectionAnnotation *pyast.Node) *pyast.Class
855
855
Node : & pyast.Node_Assign {
856
856
Assign : & pyast.Assign {
857
857
Targets : []* pyast.Node {
858
- poet .Attribute (poet .Name ("self" ), "_connection " ),
858
+ poet .Attribute (poet .Name ("self" ), "_conn " ),
859
859
},
860
- Value : poet .Name ("connection " ),
860
+ Value : poet .Name ("conn " ),
861
861
},
862
862
},
863
863
},
@@ -869,80 +869,12 @@ func querierClassDef(name string, connectionAnnotation *pyast.Node) *pyast.Class
869
869
}
870
870
}
871
871
872
- func buildQueryTree (ctx * pyTmplCtx , i * importer , source string ) * pyast.Node {
873
- mod := moduleNode (ctx .SqlcVersion , source )
874
- std , pkg := i .queryImportSpecs (source )
875
- mod .Body = append (mod .Body , buildImportGroup (std ), buildImportGroup (pkg ))
876
- mod .Body = append (mod .Body , & pyast.Node {
877
- Node : & pyast.Node_ImportGroup {
878
- ImportGroup : & pyast.ImportGroup {
879
- Imports : []* pyast.Node {
880
- {
881
- Node : & pyast.Node_ImportFrom {
882
- ImportFrom : & pyast.ImportFrom {
883
- Module : ctx .C .Package ,
884
- Names : []* pyast.Node {
885
- poet .Alias ("models" ),
886
- },
887
- },
888
- },
889
- },
890
- },
891
- },
892
- },
893
- })
894
-
895
- for _ , q := range ctx .Queries {
896
- if ! ctx .OutputQuery (q .SourceName ) {
897
- continue
898
- }
899
- queryText := fmt .Sprintf ("-- name: %s \\ \\ %s\n %s\n " , q .MethodName , q .Cmd , q .SQL )
900
- mod .Body = append (mod .Body , assignNode (q .ConstantName , poet .Constant (queryText )))
901
-
902
- // Generate params structures
903
- for _ , arg := range q .Args {
904
- if arg .EmitStruct () {
905
- var def * pyast.ClassDef
906
- if ctx .C .EmitPydanticModels {
907
- def = pydanticNode (arg .Struct .Name )
908
- } else {
909
- def = dataclassNode (arg .Struct .Name )
910
- }
911
-
912
- // We need a copy as we want to make sure that nullable params are at the end of the dataclass
913
- fields := make ([]Field , len (arg .Struct .Fields ))
914
- copy (fields , arg .Struct .Fields )
915
-
916
- // Place all nullable fields at the end and try to keep the original order as much as possible
917
- sort .SliceStable (fields , func (i int , j int ) bool {
918
- return (fields [j ].Type .IsNull && fields [i ].Type .IsNull != fields [j ].Type .IsNull ) || i < j
919
- })
920
-
921
- for _ , f := range fields {
922
- def .Body = append (def .Body , fieldNode (f , true ))
923
- }
924
- mod .Body = append (mod .Body , poet .Node (def ))
925
- }
926
- }
927
- if q .Ret .EmitStruct () {
928
- var def * pyast.ClassDef
929
- if ctx .C .EmitPydanticModels {
930
- def = pydanticNode (q .Ret .Struct .Name )
931
- } else {
932
- def = dataclassNode (q .Ret .Struct .Name )
933
- }
934
- for _ , f := range q .Ret .Struct .Fields {
935
- def .Body = append (def .Body , fieldNode (f , false ))
936
- }
937
- mod .Body = append (mod .Body , poet .Node (def ))
938
- }
939
- }
940
-
872
+ func buildQuerierClass (ctx * pyTmplCtx , isAsync bool ) []* pyast.Node {
941
873
functions := make ([]* pyast.Node , 0 , 10 )
942
874
943
875
// Define some reused types based on async or sync code
944
876
var connectionAnnotation * pyast.Node
945
- if ctx . C . EmitAsync {
877
+ if isAsync {
946
878
connectionAnnotation = typeRefNode ("sqlalchemy" , "ext" , "asyncio" , "AsyncConnection" )
947
879
} else {
948
880
connectionAnnotation = typeRefNode ("sqlalchemy" , "engine" , "Connection" )
@@ -951,9 +883,9 @@ func buildQueryTree(ctx *pyTmplCtx, i *importer, source string) *pyast.Node {
951
883
// We need to figure out how to access the SQLAlchemy connectionVar object
952
884
var connectionVar * pyast.Node
953
885
if ctx .C .EmitModule {
954
- connectionVar = poet .Name ("connection " )
886
+ connectionVar = poet .Name ("conn " )
955
887
} else {
956
- connectionVar = poet .Attribute (poet .Name ("self" ), "_connection " )
888
+ connectionVar = poet .Attribute (poet .Name ("self" ), "_conn " )
957
889
}
958
890
959
891
// We loop through all queries and build our query functions
@@ -968,7 +900,7 @@ func buildQueryTree(ctx *pyTmplCtx, i *importer, source string) *pyast.Node {
968
900
969
901
if ctx .C .EmitModule {
970
902
f .Args .Args = append (f .Args .Args , & pyast.Arg {
971
- Arg : "connection " ,
903
+ Arg : "conn " ,
972
904
Annotation : connectionAnnotation ,
973
905
})
974
906
} else {
@@ -980,7 +912,7 @@ func buildQueryTree(ctx *pyTmplCtx, i *importer, source string) *pyast.Node {
980
912
q .AddArgs (f .Args )
981
913
982
914
exec := poet .Expr (connMethodNode (poet .Attribute (connectionVar , "execute" ), q .ConstantName , q .ArgDictNode ()))
983
- if ctx . C . EmitAsync {
915
+ if isAsync {
984
916
exec = poet .Await (exec )
985
917
}
986
918
@@ -1017,7 +949,7 @@ func buildQueryTree(ctx *pyTmplCtx, i *importer, source string) *pyast.Node {
1017
949
f .Returns = subscriptNode ("Optional" , q .Ret .Annotation ())
1018
950
case ":many" :
1019
951
if ctx .C .EmitGenerators {
1020
- if ctx . C . EmitAsync {
952
+ if isAsync {
1021
953
// If we are using generators and async, we are switching to stream implementation
1022
954
exec = poet .Await (connMethodNode (poet .Attribute (connectionVar , "stream" ), q .ConstantName , q .ArgDictNode ()))
1023
955
@@ -1094,8 +1026,8 @@ func buildQueryTree(ctx *pyTmplCtx, i *importer, source string) *pyast.Node {
1094
1026
panic ("unknown cmd " + q .Cmd )
1095
1027
}
1096
1028
1097
- // If we are emitting async code, we have to swap our sync func for an async one and fix the connection annotation.
1098
- if ctx . C . EmitAsync {
1029
+ // If we are emitting async code, we have to swap our sync func for an async one and fix the conn annotation.
1030
+ if isAsync {
1099
1031
functions = append (functions , poet .Node (& pyast.AsyncFunctionDef {
1100
1032
Name : f .Name ,
1101
1033
Args : f .Args ,
@@ -1107,13 +1039,115 @@ func buildQueryTree(ctx *pyTmplCtx, i *importer, source string) *pyast.Node {
1107
1039
}
1108
1040
}
1109
1041
1110
- // Lets see how to add all functions
1042
+ return functions
1043
+ }
1044
+
1045
+ func buildQueryTree (ctx * pyTmplCtx , i * importer , source string ) * pyast.Node {
1046
+ mod := moduleNode (ctx .SqlcVersion , source )
1047
+ std , pkg := i .queryImportSpecs (source )
1048
+ mod .Body = append (mod .Body , buildImportGroup (std ), buildImportGroup (pkg ))
1049
+ mod .Body = append (mod .Body , & pyast.Node {
1050
+ Node : & pyast.Node_ImportGroup {
1051
+ ImportGroup : & pyast.ImportGroup {
1052
+ Imports : []* pyast.Node {
1053
+ {
1054
+ Node : & pyast.Node_ImportFrom {
1055
+ ImportFrom : & pyast.ImportFrom {
1056
+ Module : ctx .C .Package ,
1057
+ Names : []* pyast.Node {
1058
+ poet .Alias ("models" ),
1059
+ },
1060
+ },
1061
+ },
1062
+ },
1063
+ },
1064
+ },
1065
+ },
1066
+ })
1067
+
1068
+ for _ , q := range ctx .Queries {
1069
+ if ! ctx .OutputQuery (q .SourceName ) {
1070
+ continue
1071
+ }
1072
+ queryText := fmt .Sprintf ("-- name: %s \\ \\ %s\n %s\n " , q .MethodName , q .Cmd , q .SQL )
1073
+ mod .Body = append (mod .Body , assignNode (q .ConstantName , poet .Constant (queryText )))
1074
+
1075
+ // Generate params structures
1076
+ for _ , arg := range q .Args {
1077
+ if arg .EmitStruct () {
1078
+ var def * pyast.ClassDef
1079
+ if ctx .C .EmitPydanticModels {
1080
+ def = pydanticNode (arg .Struct .Name )
1081
+ } else {
1082
+ def = dataclassNode (arg .Struct .Name )
1083
+ }
1084
+
1085
+ // We need a copy as we want to make sure that nullable params are at the end of the dataclass
1086
+ fields := make ([]Field , len (arg .Struct .Fields ))
1087
+ copy (fields , arg .Struct .Fields )
1088
+
1089
+ // Place all nullable fields at the end and try to keep the original order as much as possible
1090
+ sort .SliceStable (fields , func (i int , j int ) bool {
1091
+ return (fields [j ].Type .IsNull && fields [i ].Type .IsNull != fields [j ].Type .IsNull ) || i < j
1092
+ })
1093
+
1094
+ for _ , f := range fields {
1095
+ def .Body = append (def .Body , fieldNode (f , true ))
1096
+ }
1097
+ mod .Body = append (mod .Body , poet .Node (def ))
1098
+ }
1099
+ }
1100
+ if q .Ret .EmitStruct () {
1101
+ var def * pyast.ClassDef
1102
+ if ctx .C .EmitPydanticModels {
1103
+ def = pydanticNode (q .Ret .Struct .Name )
1104
+ } else {
1105
+ def = dataclassNode (q .Ret .Struct .Name )
1106
+ }
1107
+ for _ , f := range q .Ret .Struct .Fields {
1108
+ def .Body = append (def .Body , fieldNode (f , false ))
1109
+ }
1110
+ mod .Body = append (mod .Body , poet .Node (def ))
1111
+ }
1112
+ }
1113
+
1114
+ // Lets see how to add all functions, we can either add them to the module directly or from within a class.
1111
1115
if ctx .C .EmitModule {
1112
- mod .Body = append (mod .Body , functions ... )
1116
+ mod .Body = append (mod .Body , buildQuerierClass ( ctx , ctx . C . EmitAsync ) ... )
1113
1117
} else {
1114
- cls := querierClassDef ("Querier" , connectionAnnotation )
1115
- cls .Body = append (cls .Body , functions ... )
1116
- mod .Body = append (mod .Body , poet .Node (cls ))
1118
+ asyncConnectionAnnotation := typeRefNode ("sqlalchemy" , "ext" , "asyncio" , "AsyncConnection" )
1119
+ syncConnectionAnnotation := typeRefNode ("sqlalchemy" , "engine" , "Connection" )
1120
+
1121
+ // NOTE: For backwards compatibility we support generating multiple classes, but this is definitely suboptimal.
1122
+ // It is much better to use the `emit_async: bool` config to select what type to emit
1123
+ if ctx .C .EmitAsyncQuerier || ctx .C .EmitSyncQuerier {
1124
+
1125
+ // When using these backwards compatible settings we force behavior!
1126
+ ctx .C .EmitModule = false
1127
+ ctx .C .EmitGenerators = true
1128
+
1129
+ if ctx .C .EmitSyncQuerier {
1130
+ cls := querierClassDef ("Querier" , syncConnectionAnnotation )
1131
+ cls .Body = append (cls .Body , buildQuerierClass (ctx , false )... )
1132
+ mod .Body = append (mod .Body , poet .Node (cls ))
1133
+ }
1134
+ if ctx .C .EmitAsyncQuerier {
1135
+ cls := querierClassDef ("AsyncQuerier" , asyncConnectionAnnotation )
1136
+ cls .Body = append (cls .Body , buildQuerierClass (ctx , true )... )
1137
+ mod .Body = append (mod .Body , poet .Node (cls ))
1138
+ }
1139
+ } else {
1140
+ var connectionAnnotation * pyast.Node
1141
+ if ctx .C .EmitAsync {
1142
+ connectionAnnotation = asyncConnectionAnnotation
1143
+ } else {
1144
+ connectionAnnotation = syncConnectionAnnotation
1145
+ }
1146
+
1147
+ cls := querierClassDef ("Querier" , connectionAnnotation )
1148
+ cls .Body = append (cls .Body , buildQuerierClass (ctx , ctx .C .EmitAsync )... )
1149
+ mod .Body = append (mod .Body , poet .Node (cls ))
1150
+ }
1117
1151
}
1118
1152
1119
1153
return poet .Node (mod )
@@ -1150,14 +1184,6 @@ func Generate(_ context.Context, req *plugin.GenerateRequest) (*plugin.GenerateR
1150
1184
}
1151
1185
}
1152
1186
1153
- // TODO: Remove when when we drop support for deprecated EmitSyncQuerier and EmitAsyncQuerier options
1154
- if conf .EmitAsyncQuerier || conf .EmitSyncQuerier {
1155
- conf .EmitModule = false
1156
- conf .EmitGenerators = true
1157
- conf .EmitAsync = conf .EmitAsyncQuerier
1158
- // TODO/NOTE: We now have a breaking change because we emit only one flavor. What do we want to do?
1159
- }
1160
-
1161
1187
enums := buildEnums (req )
1162
1188
models := buildModels (conf , req )
1163
1189
queries , err := buildQueries (conf , req , models )
0 commit comments