Skip to content

Commit e2f5401

Browse files
authored
test: add test database cleaner in subprocess (coder#19844)
fixes coder/internal#927 Adds a small subprocess that outlives the testing process to clean up any leaked test databases.
1 parent 596fdcb commit e2f5401

File tree

4 files changed

+249
-1
lines changed

4 files changed

+249
-1
lines changed

coderd/database/dbtestutil/broker.go

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package dbtestutil
22

33
import (
4+
"context"
45
"database/sql"
56
_ "embed"
67
"fmt"
@@ -25,6 +26,8 @@ type Broker struct {
2526
uuid uuid.UUID
2627
coderTestingDB *sql.DB
2728
refCount int
29+
// we keep a reference to the stdin of the cleaner so that Go doesn't garbage collect it.
30+
cleanerFD any
2831
}
2932

3033
func (b *Broker) Create(t TBSubset, opts ...OpenOption) (ConnectionParams, error) {
@@ -142,7 +145,16 @@ func (b *Broker) init(t TBSubset) error {
142145
return xerrors.Errorf("ping '%s' database: %w", CoderTestingDBName, err)
143146
}
144147
b.coderTestingDB = coderTestingDB
145-
b.uuid = uuid.New()
148+
149+
if b.uuid == uuid.Nil {
150+
b.uuid = uuid.New()
151+
ctx, cancel := context.WithTimeout(context.Background(), 20*time.Second)
152+
defer cancel()
153+
b.cleanerFD, err = startCleaner(ctx, b.uuid, coderTestingParams.DSN())
154+
if err != nil {
155+
return xerrors.Errorf("start test db cleaner: %w", err)
156+
}
157+
}
146158
return nil
147159
}
148160

Lines changed: 204 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,204 @@
1+
package dbtestutil
2+
3+
import (
4+
"context"
5+
"database/sql"
6+
"fmt"
7+
"io"
8+
"os"
9+
"os/exec"
10+
"os/signal"
11+
"time"
12+
13+
"github.com/google/uuid"
14+
"golang.org/x/xerrors"
15+
16+
"cdr.dev/slog"
17+
"cdr.dev/slog/sloggers/sloghuman"
18+
"github.com/coder/retry"
19+
)
20+
21+
const (
22+
cleanerRespOK = "OK"
23+
envCleanerParentUUID = "DB_CLEANER_PARENT_UUID"
24+
envCleanerDSN = "DB_CLEANER_DSN"
25+
)
26+
27+
var (
28+
originalWorkingDir string
29+
errGettingWorkingDir error
30+
)
31+
32+
func init() {
33+
// We expect our tests to run from somewhere in the project tree where `go run` below in `startCleaner` will
34+
// be able to resolve the command package. However, some of the tests modify the working directory during the run.
35+
// So, we grab the working directory during package init, before tests are run, and then set that work dir on the
36+
// subcommand process before it starts.
37+
originalWorkingDir, errGettingWorkingDir = os.Getwd()
38+
}
39+
40+
// startCleaner starts the cleaner in a subprocess. holdThis is an opaque reference that needs to be kept from being
41+
// garbage collected until we are done with all test databases (e.g. the end of the process).
42+
func startCleaner(ctx context.Context, parentUUID uuid.UUID, dsn string) (holdThis any, err error) {
43+
cmd := exec.Command("go", "run", "github.com/coder/coder/v2/coderd/database/dbtestutil/cleanercmd")
44+
cmd.Env = append(os.Environ(),
45+
fmt.Sprintf("%s=%s", envCleanerParentUUID, parentUUID.String()),
46+
fmt.Sprintf("%s=%s", envCleanerDSN, dsn),
47+
)
48+
49+
// c.f. comment on `func init()` in this file.
50+
if errGettingWorkingDir != nil {
51+
return nil, xerrors.Errorf("failed to get working directory during init: %w", errGettingWorkingDir)
52+
}
53+
cmd.Dir = originalWorkingDir
54+
55+
// Here we don't actually use the reference to the stdin pipe, because we never write anything to it. When this
56+
// process exits, the pipe is closed by the OS and this triggers the cleaner to do its cleaning work. But, we do
57+
// need to hang on to a reference to it so that it doesn't get garbage collected and trigger cleanup early.
58+
stdin, err := cmd.StdinPipe()
59+
if err != nil {
60+
return nil, xerrors.Errorf("failed to open stdin pipe: %w", err)
61+
}
62+
stdout, err := cmd.StdoutPipe()
63+
if err != nil {
64+
return nil, xerrors.Errorf("failed to open stdout pipe: %w", err)
65+
}
66+
// uncomment this to see log output from the cleaner
67+
// cmd.Stderr = os.Stderr
68+
err = cmd.Start()
69+
if err != nil {
70+
return nil, xerrors.Errorf("failed to start broker: %w", err)
71+
}
72+
outCh := make(chan []byte, 1)
73+
errCh := make(chan error, 1)
74+
go func() {
75+
buf := make([]byte, 1024)
76+
n, readErr := stdout.Read(buf)
77+
if readErr != nil {
78+
errCh <- readErr
79+
return
80+
}
81+
outCh <- buf[:n]
82+
}()
83+
select {
84+
case <-ctx.Done():
85+
_ = cmd.Process.Kill()
86+
return nil, ctx.Err()
87+
case err := <-errCh:
88+
return nil, xerrors.Errorf("failed to read db test cleaner output: %w", err)
89+
case out := <-outCh:
90+
if string(out) != cleanerRespOK {
91+
return nil, xerrors.Errorf("db test cleaner error: %s", string(out))
92+
}
93+
return stdin, nil
94+
}
95+
}
96+
97+
type cleaner struct {
98+
parentUUID uuid.UUID
99+
logger slog.Logger
100+
db *sql.DB
101+
}
102+
103+
func (c *cleaner) init(ctx context.Context) error {
104+
var err error
105+
dsn := os.Getenv(envCleanerDSN)
106+
if dsn == "" {
107+
return xerrors.Errorf("DSN not set via env %s: %w", envCleanerDSN, err)
108+
}
109+
parentUUIDStr := os.Getenv(envCleanerParentUUID)
110+
c.parentUUID, err = uuid.Parse(parentUUIDStr)
111+
if err != nil {
112+
return xerrors.Errorf("failed to parse parent UUID '%s': %w", parentUUIDStr, err)
113+
}
114+
c.logger = slog.Make(sloghuman.Sink(os.Stderr)).
115+
Named("dbtestcleaner").
116+
Leveled(slog.LevelDebug).
117+
With(slog.F("parent_uuid", parentUUIDStr))
118+
119+
c.db, err = sql.Open("postgres", dsn)
120+
if err != nil {
121+
return xerrors.Errorf("couldn't open DB: %w", err)
122+
}
123+
for r := retry.New(10*time.Millisecond, 500*time.Millisecond); r.Wait(ctx); {
124+
err = c.db.PingContext(ctx)
125+
if err == nil {
126+
return nil
127+
}
128+
c.logger.Error(ctx, "failed to ping DB", slog.Error(err))
129+
}
130+
return ctx.Err()
131+
}
132+
133+
// waitAndClean waits for stdin to close then attempts to clean up any test databases with our parent's UUID. This
134+
// is best-effort. If we hit an error we exit.
135+
//
136+
// We log to stderr for debugging, but we don't expect this output to normally be available since the parent has
137+
// exited. Uncomment the line `cmd.Stderr = os.Stderr` in startCleaner() to see this output.
138+
func (c *cleaner) waitAndClean() {
139+
c.logger.Debug(context.Background(), "waiting for stdin to close")
140+
_, _ = io.ReadAll(os.Stdin) // here we're just waiting for stdin to close
141+
c.logger.Debug(context.Background(), "stdin closed")
142+
rows, err := c.db.Query(
143+
"SELECT name FROM test_databases WHERE process_uuid = $1 AND dropped_at IS NULL",
144+
c.parentUUID,
145+
)
146+
if err != nil {
147+
c.logger.Error(context.Background(), "error querying test databases", slog.Error(err))
148+
return
149+
}
150+
defer func() {
151+
_ = rows.Close()
152+
}()
153+
names := make([]string, 0)
154+
for rows.Next() {
155+
var name string
156+
if err := rows.Scan(&name); err != nil {
157+
continue
158+
}
159+
names = append(names, name)
160+
}
161+
if closeErr := rows.Close(); closeErr != nil {
162+
c.logger.Error(context.Background(), "error closing rows", slog.Error(closeErr))
163+
}
164+
c.logger.Debug(context.Background(), "queried names", slog.F("names", names))
165+
for _, name := range names {
166+
_, err := c.db.Exec(fmt.Sprintf("DROP DATABASE IF EXISTS %s", name))
167+
if err != nil {
168+
c.logger.Error(context.Background(), "error dropping database", slog.Error(err), slog.F("name", name))
169+
return
170+
}
171+
_, err = c.db.Exec("UPDATE test_databases SET dropped_at = CURRENT_TIMESTAMP WHERE name = $1", name)
172+
if err != nil {
173+
c.logger.Error(context.Background(), "error dropping database", slog.Error(err), slog.F("name", name))
174+
return
175+
}
176+
}
177+
c.logger.Debug(context.Background(), "finished cleaning")
178+
}
179+
180+
// RunCleaner runs the test database cleaning process. It takes no arguments but uses stdio and environment variables
181+
// for its operation. It is designed to be launched as the only task of a `main()` process, but is included in this
182+
// package to share constants with the parent code that launches it above.
183+
//
184+
// The cleaner is designed to run in a separate process from the main test suite, connected over stdio. If the main test
185+
// process ends (panics, times out, or is killed) without explicitly discarding the databases it clones, the cleaner
186+
// removes them so they don't leak beyond the test session. c.f. https://github.com/coder/internal/issues/927
187+
func RunCleaner() {
188+
c := cleaner{}
189+
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
190+
defer cancel()
191+
// canceling a test via the IDE sends us an interrupt signal. We only want to process that signal during init. After
192+
// we want to ignore the signal and do our cleaning.
193+
signalCtx, signalCancel := signal.NotifyContext(ctx, os.Interrupt)
194+
defer signalCancel()
195+
err := c.init(signalCtx)
196+
if err != nil {
197+
_, _ = fmt.Fprintf(os.Stdout, "failed to init: %s", err.Error())
198+
_ = os.Stdout.Close()
199+
return
200+
}
201+
_, _ = fmt.Fprint(os.Stdout, cleanerRespOK)
202+
_ = os.Stdout.Close()
203+
c.waitAndClean()
204+
}
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
package main
2+
3+
import "github.com/coder/coder/v2/coderd/database/dbtestutil"
4+
5+
func main() {
6+
dbtestutil.RunCleaner()
7+
}

coderd/database/dbtestutil/postgres_test.go

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ package dbtestutil_test
33
import (
44
"database/sql"
55
"testing"
6+
"time"
67

78
_ "github.com/lib/pq"
89
"github.com/stretchr/testify/require"
@@ -110,3 +111,27 @@ func TestOpen_ValidDBFrom(t *testing.T) {
110111
require.True(t, rows.Next())
111112
require.NoError(t, rows.Close())
112113
}
114+
115+
func TestOpen_Panic(t *testing.T) {
116+
t.Skip("unskip this to manually test that we don't leak a database into postgres")
117+
t.Parallel()
118+
if !dbtestutil.WillUsePostgres() {
119+
t.Skip("this test requires postgres")
120+
}
121+
122+
_, err := dbtestutil.Open(t)
123+
require.NoError(t, err)
124+
panic("now check SELECT datname FROM pg_database;")
125+
}
126+
127+
func TestOpen_Timeout(t *testing.T) {
128+
t.Skip("unskip this and set a short timeout to manually test that we don't leak a database into postgres")
129+
t.Parallel()
130+
if !dbtestutil.WillUsePostgres() {
131+
t.Skip("this test requires postgres")
132+
}
133+
134+
_, err := dbtestutil.Open(t)
135+
require.NoError(t, err)
136+
time.Sleep(11 * time.Minute)
137+
}

0 commit comments

Comments
 (0)