diff --git a/docs/reference/query-annotations.md b/docs/reference/query-annotations.md index 15d2128cd0..cb0b299310 100644 --- a/docs/reference/query-annotations.md +++ b/docs/reference/query-annotations.md @@ -94,3 +94,110 @@ func (q *Queries) GetAuthor(ctx context.Context, id int64) (Author, error) { // ... } ``` + +## `:batchexec` + +__NOTE: This command only works with PostgreSQL using the `pgx` driver and outputting Go code.__ + +The generated method will return a batch object. The batch object will have +the following methods: +- `Exec`, that takes a `func(int, error)` parameter, +- `Close`, to close the batch operation early. + +```sql +-- name: DeleteBook :batchexec +DELETE FROM books +WHERE book_id = $1; +``` + +```go +type DeleteBookBatchResults struct { + br pgx.BatchResults + ind int +} +func (q *Queries) DeleteBook(ctx context.Context, bookID []int32) *DeleteBookBatchResults { + //... +} +func (b *DeleteBookBatchResults) Exec(f func(int, error)) { + //... +} +func (b *DeleteBookBatchResults) Close() error { + //... +} +``` + +## `:batchmany` + +__NOTE: This command only works with PostgreSQL using the `pgx` driver and outputting Go code.__ + +The generated method will return a batch object. The batch object will have +the following methods: +- `Query`, that takes a `func(int, []T, error)` parameter, where `T` is your query's return type +- `Close`, to close the batch operation early. + +```sql +-- name: BooksByTitleYear :batchmany +SELECT * FROM books +WHERE title = $1 AND year = $2; +``` + +```go +type BooksByTitleYearBatchResults struct { + br pgx.BatchResults + ind int +} +type BooksByTitleYearParams struct { + Title string `json:"title"` + Year int32 `json:"year"` +} +func (q *Queries) BooksByTitleYear(ctx context.Context, arg []BooksByTitleYearParams) *BooksByTitleYearBatchResults { + //... +} +func (b *BooksByTitleYearBatchResults) Query(f func(int, []Book, error)) { + //... +} +func (b *BooksByTitleYearBatchResults) Close() error { + //... +} +``` + +## `:batchone` + +__NOTE: This command only works with PostgreSQL using the `pgx` driver and outputting Go code.__ + +The generated method will return a batch object. The batch object will have +the following methods: +- `QueryRow`, that takes a `func(int, T, error)` parameter, where `T` is your query's return type +- `Close`, to close the batch operation early. + +```sql +-- name: CreateBook :batchone +INSERT INTO books ( + author_id, + isbn +) VALUES ( + $1, + $2 +) +RETURNING book_id, author_id, isbn +``` + +```go +type CreateBookBatchResults struct { + br pgx.BatchResults + ind int +} +type CreateBookParams struct { + AuthorID int32 `json:"author_id"` + Isbn string `json:"isbn"` +} +func (q *Queries) CreateBook(ctx context.Context, arg []CreateBookParams) *CreateBookBatchResults { + //... +} +func (b *CreateBookBatchResults) QueryRow(f func(int, Book, error)) { + //... +} +func (b *CreateBookBatchResults) Close() error { + //... +} +``` diff --git a/examples/batch/postgresql/batch.go b/examples/batch/postgresql/batch.go new file mode 100644 index 0000000000..db8518feaf --- /dev/null +++ b/examples/batch/postgresql/batch.go @@ -0,0 +1,237 @@ +// Code generated by sqlc. DO NOT EDIT. +// source: batch.go +package batch + +import ( + "context" + "time" + + "github.com/jackc/pgx/v4" +) + +const booksByYear = `-- name: BooksByYear :batchmany +SELECT book_id, author_id, isbn, book_type, title, year, available, tags FROM books +WHERE year = $1 +` + +type BooksByYearBatchResults struct { + br pgx.BatchResults + ind int +} + +func (q *Queries) BooksByYear(ctx context.Context, year []int32) *BooksByYearBatchResults { + batch := &pgx.Batch{} + for _, a := range year { + vals := []interface{}{ + a, + } + batch.Queue(booksByYear, vals...) + } + br := q.db.SendBatch(ctx, batch) + return &BooksByYearBatchResults{br, 0} +} + +func (b *BooksByYearBatchResults) Query(f func(int, []Book, error)) { + for { + rows, err := b.br.Query() + if err != nil && (err.Error() == "no result" || err.Error() == "batch already closed") { + break + } + defer rows.Close() + var items []Book + for rows.Next() { + var i Book + if err := rows.Scan( + &i.BookID, + &i.AuthorID, + &i.Isbn, + &i.BookType, + &i.Title, + &i.Year, + &i.Available, + &i.Tags, + ); err != nil { + break + } + items = append(items, i) + } + + if f != nil { + f(b.ind, items, rows.Err()) + } + b.ind++ + } +} + +func (b *BooksByYearBatchResults) Close() error { + return b.br.Close() +} + +const createBook = `-- name: CreateBook :batchone +INSERT INTO books ( + author_id, + isbn, + book_type, + title, + year, + available, + tags +) VALUES ( + $1, + $2, + $3, + $4, + $5, + $6, + $7 +) +RETURNING book_id, author_id, isbn, book_type, title, year, available, tags +` + +type CreateBookBatchResults struct { + br pgx.BatchResults + ind int +} + +type CreateBookParams struct { + AuthorID int32 `json:"author_id"` + Isbn string `json:"isbn"` + BookType BookType `json:"book_type"` + Title string `json:"title"` + Year int32 `json:"year"` + Available time.Time `json:"available"` + Tags []string `json:"tags"` +} + +func (q *Queries) CreateBook(ctx context.Context, arg []CreateBookParams) *CreateBookBatchResults { + batch := &pgx.Batch{} + for _, a := range arg { + vals := []interface{}{ + a.AuthorID, + a.Isbn, + a.BookType, + a.Title, + a.Year, + a.Available, + a.Tags, + } + batch.Queue(createBook, vals...) + } + br := q.db.SendBatch(ctx, batch) + return &CreateBookBatchResults{br, 0} +} + +func (b *CreateBookBatchResults) QueryRow(f func(int, Book, error)) { + for { + row := b.br.QueryRow() + var i Book + err := row.Scan( + &i.BookID, + &i.AuthorID, + &i.Isbn, + &i.BookType, + &i.Title, + &i.Year, + &i.Available, + &i.Tags, + ) + if err != nil && (err.Error() == "no result" || err.Error() == "batch already closed") { + break + } + if f != nil { + f(b.ind, i, err) + } + b.ind++ + } +} + +func (b *CreateBookBatchResults) Close() error { + return b.br.Close() +} + +const deleteBook = `-- name: DeleteBook :batchexec +DELETE FROM books +WHERE book_id = $1 +` + +type DeleteBookBatchResults struct { + br pgx.BatchResults + ind int +} + +func (q *Queries) DeleteBook(ctx context.Context, bookID []int32) *DeleteBookBatchResults { + batch := &pgx.Batch{} + for _, a := range bookID { + vals := []interface{}{ + a, + } + batch.Queue(deleteBook, vals...) + } + br := q.db.SendBatch(ctx, batch) + return &DeleteBookBatchResults{br, 0} +} + +func (b *DeleteBookBatchResults) Exec(f func(int, error)) { + for { + _, err := b.br.Exec() + if err != nil && (err.Error() == "no result" || err.Error() == "batch already closed") { + break + } + if f != nil { + f(b.ind, err) + } + b.ind++ + } +} + +func (b *DeleteBookBatchResults) Close() error { + return b.br.Close() +} + +const updateBook = `-- name: UpdateBook :batchexec +UPDATE books +SET title = $1, tags = $2 +WHERE book_id = $3 +` + +type UpdateBookBatchResults struct { + br pgx.BatchResults + ind int +} + +type UpdateBookParams struct { + Title string `json:"title"` + Tags []string `json:"tags"` + BookID int32 `json:"book_id"` +} + +func (q *Queries) UpdateBook(ctx context.Context, arg []UpdateBookParams) *UpdateBookBatchResults { + batch := &pgx.Batch{} + for _, a := range arg { + vals := []interface{}{ + a.Title, + a.Tags, + a.BookID, + } + batch.Queue(updateBook, vals...) + } + br := q.db.SendBatch(ctx, batch) + return &UpdateBookBatchResults{br, 0} +} + +func (b *UpdateBookBatchResults) Exec(f func(int, error)) { + for { + _, err := b.br.Exec() + if err != nil && (err.Error() == "no result" || err.Error() == "batch already closed") { + break + } + if f != nil { + f(b.ind, err) + } + b.ind++ + } +} + +func (b *UpdateBookBatchResults) Close() error { + return b.br.Close() +} diff --git a/examples/batch/postgresql/db.go b/examples/batch/postgresql/db.go new file mode 100644 index 0000000000..e5e1ad5bbe --- /dev/null +++ b/examples/batch/postgresql/db.go @@ -0,0 +1,31 @@ +// Code generated by sqlc. DO NOT EDIT. + +package batch + +import ( + "context" + + "github.com/jackc/pgconn" + "github.com/jackc/pgx/v4" +) + +type DBTX interface { + Exec(context.Context, string, ...interface{}) (pgconn.CommandTag, error) + Query(context.Context, string, ...interface{}) (pgx.Rows, error) + QueryRow(context.Context, string, ...interface{}) pgx.Row + SendBatch(context.Context, *pgx.Batch) pgx.BatchResults +} + +func New(db DBTX) *Queries { + return &Queries{db: db} +} + +type Queries struct { + db DBTX +} + +func (q *Queries) WithTx(tx pgx.Tx) *Queries { + return &Queries{ + db: tx, + } +} diff --git a/examples/batch/postgresql/db_test.go b/examples/batch/postgresql/db_test.go new file mode 100644 index 0000000000..6293831efe --- /dev/null +++ b/examples/batch/postgresql/db_test.go @@ -0,0 +1,140 @@ +//go:build examples +// +build examples + +package batch + +import ( + "context" + "testing" + "time" + + "github.com/kyleconroy/sqlc/internal/sqltest" +) + +func TestBatchBooks(t *testing.T) { + db, cleanup := sqltest.PostgreSQLPgx(t, []string{"schema.sql"}) + defer cleanup() + + ctx := context.Background() + dq := New(db) + + // create an author + a, err := dq.CreateAuthor(ctx, "Unknown Master") + if err != nil { + t.Fatal(err) + } + + now := time.Now() + + // batch insert new books + newBooksParams := []CreateBookParams{ + { + AuthorID: a.AuthorID, + Isbn: "1", + Title: "my book title", + BookType: BookTypeFICTION, + Year: 2016, + Available: now, + Tags: []string{}, + }, + { + AuthorID: a.AuthorID, + Isbn: "2", + Title: "the second book", + BookType: BookTypeFICTION, + Year: 2016, + Available: now, + Tags: []string{"cool", "unique"}, + }, + { + AuthorID: a.AuthorID, + Isbn: "3", + Title: "the third book", + BookType: BookTypeFICTION, + Year: 2001, + Available: now, + Tags: []string{"cool"}, + }, + { + AuthorID: a.AuthorID, + Isbn: "4", + Title: "4th place finisher", + BookType: BookTypeNONFICTION, + Year: 2011, + Available: now, + Tags: []string{"other"}, + }, + } + newBooks := make([]Book, len(newBooksParams)) + var cnt int + dq.CreateBook(ctx, newBooksParams).QueryRow(func(i int, b Book, err error) { + if err != nil { + t.Fatalf("failed inserting book (%s): %s", b.Title, err) + } + newBooks[i] = b + cnt = i + }) + // first i was 0, so add 1 + cnt++ + numBooksExpected := len(newBooks) + if cnt != numBooksExpected { + t.Fatalf("expected to insert %d books; got %d", numBooksExpected, cnt) + } + + // batch update the title and tags + updateBooksParams := []UpdateBookParams{ + { + BookID: newBooks[1].BookID, + Title: "changed second title", + Tags: []string{"cool", "disastor"}, + }, + } + dq.UpdateBook(ctx, updateBooksParams).Exec(func(i int, err error) { + if err != nil { + t.Fatalf("error updating book %d: %s", updateBooksParams[i].BookID, err) + } + }) + + // batch many to retrieve books by year + selectBooksByTitleYearParams := []int32{2001, 2016} + var books0 []Book + dq.BooksByYear(ctx, selectBooksByTitleYearParams).Query(func(i int, books []Book, err error) { + if err != nil { + t.Fatal(err) + } + t.Logf("num books for %d: %d", selectBooksByTitleYearParams[i], len(books)) + books0 = append(books0, books...) + }) + + for _, book := range books0 { + t.Logf("Book %d (%s): %s available: %s\n", book.BookID, book.BookType, book.Title, book.Available.Format(time.RFC822Z)) + author, err := dq.GetAuthor(ctx, book.AuthorID) + if err != nil { + t.Fatal(err) + } + t.Logf("Book %d author: %s\n", book.BookID, author.Name) + } + + // batch delete books + deleteBooksParams := make([]int32, len(newBooks)) + for i, book := range newBooks { + deleteBooksParams[i] = book.BookID + } + batchDelete := dq.DeleteBook(ctx, deleteBooksParams) + numDeletesProcessed := 0 + batchDelete.Exec(func(i int, err error) { + numDeletesProcessed++ + if err != nil { + t.Fatalf("error deleting book %d: %s", deleteBooksParams[i], err) + } + if i == len(deleteBooksParams)-3 { + // close batch operation before processing all errors from delete operation + if err := batchDelete.Close(); err != nil { + t.Fatalf("failed to close batch operation: %s", err) + } + } + }) + if numDeletesProcessed != 2 { + t.Fatalf("expected Close to short-circuit record processing (expected 2; got %d)", numDeletesProcessed) + } +} diff --git a/examples/batch/postgresql/models.go b/examples/batch/postgresql/models.go new file mode 100644 index 0000000000..07c2297605 --- /dev/null +++ b/examples/batch/postgresql/models.go @@ -0,0 +1,43 @@ +// Code generated by sqlc. DO NOT EDIT. + +package batch + +import ( + "fmt" + "time" +) + +type BookType string + +const ( + BookTypeFICTION BookType = "FICTION" + BookTypeNONFICTION BookType = "NONFICTION" +) + +func (e *BookType) Scan(src interface{}) error { + switch s := src.(type) { + case []byte: + *e = BookType(s) + case string: + *e = BookType(s) + default: + return fmt.Errorf("unsupported scan type for BookType: %T", src) + } + return nil +} + +type Author struct { + AuthorID int32 `json:"author_id"` + Name string `json:"name"` +} + +type Book struct { + BookID int32 `json:"book_id"` + AuthorID int32 `json:"author_id"` + Isbn string `json:"isbn"` + BookType BookType `json:"book_type"` + Title string `json:"title"` + Year int32 `json:"year"` + Available time.Time `json:"available"` + Tags []string `json:"tags"` +} diff --git a/examples/batch/postgresql/querier.go b/examples/batch/postgresql/querier.go new file mode 100644 index 0000000000..96ee010456 --- /dev/null +++ b/examples/batch/postgresql/querier.go @@ -0,0 +1,18 @@ +// Code generated by sqlc. DO NOT EDIT. + +package batch + +import ( + "context" +) + +type Querier interface { + BooksByYear(ctx context.Context, year []int32) *BooksByYearBatchResults + CreateAuthor(ctx context.Context, name string) (Author, error) + CreateBook(ctx context.Context, arg []CreateBookParams) *CreateBookBatchResults + DeleteBook(ctx context.Context, bookID []int32) *DeleteBookBatchResults + GetAuthor(ctx context.Context, authorID int32) (Author, error) + UpdateBook(ctx context.Context, arg []UpdateBookParams) *UpdateBookBatchResults +} + +var _ Querier = (*Queries)(nil) diff --git a/examples/batch/postgresql/query.sql b/examples/batch/postgresql/query.sql new file mode 100644 index 0000000000..e8d53d2b9e --- /dev/null +++ b/examples/batch/postgresql/query.sql @@ -0,0 +1,40 @@ +-- name: GetAuthor :one +SELECT * FROM authors +WHERE author_id = $1; + +-- name: DeleteBook :batchexec +DELETE FROM books +WHERE book_id = $1; + +-- name: BooksByYear :batchmany +SELECT * FROM books +WHERE year = $1; + +-- name: CreateAuthor :one +INSERT INTO authors (name) VALUES ($1) +RETURNING *; + +-- name: CreateBook :batchone +INSERT INTO books ( + author_id, + isbn, + book_type, + title, + year, + available, + tags +) VALUES ( + $1, + $2, + $3, + $4, + $5, + $6, + $7 +) +RETURNING *; + +-- name: UpdateBook :batchexec +UPDATE books +SET title = $1, tags = $2 +WHERE book_id = $3; diff --git a/examples/batch/postgresql/query.sql.go b/examples/batch/postgresql/query.sql.go new file mode 100644 index 0000000000..8e5f2dd2d4 --- /dev/null +++ b/examples/batch/postgresql/query.sql.go @@ -0,0 +1,32 @@ +// Code generated by sqlc. DO NOT EDIT. +// source: query.sql + +package batch + +import ( + "context" +) + +const createAuthor = `-- name: CreateAuthor :one +INSERT INTO authors (name) VALUES ($1) +RETURNING author_id, name +` + +func (q *Queries) CreateAuthor(ctx context.Context, name string) (Author, error) { + row := q.db.QueryRow(ctx, createAuthor, name) + var i Author + err := row.Scan(&i.AuthorID, &i.Name) + return i, err +} + +const getAuthor = `-- name: GetAuthor :one +SELECT author_id, name FROM authors +WHERE author_id = $1 +` + +func (q *Queries) GetAuthor(ctx context.Context, authorID int32) (Author, error) { + row := q.db.QueryRow(ctx, getAuthor, authorID) + var i Author + err := row.Scan(&i.AuthorID, &i.Name) + return i, err +} diff --git a/examples/batch/postgresql/schema.sql b/examples/batch/postgresql/schema.sql new file mode 100644 index 0000000000..700b6658e7 --- /dev/null +++ b/examples/batch/postgresql/schema.sql @@ -0,0 +1,20 @@ +CREATE TABLE authors ( + author_id SERIAL PRIMARY KEY, + name text NOT NULL DEFAULT '' +); + +CREATE TYPE book_type AS ENUM ( + 'FICTION', + 'NONFICTION' +); + +CREATE TABLE books ( + book_id SERIAL PRIMARY KEY, + author_id integer NOT NULL REFERENCES authors(author_id), + isbn text NOT NULL DEFAULT '' UNIQUE, + book_type book_type NOT NULL DEFAULT 'FICTION', + title text NOT NULL DEFAULT '', + year integer NOT NULL DEFAULT 2000, + available timestamp with time zone NOT NULL DEFAULT 'NOW()', + tags varchar[] NOT NULL DEFAULT '{}' +); diff --git a/examples/batch/sqlc.json b/examples/batch/sqlc.json new file mode 100644 index 0000000000..f56bad9b5d --- /dev/null +++ b/examples/batch/sqlc.json @@ -0,0 +1,16 @@ +{ + "version": "1", + "packages": [ + { + "path": "postgresql", + "name": "batch", + "schema": "postgresql/schema.sql", + "queries": "postgresql/query.sql", + "engine": "postgresql", + "sql_package": "pgx/v4", + "emit_json_tags": true, + "emit_prepared_queries": true, + "emit_interface": true + } + ] +} \ No newline at end of file diff --git a/internal/codegen/golang/gen.go b/internal/codegen/golang/gen.go index 7e427d0356..522cdb1b4f 100644 --- a/internal/codegen/golang/gen.go +++ b/internal/codegen/golang/gen.go @@ -40,6 +40,7 @@ type tmplCtx struct { EmitEmptySlices bool EmitMethodsWithDBArgument bool UsesCopyFrom bool + UsesBatch bool } func (t *tmplCtx) OutputQuery(sourceName string) bool { @@ -69,6 +70,7 @@ func generate(settings config.CombinedSettings, enums []Enum, structs []Struct, "comment": codegen.DoubleSlashComment, "escape": codegen.EscapeBacktick, "imports": i.Imports, + "hasPrefix": strings.HasPrefix, } tmpl := template.Must( @@ -91,6 +93,7 @@ func generate(settings config.CombinedSettings, enums []Enum, structs []Struct, EmitEmptySlices: golang.EmitEmptySlices, EmitMethodsWithDBArgument: golang.EmitMethodsWithDBArgument, UsesCopyFrom: usesCopyFrom(queries), + UsesBatch: usesBatch(queries), SQLPackage: SQLPackageFromString(golang.SQLPackage), Q: "`", Package: golang.Package, @@ -103,6 +106,10 @@ func generate(settings config.CombinedSettings, enums []Enum, structs []Struct, return nil, errors.New(":copyfrom is only supported by pgx") } + if tctx.UsesBatch && tctx.SQLPackage != SQLPackagePGX { + return nil, errors.New(":batch* commands are only supported by pgx") + } + output := map[string]string{} execute := func(name, templateName string) error { @@ -146,6 +153,8 @@ func generate(settings config.CombinedSettings, enums []Enum, structs []Struct, copyfromFileName := "copyfrom.go" // TODO(Jille): Make this configurable. + batchFileName := "batch.go" + if err := execute(dbFileName, "dbFile"); err != nil { return nil, err } @@ -162,6 +171,11 @@ func generate(settings config.CombinedSettings, enums []Enum, structs []Struct, return nil, err } } + if tctx.UsesBatch { + if err := execute(batchFileName, "batchFile"); err != nil { + return nil, err + } + } files := map[string]struct{}{} for _, gq := range queries { @@ -184,3 +198,14 @@ func usesCopyFrom(queries []Query) bool { } return false } + +func usesBatch(queries []Query) bool { + for _, q := range queries { + for _, cmd := range []string{metadata.CmdBatchExec, metadata.CmdBatchMany, metadata.CmdBatchOne} { + if q.Cmd == cmd { + return true + } + } + } + return false +} diff --git a/internal/codegen/golang/imports.go b/internal/codegen/golang/imports.go index 93ff4d4376..a17bd18d1e 100644 --- a/internal/codegen/golang/imports.go +++ b/internal/codegen/golang/imports.go @@ -90,6 +90,7 @@ func (i *importer) Imports(filename string) [][]ImportSpec { querierFileName = i.Settings.Go.OutputQuerierFileName } copyfromFileName := "copyfrom.go" + batchFileName := "batch.go" switch filename { case dbFileName: @@ -100,6 +101,8 @@ func (i *importer) Imports(filename string) [][]ImportSpec { return mergeImports(i.interfaceImports()) case copyfromFileName: return mergeImports(i.copyfromImports()) + case batchFileName: + return mergeImports(i.batchImports(filename)) default: return mergeImports(i.queryImports(filename)) } @@ -284,6 +287,9 @@ func (i *importer) queryImports(filename string) fileImports { var gq []Query anyNonCopyFrom := false for _, query := range i.Queries { + if usesBatch([]Query{query}) { + continue + } if query.SourceName == filename { gq = append(gq, query) if query.Cmd != metadata.CmdCopyFrom { @@ -392,3 +398,45 @@ func (i *importer) copyfromImports() fileImports { return sortedImports(std, pkg) } + +func (i *importer) batchImports(filename string) fileImports { + std, pkg := buildImports(i.Settings, i.Queries, func(name string) bool { + for _, q := range i.Queries { + if !usesBatch([]Query{q}) { + continue + } + if q.hasRetType() { + if q.Ret.EmitStruct() { + for _, f := range q.Ret.Struct.Fields { + fType := strings.TrimPrefix(f.Type, "[]") + if strings.HasPrefix(fType, name) { + return true + } + } + } + if strings.HasPrefix(q.Ret.Type(), name) { + return true + } + } + if !q.Arg.isEmpty() { + if q.Arg.EmitStruct() { + for _, f := range q.Arg.Struct.Fields { + fType := strings.TrimPrefix(f.Type, "[]") + if strings.HasPrefix(fType, name) { + return true + } + } + } + if strings.HasPrefix(q.Arg.Type(), name) { + return true + } + } + } + return false + }) + + std["context"] = struct{}{} + pkg[ImportSpec{Path: "github.com/jackc/pgx/v4"}] = struct{}{} + + return sortedImports(std, pkg) +} diff --git a/internal/codegen/golang/templates/pgx/batchCode.tmpl b/internal/codegen/golang/templates/pgx/batchCode.tmpl new file mode 100644 index 0000000000..c48ac5aecc --- /dev/null +++ b/internal/codegen/golang/templates/pgx/batchCode.tmpl @@ -0,0 +1,113 @@ +{{define "batchCodePgx"}} +{{range .GoQueries}} +{{if eq (hasPrefix .Cmd ":batch") true }} +const {{.ConstantName}} = {{$.Q}}-- name: {{.MethodName}} {{.Cmd}} +{{escape .SQL}} +{{$.Q}} + +type {{.MethodName}}BatchResults struct { + br pgx.BatchResults + ind int +} + +{{if .Arg.EmitStruct}} +type {{.Arg.Type}} struct { {{- range .Arg.Struct.Fields}} + {{.Name}} {{.Type}} {{if or ($.EmitJSONTags) ($.EmitDBTags)}}{{$.Q}}{{.Tag}}{{$.Q}}{{end}} + {{- end}} +} +{{end}} + +{{if .Ret.EmitStruct}} +type {{.Ret.Type}} struct { {{- range .Ret.Struct.Fields}} + {{.Name}} {{.Type}} {{if or ($.EmitJSONTags) ($.EmitDBTags)}}{{$.Q}}{{.Tag}}{{$.Q}}{{end}} + {{- end}} +} +{{end}} + +{{range .Comments}}//{{.}} +{{end -}} +func (q *Queries) {{.MethodName}}(ctx context.Context, {{ if $.EmitMethodsWithDBArgument}}db DBTX,{{end}} {{.Arg.SlicePair}}) *{{.MethodName}}BatchResults { + batch := &pgx.Batch{} + for _, a := range {{index .Arg.Name}} { + vals := []interface{}{ + {{- if .Arg.Struct }} + {{- range .Arg.Struct.Fields }} + a.{{.Name}}, + {{- end }} + {{- else }} + a, + {{- end }} + } + batch.Queue({{.ConstantName}}, vals...) + } + br := {{if not $.EmitMethodsWithDBArgument}}q.{{end}}db.SendBatch(ctx, batch) + return &{{.MethodName}}BatchResults{br,0} +} + +{{if eq .Cmd ":batchexec"}} +func (b *{{.MethodName}}BatchResults) Exec(f func(int, error)) { + for { + _, err := b.br.Exec() + if err != nil && (err.Error() == "no result" || err.Error() == "batch already closed"){ + break + } + if f != nil { + f(b.ind, err) + } + b.ind++ + } +} +{{end}} + +{{if eq .Cmd ":batchmany"}} +func (b *{{.MethodName}}BatchResults) Query(f func(int, []{{.Ret.DefineType}}, error)) { + for { + rows, err := b.br.Query() + if err != nil && (err.Error() == "no result" || err.Error() == "batch already closed") { + break + } + defer rows.Close() + {{- if $.EmitEmptySlices}} + items := []{{.Ret.DefineType}}{} + {{else}} + var items []{{.Ret.DefineType}} + {{end -}} + for rows.Next() { + var {{.Ret.Name}} {{.Ret.Type}} + if err := rows.Scan({{.Ret.Scan}}); err != nil { + break + } + items = append(items, {{.Ret.ReturnName}}) + } + + if f != nil { + f(b.ind, items, rows.Err()) + } + b.ind++ + } +} +{{end}} + +{{if eq .Cmd ":batchone"}} +func (b *{{.MethodName}}BatchResults) QueryRow(f func(int, {{.Ret.DefineType}}, error)) { + for { + row := b.br.QueryRow() + var {{.Ret.Name}} {{.Ret.Type}} + err := row.Scan({{.Ret.Scan}}) + if err != nil && (err.Error() == "no result" || err.Error() == "batch already closed") { + break + } + if f != nil { + f(b.ind, {{.Ret.ReturnName}}, err) + } + b.ind++ + } +} +{{end}} + +func (b *{{.MethodName}}BatchResults) Close() error { + return b.br.Close() +} +{{end}} +{{end}} +{{end}} diff --git a/internal/codegen/golang/templates/pgx/dbCode.tmpl b/internal/codegen/golang/templates/pgx/dbCode.tmpl index 00d624f8b9..236554d9f2 100644 --- a/internal/codegen/golang/templates/pgx/dbCode.tmpl +++ b/internal/codegen/golang/templates/pgx/dbCode.tmpl @@ -7,6 +7,9 @@ type DBTX interface { {{- if .UsesCopyFrom }} CopyFrom(ctx context.Context, tableName pgx.Identifier, columnNames []string, rowSrc pgx.CopyFromSource) (int64, error) {{- end }} +{{- if .UsesBatch }} + SendBatch(context.Context, *pgx.Batch) pgx.BatchResults +{{- end }} } {{ if .EmitMethodsWithDBArgument}} diff --git a/internal/codegen/golang/templates/pgx/interfaceCode.tmpl b/internal/codegen/golang/templates/pgx/interfaceCode.tmpl index c8bedb997c..705963b48d 100644 --- a/internal/codegen/golang/templates/pgx/interfaceCode.tmpl +++ b/internal/codegen/golang/templates/pgx/interfaceCode.tmpl @@ -32,6 +32,12 @@ {{- else if eq .Cmd ":copyfrom" }} {{.MethodName}}(ctx context.Context, {{.Arg.SlicePair}}) (int64, error) {{- end}} + {{- if and (or (eq .Cmd ":batchexec") (eq .Cmd ":batchmany") (eq .Cmd ":batchone")) ($dbtxParam) }} + {{.MethodName}}(ctx context.Context, db DBTX, {{.Arg.SlicePair}}) *{{.MethodName}}BatchResults + {{- else if or (eq .Cmd ":batchexec") (eq .Cmd ":batchmany") (eq .Cmd ":batchone") }} + {{.MethodName}}(ctx context.Context, {{.Arg.SlicePair}}) *{{.MethodName}}BatchResults + {{- end}} + {{- end}} } diff --git a/internal/codegen/golang/templates/pgx/queryCode.tmpl b/internal/codegen/golang/templates/pgx/queryCode.tmpl index da27d64be9..230844a58b 100644 --- a/internal/codegen/golang/templates/pgx/queryCode.tmpl +++ b/internal/codegen/golang/templates/pgx/queryCode.tmpl @@ -1,12 +1,13 @@ {{define "queryCodePgx"}} {{range .GoQueries}} {{if $.OutputQuery .SourceName}} -{{if ne .Cmd ":copyfrom"}} +{{if and (ne .Cmd ":copyfrom") (ne (hasPrefix .Cmd ":batch") true)}} const {{.ConstantName}} = {{$.Q}}-- name: {{.MethodName}} {{.Cmd}} {{escape .SQL}} {{$.Q}} {{end}} +{{if ne (hasPrefix .Cmd ":batch") true}} {{if .Arg.EmitStruct}} type {{.Arg.Type}} struct { {{- range .Arg.Struct.Fields}} {{.Name}} {{.Type}} {{if or ($.EmitJSONTags) ($.EmitDBTags)}}{{$.Q}}{{.Tag}}{{$.Q}}{{end}} @@ -20,6 +21,7 @@ type {{.Ret.Type}} struct { {{- range .Ret.Struct.Fields}} {{- end}} } {{end}} +{{end}} {{if eq .Cmd ":one"}} {{range .Comments}}//{{.}} diff --git a/internal/codegen/golang/templates/template.tmpl b/internal/codegen/golang/templates/template.tmpl index b0d7d3f933..da22f89273 100644 --- a/internal/codegen/golang/templates/template.tmpl +++ b/internal/codegen/golang/templates/template.tmpl @@ -137,3 +137,22 @@ import ( {{- template "copyfromCodePgx" .}} {{end}} {{end}} + +{{define "batchFile"}}// Code generated by sqlc. DO NOT EDIT. +// source: {{.SourceName}} +package {{.Package}} + +import ( + {{range imports .SourceName}} + {{range .}}{{.}} + {{end}} + {{end}} +) +{{template "batchCode" . }} +{{end}} + +{{define "batchCode"}} +{{if eq .SQLPackage "pgx/v4"}} + {{- template "batchCodePgx" .}} +{{end}} +{{end}} \ No newline at end of file diff --git a/internal/endtoend/testdata/batch/postgresql/pgx/go/batch.go b/internal/endtoend/testdata/batch/postgresql/pgx/go/batch.go new file mode 100644 index 0000000000..e6ff1f51fd --- /dev/null +++ b/internal/endtoend/testdata/batch/postgresql/pgx/go/batch.go @@ -0,0 +1,152 @@ +// Code generated by sqlc. DO NOT EDIT. +// source: batch.go +package querytest + +import ( + "context" + "database/sql" + + "github.com/jackc/pgx/v4" +) + +const getValues = `-- name: GetValues :batchmany +SELECT a, b +FROM myschema.foo +WHERE b = $1 +` + +type GetValuesBatchResults struct { + br pgx.BatchResults + ind int +} + +func (q *Queries) GetValues(ctx context.Context, b []sql.NullInt32) *GetValuesBatchResults { + batch := &pgx.Batch{} + for _, a := range b { + vals := []interface{}{ + a, + } + batch.Queue(getValues, vals...) + } + br := q.db.SendBatch(ctx, batch) + return &GetValuesBatchResults{br, 0} +} + +func (b *GetValuesBatchResults) Query(f func(int, []MyschemaFoo, error)) { + for { + rows, err := b.br.Query() + if err != nil && (err.Error() == "no result" || err.Error() == "batch already closed") { + break + } + defer rows.Close() + var items []MyschemaFoo + for rows.Next() { + var i MyschemaFoo + if err := rows.Scan(&i.A, &i.B); err != nil { + break + } + items = append(items, i) + } + + if f != nil { + f(b.ind, items, rows.Err()) + } + b.ind++ + } +} + +func (b *GetValuesBatchResults) Close() error { + return b.br.Close() +} + +const insertValues = `-- name: InsertValues :batchone +INSERT INTO myschema.foo (a, b) +VALUES ($1, $2) +RETURNING a +` + +type InsertValuesBatchResults struct { + br pgx.BatchResults + ind int +} + +type InsertValuesParams struct { + A sql.NullString + B sql.NullInt32 +} + +func (q *Queries) InsertValues(ctx context.Context, arg []InsertValuesParams) *InsertValuesBatchResults { + batch := &pgx.Batch{} + for _, a := range arg { + vals := []interface{}{ + a.A, + a.B, + } + batch.Queue(insertValues, vals...) + } + br := q.db.SendBatch(ctx, batch) + return &InsertValuesBatchResults{br, 0} +} + +func (b *InsertValuesBatchResults) QueryRow(f func(int, sql.NullString, error)) { + for { + row := b.br.QueryRow() + var a sql.NullString + err := row.Scan(&a) + if err != nil && (err.Error() == "no result" || err.Error() == "batch already closed") { + break + } + if f != nil { + f(b.ind, a, err) + } + b.ind++ + } +} + +func (b *InsertValuesBatchResults) Close() error { + return b.br.Close() +} + +const updateValues = `-- name: UpdateValues :batchexec +UPDATE myschema.foo SET a = $1, b = $2 +` + +type UpdateValuesBatchResults struct { + br pgx.BatchResults + ind int +} + +type UpdateValuesParams struct { + A sql.NullString + B sql.NullInt32 +} + +func (q *Queries) UpdateValues(ctx context.Context, arg []UpdateValuesParams) *UpdateValuesBatchResults { + batch := &pgx.Batch{} + for _, a := range arg { + vals := []interface{}{ + a.A, + a.B, + } + batch.Queue(updateValues, vals...) + } + br := q.db.SendBatch(ctx, batch) + return &UpdateValuesBatchResults{br, 0} +} + +func (b *UpdateValuesBatchResults) Exec(f func(int, error)) { + for { + _, err := b.br.Exec() + if err != nil && (err.Error() == "no result" || err.Error() == "batch already closed") { + break + } + if f != nil { + f(b.ind, err) + } + b.ind++ + } +} + +func (b *UpdateValuesBatchResults) Close() error { + return b.br.Close() +} diff --git a/internal/endtoend/testdata/batch/postgresql/pgx/go/db.go b/internal/endtoend/testdata/batch/postgresql/pgx/go/db.go new file mode 100644 index 0000000000..20a53615b1 --- /dev/null +++ b/internal/endtoend/testdata/batch/postgresql/pgx/go/db.go @@ -0,0 +1,31 @@ +// Code generated by sqlc. DO NOT EDIT. + +package querytest + +import ( + "context" + + "github.com/jackc/pgconn" + "github.com/jackc/pgx/v4" +) + +type DBTX interface { + Exec(context.Context, string, ...interface{}) (pgconn.CommandTag, error) + Query(context.Context, string, ...interface{}) (pgx.Rows, error) + QueryRow(context.Context, string, ...interface{}) pgx.Row + SendBatch(context.Context, *pgx.Batch) pgx.BatchResults +} + +func New(db DBTX) *Queries { + return &Queries{db: db} +} + +type Queries struct { + db DBTX +} + +func (q *Queries) WithTx(tx pgx.Tx) *Queries { + return &Queries{ + db: tx, + } +} diff --git a/internal/endtoend/testdata/batch/postgresql/pgx/go/models.go b/internal/endtoend/testdata/batch/postgresql/pgx/go/models.go new file mode 100644 index 0000000000..319a4b26a2 --- /dev/null +++ b/internal/endtoend/testdata/batch/postgresql/pgx/go/models.go @@ -0,0 +1,12 @@ +// Code generated by sqlc. DO NOT EDIT. + +package querytest + +import ( + "database/sql" +) + +type MyschemaFoo struct { + A sql.NullString + B sql.NullInt32 +} diff --git a/internal/endtoend/testdata/batch/postgresql/pgx/go/query.sql.go b/internal/endtoend/testdata/batch/postgresql/pgx/go/query.sql.go new file mode 100644 index 0000000000..97309c281f --- /dev/null +++ b/internal/endtoend/testdata/batch/postgresql/pgx/go/query.sql.go @@ -0,0 +1,6 @@ +// Code generated by sqlc. DO NOT EDIT. +// source: query.sql + +package querytest + +import () diff --git a/internal/endtoend/testdata/batch/postgresql/pgx/query.sql b/internal/endtoend/testdata/batch/postgresql/pgx/query.sql new file mode 100644 index 0000000000..2c6b776fd5 --- /dev/null +++ b/internal/endtoend/testdata/batch/postgresql/pgx/query.sql @@ -0,0 +1,15 @@ +CREATE SCHEMA myschema; +CREATE TABLE myschema.foo (a text, b integer); + +-- name: InsertValues :batchone +INSERT INTO myschema.foo (a, b) +VALUES ($1, $2) +RETURNING a; + +-- name: GetValues :batchmany +SELECT * +FROM myschema.foo +WHERE b = $1; + +-- name: UpdateValues :batchexec +UPDATE myschema.foo SET a = $1, b = $2; diff --git a/internal/endtoend/testdata/batch/postgresql/pgx/sqlc.json b/internal/endtoend/testdata/batch/postgresql/pgx/sqlc.json new file mode 100644 index 0000000000..9403bd0279 --- /dev/null +++ b/internal/endtoend/testdata/batch/postgresql/pgx/sqlc.json @@ -0,0 +1,13 @@ +{ + "version": "1", + "packages": [ + { + "path": "go", + "engine": "postgresql", + "sql_package": "pgx/v4", + "name": "querytest", + "schema": "query.sql", + "queries": "query.sql" + } + ] +} diff --git a/internal/endtoend/testdata/batch_imports/postgresql/pgx/go/batch.go b/internal/endtoend/testdata/batch_imports/postgresql/pgx/go/batch.go new file mode 100644 index 0000000000..97f4a4a0b5 --- /dev/null +++ b/internal/endtoend/testdata/batch_imports/postgresql/pgx/go/batch.go @@ -0,0 +1,108 @@ +// Code generated by sqlc. DO NOT EDIT. +// source: batch.go +package querytest + +import ( + "context" + "database/sql" + + "github.com/jackc/pgx/v4" +) + +const getValues = `-- name: GetValues :batchmany +SELECT a, b +FROM myschema.foo +WHERE b = $1 +` + +type GetValuesBatchResults struct { + br pgx.BatchResults + ind int +} + +func (q *Queries) GetValues(ctx context.Context, b []sql.NullInt32) *GetValuesBatchResults { + batch := &pgx.Batch{} + for _, a := range b { + vals := []interface{}{ + a, + } + batch.Queue(getValues, vals...) + } + br := q.db.SendBatch(ctx, batch) + return &GetValuesBatchResults{br, 0} +} + +func (b *GetValuesBatchResults) Query(f func(int, []MyschemaFoo, error)) { + for { + rows, err := b.br.Query() + if err != nil && (err.Error() == "no result" || err.Error() == "batch already closed") { + break + } + defer rows.Close() + var items []MyschemaFoo + for rows.Next() { + var i MyschemaFoo + if err := rows.Scan(&i.A, &i.B); err != nil { + break + } + items = append(items, i) + } + + if f != nil { + f(b.ind, items, rows.Err()) + } + b.ind++ + } +} + +func (b *GetValuesBatchResults) Close() error { + return b.br.Close() +} + +const insertValues = `-- name: InsertValues :batchone +INSERT INTO myschema.foo (a, b) +VALUES ($1, $2) +RETURNING a +` + +type InsertValuesBatchResults struct { + br pgx.BatchResults + ind int +} + +type InsertValuesParams struct { + A sql.NullString + B sql.NullInt32 +} + +func (q *Queries) InsertValues(ctx context.Context, arg []InsertValuesParams) *InsertValuesBatchResults { + batch := &pgx.Batch{} + for _, a := range arg { + vals := []interface{}{ + a.A, + a.B, + } + batch.Queue(insertValues, vals...) + } + br := q.db.SendBatch(ctx, batch) + return &InsertValuesBatchResults{br, 0} +} + +func (b *InsertValuesBatchResults) QueryRow(f func(int, sql.NullString, error)) { + for { + row := b.br.QueryRow() + var a sql.NullString + err := row.Scan(&a) + if err != nil && (err.Error() == "no result" || err.Error() == "batch already closed") { + break + } + if f != nil { + f(b.ind, a, err) + } + b.ind++ + } +} + +func (b *InsertValuesBatchResults) Close() error { + return b.br.Close() +} diff --git a/internal/endtoend/testdata/batch_imports/postgresql/pgx/go/db.go b/internal/endtoend/testdata/batch_imports/postgresql/pgx/go/db.go new file mode 100644 index 0000000000..20a53615b1 --- /dev/null +++ b/internal/endtoend/testdata/batch_imports/postgresql/pgx/go/db.go @@ -0,0 +1,31 @@ +// Code generated by sqlc. DO NOT EDIT. + +package querytest + +import ( + "context" + + "github.com/jackc/pgconn" + "github.com/jackc/pgx/v4" +) + +type DBTX interface { + Exec(context.Context, string, ...interface{}) (pgconn.CommandTag, error) + Query(context.Context, string, ...interface{}) (pgx.Rows, error) + QueryRow(context.Context, string, ...interface{}) pgx.Row + SendBatch(context.Context, *pgx.Batch) pgx.BatchResults +} + +func New(db DBTX) *Queries { + return &Queries{db: db} +} + +type Queries struct { + db DBTX +} + +func (q *Queries) WithTx(tx pgx.Tx) *Queries { + return &Queries{ + db: tx, + } +} diff --git a/internal/endtoend/testdata/batch_imports/postgresql/pgx/go/models.go b/internal/endtoend/testdata/batch_imports/postgresql/pgx/go/models.go new file mode 100644 index 0000000000..319a4b26a2 --- /dev/null +++ b/internal/endtoend/testdata/batch_imports/postgresql/pgx/go/models.go @@ -0,0 +1,12 @@ +// Code generated by sqlc. DO NOT EDIT. + +package querytest + +import ( + "database/sql" +) + +type MyschemaFoo struct { + A sql.NullString + B sql.NullInt32 +} diff --git a/internal/endtoend/testdata/batch_imports/postgresql/pgx/go/query.sql.go b/internal/endtoend/testdata/batch_imports/postgresql/pgx/go/query.sql.go new file mode 100644 index 0000000000..825afc6ac5 --- /dev/null +++ b/internal/endtoend/testdata/batch_imports/postgresql/pgx/go/query.sql.go @@ -0,0 +1,23 @@ +// Code generated by sqlc. DO NOT EDIT. +// source: query.sql + +package querytest + +import ( + "context" + "database/sql" +) + +const updateValues = `-- name: UpdateValues :exec +UPDATE myschema.foo SET a = $1, b = $2 +` + +type UpdateValuesParams struct { + A sql.NullString + B sql.NullInt32 +} + +func (q *Queries) UpdateValues(ctx context.Context, arg UpdateValuesParams) error { + _, err := q.db.Exec(ctx, updateValues, arg.A, arg.B) + return err +} diff --git a/internal/endtoend/testdata/batch_imports/postgresql/pgx/query.sql b/internal/endtoend/testdata/batch_imports/postgresql/pgx/query.sql new file mode 100644 index 0000000000..706cf8ef94 --- /dev/null +++ b/internal/endtoend/testdata/batch_imports/postgresql/pgx/query.sql @@ -0,0 +1,15 @@ +CREATE SCHEMA myschema; +CREATE TABLE myschema.foo (a text, b integer); + +-- name: InsertValues :batchone +INSERT INTO myschema.foo (a, b) +VALUES ($1, $2) +RETURNING a; + +-- name: GetValues :batchmany +SELECT * +FROM myschema.foo +WHERE b = $1; + +-- name: UpdateValues :exec +UPDATE myschema.foo SET a = $1, b = $2; diff --git a/internal/endtoend/testdata/batch_imports/postgresql/pgx/sqlc.json b/internal/endtoend/testdata/batch_imports/postgresql/pgx/sqlc.json new file mode 100644 index 0000000000..9403bd0279 --- /dev/null +++ b/internal/endtoend/testdata/batch_imports/postgresql/pgx/sqlc.json @@ -0,0 +1,13 @@ +{ + "version": "1", + "packages": [ + { + "path": "go", + "engine": "postgresql", + "sql_package": "pgx/v4", + "name": "querytest", + "schema": "query.sql", + "queries": "query.sql" + } + ] +} diff --git a/internal/metadata/meta.go b/internal/metadata/meta.go index 8fc2f3b32e..67a9e351c9 100644 --- a/internal/metadata/meta.go +++ b/internal/metadata/meta.go @@ -19,6 +19,9 @@ const ( CmdMany = ":many" CmdOne = ":one" CmdCopyFrom = ":copyfrom" + CmdBatchExec = ":batchexec" + CmdBatchMany = ":batchmany" + CmdBatchOne = ":batchone" ) // A query name must be a valid Go identifier @@ -80,7 +83,7 @@ func Parse(t string, commentStyle CommentSyntax) (string, string, error) { part = part[:len(part)-1] // removes the trailing "*/" element } if len(part) == 2 { - return "", "", fmt.Errorf("missing query type [':one', ':many', ':exec', ':execrows', ':execresult', ':copyfrom']: %s", line) + return "", "", fmt.Errorf("missing query type [':one', ':many', ':exec', ':execrows', ':execresult', ':copyfrom', 'batchexec', 'batchmany', 'batchone']: %s", line) } if len(part) != 4 { return "", "", fmt.Errorf("invalid query comment: %s", line) @@ -88,7 +91,7 @@ func Parse(t string, commentStyle CommentSyntax) (string, string, error) { queryName := part[2] queryType := strings.TrimSpace(part[3]) switch queryType { - case CmdOne, CmdMany, CmdExec, CmdExecResult, CmdExecRows, CmdCopyFrom: + case CmdOne, CmdMany, CmdExec, CmdExecResult, CmdExecRows, CmdCopyFrom, CmdBatchExec, CmdBatchMany, CmdBatchOne: default: return "", "", fmt.Errorf("invalid query type: %s", queryType) } diff --git a/internal/sql/validate/cmd.go b/internal/sql/validate/cmd.go index aed44d4a97..0dcc176383 100644 --- a/internal/sql/validate/cmd.go +++ b/internal/sql/validate/cmd.go @@ -44,10 +44,21 @@ func validateCopyfrom(n ast.Node) error { return nil } +func validateBatch(n ast.Node) error { + nums, _, _ := ParamRef(n) + if len(nums) == 0 { + return errors.New(":batch* commands require parameters") + } + return nil +} + func Cmd(n ast.Node, name, cmd string) error { if cmd == metadata.CmdCopyFrom { return validateCopyfrom(n) } + if (cmd == metadata.CmdBatchExec || cmd == metadata.CmdBatchMany) || cmd == metadata.CmdBatchOne { + return validateBatch(n) + } if !(cmd == metadata.CmdMany || cmd == metadata.CmdOne) { return nil } diff --git a/internal/sqltest/pgx.go b/internal/sqltest/pgx.go new file mode 100644 index 0000000000..13ecefe9e0 --- /dev/null +++ b/internal/sqltest/pgx.go @@ -0,0 +1,88 @@ +package sqltest + +import ( + "context" + "fmt" + "math/rand" + "os" + "path/filepath" + "testing" + "time" + + "github.com/kyleconroy/sqlc/internal/sql/sqlpath" + + "github.com/jackc/pgx/v4" +) + +func init() { + rand.Seed(time.Now().UnixNano()) +} + +func PostgreSQLPgx(t *testing.T, migrations []string) (*pgx.Conn, func()) { + t.Helper() + + pgUser := os.Getenv("PG_USER") + pgHost := os.Getenv("PG_HOST") + pgPort := os.Getenv("PG_PORT") + pgPass := os.Getenv("PG_PASSWORD") + pgDB := os.Getenv("PG_DATABASE") + + if pgUser == "" { + pgUser = "postgres" + } + + if pgPass == "" { + pgPass = "mysecretpassword" + } + + if pgPort == "" { + pgPort = "5432" + } + + if pgHost == "" { + pgHost = "127.0.0.1" + } + + if pgDB == "" { + pgDB = "dinotest" + } + + source := fmt.Sprintf("postgres://%s:%s@%s:%s/%s?sslmode=disable", pgUser, pgPass, pgHost, pgPort, pgDB) + t.Logf("db: %s", source) + + db, err := pgx.Connect(context.Background(), source) + if err != nil { + t.Fatal(err) + } + + // For each test, pick a new schema name at random. + schema := "sqltest_postgresql_" + id() + if _, err := db.Exec(context.Background(), "CREATE SCHEMA "+schema); err != nil { + t.Fatal(err) + } + + sdb, err := pgx.Connect(context.Background(), source+"&search_path="+schema) + if err != nil { + t.Fatal(err) + } + + files, err := sqlpath.Glob(migrations) + if err != nil { + t.Fatal(err) + } + for _, f := range files { + blob, err := os.ReadFile(f) + if err != nil { + t.Fatal(err) + } + if _, err := sdb.Exec(context.Background(), string(blob)); err != nil { + t.Fatalf("%s: %s", filepath.Base(f), err) + } + } + + return sdb, func() { + if _, err := db.Exec(context.Background(), "DROP SCHEMA "+schema+" CASCADE"); err != nil { + t.Fatal(err) + } + } +}