diff --git a/README.md b/README.md index f0703c4..b0da401 100644 --- a/README.md +++ b/README.md @@ -222,7 +222,7 @@ func TestTrap(t *testing.T) { count++ }) call := trap.MustWait(ctx) - call.Release() + call.MustRelease(ctx) if call.Duration != time.Hour { t.Fatal("wrong duration") } @@ -268,15 +268,15 @@ func TestTrap2(t *testing.T) { }(mClock) // start - trap.MustWait(ctx).Release() + trap.MustWait(ctx).MustRelease(ctx) // phase 1 call := trap.MustWait(ctx) mClock.Advance(3*time.Second).MustWait(ctx) - call.Release() + call.MustRelease(ctx) // phase 2 call = trap.MustWait(ctx) mClock.Advance(5*time.Second).MustWait(ctx) - call.Release() + call.MustRelease(ctx) <-done // Now logs contains []string{"Phase 1 took 3s", "Phase 2 took 5s"} @@ -302,7 +302,7 @@ go func(){ }() call := trap.MustWait(ctx) mClock.Advance(time.Second).MustWait(ctx) -call.Release() +call.MustRelease(ctx) // call.Tags contains []string{"foo", "bar"} gotFoo := <-foo // 1s after start @@ -478,8 +478,8 @@ func TestTicker(t *testing.T) { trap := mClock.Trap().TickerFunc() defer trap.Close() // stop trapping at end go runMyTicker(mClock) // async calls TickerFunc() - call := trap.Wait(context.Background()) // waits for a call and blocks its return - call.Release() // allow the TickerFunc() call to return + call := trap.MustWait(context.Background()) // waits for a call and blocks its return + call.MustRelease(ctx) // allow the TickerFunc() call to return // optionally check the duration using call.Duration // Move the clock forward 1 tick mClock.Advance(time.Second).MustWait(context.Background()) @@ -527,9 +527,9 @@ go func(clock quartz.Clock) { measurement = clock.Since(start) }(mClock) -c := trap.Wait(ctx) +c := trap.MustWait(ctx) mClock.Advance(5*time.Second) -c.Release() +c.MustRelease(ctx) ``` We wait until we trap the `clock.Since()` call, which implies that `clock.Now()` has completed, then @@ -617,10 +617,10 @@ func TestInactivityTimer_Late(t *testing.T) { // Trigger the AfterFunc w := mClock.Advance(10*time.Minute) - c := trap.Wait(ctx) + c := trap.MustWait(ctx) // Advance the clock a few ms to simulate a busy system mClock.Advance(3*time.Millisecond) - c.Release() // Until() returns + c.MustRelease(ctx) // Until() returns w.MustWait(ctx) // Wait for the AfterFunc to wrap up // Assert that the timeoutLocked() function was called diff --git a/example_test.go b/example_test.go index f798cdd..489aea7 100644 --- a/example_test.go +++ b/example_test.go @@ -65,7 +65,7 @@ func TestExampleTickerFunc(t *testing.T) { // it's good practice to release calls before any possible t.Fatal() calls // so that we don't leave dangling goroutines waiting for the call to be // released. - call.Release() + call.MustRelease(ctx) if call.Duration != time.Hour { t.Fatal("unexpected duration") } @@ -122,7 +122,7 @@ func TestExampleLatencyMeasurer(t *testing.T) { w := mClock.Advance(10 * time.Second) // triggers first tick c := trap.MustWait(ctx) // call to Since() mClock.Advance(33 * time.Millisecond) - c.Release() + c.MustRelease(ctx) w.MustWait(ctx) if l := lm.LastLatency(); l != 33*time.Millisecond { @@ -133,7 +133,7 @@ func TestExampleLatencyMeasurer(t *testing.T) { d, w2 := mClock.AdvanceNext() c = trap.MustWait(ctx) mClock.Advance(17 * time.Millisecond) - c.Release() + c.MustRelease(ctx) w2.MustWait(ctx) expectedD := 10*time.Second - 33*time.Millisecond diff --git a/mock.go b/mock.go index 5d81270..1b0c48c 100644 --- a/mock.go +++ b/mock.go @@ -13,8 +13,9 @@ import ( // Mock is the testing implementation of Clock. It tracks a time that monotonically increases // during a test, triggering any timers or tickers automatically. type Mock struct { - tb testing.TB - mu sync.Mutex + tb testing.TB + mu sync.Mutex + testOver bool // cur is the current time cur time.Time @@ -190,13 +191,16 @@ func (m *Mock) removeEventLocked(e event) { } } -func (m *Mock) matchCallLocked(c *Call) { +func (m *Mock) matchCallLocked(c *apiCall) { var traps []*Trap for _, t := range m.traps { if t.matches(c) { traps = append(traps, t) } } + if !m.testOver { + m.tb.Logf("Mock Clock - %s call, matched %d traps", c, len(traps)) + } if len(traps) == 0 { return } @@ -260,6 +264,9 @@ func (m *Mock) Advance(d time.Duration) AdvanceWaiter { m.tb.Helper() w := AdvanceWaiter{tb: m.tb, ch: make(chan struct{})} m.mu.Lock() + if !m.testOver { + m.tb.Logf("Mock Clock - Advance(%s)", d) + } fin := m.cur.Add(d) // nextTime.IsZero implies no events scheduled. if m.nextTime.IsZero() || fin.Before(m.nextTime) { @@ -307,6 +314,9 @@ func (m *Mock) Set(t time.Time) AdvanceWaiter { m.tb.Helper() w := AdvanceWaiter{tb: m.tb, ch: make(chan struct{})} m.mu.Lock() + if !m.testOver { + m.tb.Logf("Mock Clock - Set(%s)", t) + } if t.Before(m.cur) { defer close(w.ch) defer m.mu.Unlock() @@ -343,6 +353,9 @@ func (m *Mock) Set(t time.Time) AdvanceWaiter { // wait for the timer/tick event(s) to finish. func (m *Mock) AdvanceNext() (time.Duration, AdvanceWaiter) { m.mu.Lock() + if !m.testOver { + m.tb.Logf("Mock Clock - AdvanceNext()") + } m.tb.Helper() w := AdvanceWaiter{tb: m.tb, ch: make(chan struct{})} if m.nextTime.IsZero() { @@ -431,11 +444,14 @@ func (m *Mock) Trap() Trapper { func (m *Mock) newTrap(fn clockFunction, tags []string) *Trap { m.mu.Lock() defer m.mu.Unlock() + if !m.testOver { + m.tb.Logf("Mock Clock - Trap %s(..., %v)", fn, tags) + } tr := &Trap{ fn: fn, tags: tags, mock: m, - calls: make(chan *Call), + calls: make(chan *apiCall), done: make(chan struct{}), } m.traps = append(m.traps, tr) @@ -450,10 +466,17 @@ func NewMock(tb testing.TB) *Mock { if err != nil { panic(err) } - return &Mock{ + m := &Mock{ tb: tb, cur: cur, } + tb.Cleanup(func() { + m.mu.Lock() + defer m.mu.Unlock() + m.testOver = true + tb.Logf("Mock Clock - test cleanup; will no longer log clock events") + }) + return m } var _ Clock = &Mock{} @@ -557,9 +580,41 @@ const ( clockFunctionUntil ) -type callArg func(c *Call) - -type Call struct { +func (c clockFunction) String() string { + switch c { + case clockFunctionNewTimer: + return "NewTimer" + case clockFunctionAfterFunc: + return "AfterFunc" + case clockFunctionTimerStop: + return "Timer.Stop" + case clockFunctionTimerReset: + return "Timer.Reset" + case clockFunctionTickerFunc: + return "TickerFunc" + case clockFunctionTickerFuncWait: + return "TickerFunc.Wait" + case clockFunctionNewTicker: + return "NewTicker" + case clockFunctionTickerReset: + return "Ticker.Reset" + case clockFunctionTickerStop: + return "Ticker.Stop" + case clockFunctionNow: + return "Now" + case clockFunctionSince: + return "Since" + case clockFunctionUntil: + return "Until" + default: + return fmt.Sprintf("Unknown clockFunction(%d)", c) + } +} + +type callArg func(c *apiCall) + +// apiCall represents a single call to one of the Clock APIs. +type apiCall struct { Time time.Time Duration time.Duration Tags []string @@ -569,25 +624,91 @@ type Call struct { complete chan struct{} } -func (c *Call) Release() { - c.releases.Done() - <-c.complete +func (a *apiCall) String() string { + switch a.fn { + case clockFunctionNewTimer: + return fmt.Sprintf("NewTimer(%s, %v)", a.Duration, a.Tags) + case clockFunctionAfterFunc: + return fmt.Sprintf("AfterFunc(%s, , %v)", a.Duration, a.Tags) + case clockFunctionTimerStop: + return fmt.Sprintf("Timer.Stop(%v)", a.Tags) + case clockFunctionTimerReset: + return fmt.Sprintf("Timer.Reset(%s, %v)", a.Duration, a.Tags) + case clockFunctionTickerFunc: + return fmt.Sprintf("TickerFunc(, %s, , %s)", a.Duration, a.Tags) + case clockFunctionTickerFuncWait: + return fmt.Sprintf("TickerFunc.Wait(%v)", a.Tags) + case clockFunctionNewTicker: + return fmt.Sprintf("NewTicker(%s, %v)", a.Duration, a.Tags) + case clockFunctionTickerReset: + return fmt.Sprintf("Ticker.Reset(%s, %v)", a.Duration, a.Tags) + case clockFunctionTickerStop: + return fmt.Sprintf("Ticker.Stop(%v)", a.Tags) + case clockFunctionNow: + return fmt.Sprintf("Now(%v)", a.Tags) + case clockFunctionSince: + return fmt.Sprintf("Since(%s, %v)", a.Time, a.Tags) + case clockFunctionUntil: + return fmt.Sprintf("Until(%s, %v)", a.Time, a.Tags) + default: + return fmt.Sprintf("Unknown clockFunction(%d)", a.fn) + } +} + +// Call represents an apiCall that has been trapped. +type Call struct { + Time time.Time + Duration time.Duration + Tags []string + + tb testing.TB + apiCall *apiCall + trap *Trap +} + +// Release the call and wait for it to complete. If the provided context expires before the call completes, it returns +// an error. +// +// IMPORTANT: If a call is trapped by more than one trap, they all must release the call before it can complete, and +// they must do so from different goroutines. +func (c *Call) Release(ctx context.Context) error { + c.apiCall.releases.Done() + select { + case <-ctx.Done(): + return fmt.Errorf("timed out waiting for release; did more than one trap capture the call?: %w", ctx.Err()) + case <-c.apiCall.complete: + // OK + } + c.trap.callReleased() + return nil +} + +// MustRelease releases the call and waits for it to complete. If the provided context expires before the call +// completes, it fails the test. +// +// IMPORTANT: If a call is trapped by more than one trap, they all must release the call before it can complete, and +// they must do so from different goroutines. +func (c *Call) MustRelease(ctx context.Context) { + if err := c.Release(ctx); err != nil { + c.tb.Helper() + c.tb.Fatal(err.Error()) + } } func withTime(t time.Time) callArg { - return func(c *Call) { + return func(c *apiCall) { c.Time = t } } func withDuration(d time.Duration) callArg { - return func(c *Call) { + return func(c *apiCall) { c.Duration = d } } -func newCall(fn clockFunction, tags []string, args ...callArg) *Call { - c := &Call{ +func newCall(fn clockFunction, tags []string, args ...callArg) *apiCall { + c := &apiCall{ fn: fn, Tags: tags, complete: make(chan struct{}), @@ -602,19 +723,23 @@ type Trap struct { fn clockFunction tags []string mock *Mock - calls chan *Call + calls chan *apiCall done chan struct{} + + // mu protects the unreleasedCalls count + mu sync.Mutex + unreleasedCalls int } -func (t *Trap) catch(c *Call) { +func (t *Trap) catch(c *apiCall) { select { case t.calls <- c: case <-t.done: - c.Release() + c.releases.Done() } } -func (t *Trap) matches(c *Call) bool { +func (t *Trap) matches(c *apiCall) bool { if t.fn != c.fn { return false } @@ -629,6 +754,10 @@ func (t *Trap) matches(c *Call) bool { func (t *Trap) Close() { t.mock.mu.Lock() defer t.mock.mu.Unlock() + if t.unreleasedCalls != 0 { + t.mock.tb.Helper() + t.mock.tb.Errorf("trap Closed() with %d unreleased calls", t.unreleasedCalls) + } for i, tr := range t.mock.traps { if t == tr { t.mock.traps = append(t.mock.traps[:i], t.mock.traps[i+1:]...) @@ -637,6 +766,12 @@ func (t *Trap) Close() { close(t.done) } +func (t *Trap) callReleased() { + t.mu.Lock() + defer t.mu.Unlock() + t.unreleasedCalls-- +} + var ErrTrapClosed = errors.New("trap closed") func (t *Trap) Wait(ctx context.Context) (*Call, error) { @@ -645,7 +780,18 @@ func (t *Trap) Wait(ctx context.Context) (*Call, error) { return nil, ctx.Err() case <-t.done: return nil, ErrTrapClosed - case c := <-t.calls: + case a := <-t.calls: + c := &Call{ + Time: a.Time, + Duration: a.Duration, + Tags: a.Tags, + apiCall: a, + trap: t, + tb: t.mock.tb, + } + t.mu.Lock() + defer t.mu.Unlock() + t.unreleasedCalls++ return c, nil } } diff --git a/mock_test.go b/mock_test.go index 8099738..5663762 100644 --- a/mock_test.go +++ b/mock_test.go @@ -24,7 +24,7 @@ func TestTimer_NegativeDuration(t *testing.T) { timers <- mClock.NewTimer(-time.Second) }() c := trap.MustWait(ctx) - c.Release() + c.MustRelease(ctx) // trap returns the actual passed value if c.Duration != -time.Second { t.Fatalf("expected -time.Second, got: %v", c.Duration) @@ -62,7 +62,7 @@ func TestAfterFunc_NegativeDuration(t *testing.T) { }) }() c := trap.MustWait(ctx) - c.Release() + c.MustRelease(ctx) // trap returns the actual passed value if c.Duration != -time.Second { t.Fatalf("expected -time.Second, got: %v", c.Duration) @@ -99,7 +99,7 @@ func TestNewTicker(t *testing.T) { tickers <- mClock.NewTicker(time.Hour, "new") }() c := trapNT.MustWait(ctx) - c.Release() + c.MustRelease(ctx) if c.Duration != time.Hour { t.Fatalf("expected time.Hour, got: %v", c.Duration) } @@ -123,7 +123,7 @@ func TestNewTicker(t *testing.T) { go tkr.Reset(time.Minute, "reset") c = trapReset.MustWait(ctx) mClock.Advance(time.Second).MustWait(ctx) - c.Release() + c.MustRelease(ctx) if c.Duration != time.Minute { t.Fatalf("expected time.Minute, got: %v", c.Duration) } @@ -142,7 +142,7 @@ func TestNewTicker(t *testing.T) { } go tkr.Stop("stop") - trapStop.MustWait(ctx).Release() + trapStop.MustWait(ctx).MustRelease(ctx) mClock.Advance(time.Hour).MustWait(ctx) select { case <-tkr.C: @@ -153,7 +153,7 @@ func TestNewTicker(t *testing.T) { // Resetting after stop go tkr.Reset(time.Minute, "reset") - trapReset.MustWait(ctx).Release() + trapReset.MustWait(ctx).MustRelease(ctx) mClock.Advance(time.Minute).MustWait(ctx) tTime = mClock.Now() select { @@ -319,3 +319,139 @@ func TestTickerFunc_LongCallback(t *testing.T) { } w.MustWait(testCtx) } + +func Test_MultipleTraps(t *testing.T) { + t.Parallel() + testCtx, testCancel := context.WithTimeout(context.Background(), 10*time.Second) + defer testCancel() + mClock := quartz.NewMock(t) + + trap0 := mClock.Trap().Now("0") + defer trap0.Close() + trap1 := mClock.Trap().Now("1") + defer trap1.Close() + + timeCh := make(chan time.Time) + go func() { + timeCh <- mClock.Now("0", "1") + }() + + c0 := trap0.MustWait(testCtx) + mClock.Advance(time.Second) + // the two trapped call instances need to be released on separate goroutines since they each wait for the Now() call + // to return, which is blocked on both releases happening. If you release them on the same goroutine, in either + // order, it will deadlock. + done := make(chan struct{}) + go func() { + defer close(done) + c0.MustRelease(testCtx) + }() + c1 := trap1.MustWait(testCtx) + mClock.Advance(time.Second) + c1.MustRelease(testCtx) + + select { + case <-done: + case <-testCtx.Done(): + t.Fatal("timed out waiting for c0.Release()") + } + + select { + case got := <-timeCh: + end := mClock.Now("end") + if !got.Equal(end) { + t.Fatalf("expected %s got %s", end, got) + } + case <-testCtx.Done(): + t.Fatal("timed out waiting for Now()") + } +} + +func Test_MultipleTrapsDeadlock(t *testing.T) { + t.Parallel() + tRunFail(t, func(t testing.TB) { + testCtx, testCancel := context.WithTimeout(context.Background(), 2*time.Second) + defer testCancel() + mClock := quartz.NewMock(t) + + trap0 := mClock.Trap().Now("0") + defer trap0.Close() + trap1 := mClock.Trap().Now("1") + defer trap1.Close() + + timeCh := make(chan time.Time) + go func() { + timeCh <- mClock.Now("0", "1") + }() + + c0 := trap0.MustWait(testCtx) + c0.MustRelease(testCtx) // deadlocks, test failure + }) +} + +func Test_UnreleasedCalls(t *testing.T) { + t.Parallel() + tRunFail(t, func(t testing.TB) { + testCtx, testCancel := context.WithTimeout(context.Background(), 10*time.Second) + defer testCancel() + mClock := quartz.NewMock(t) + + trap := mClock.Trap().Now() + defer trap.Close() + + go func() { + _ = mClock.Now() + }() + + trap.MustWait(testCtx) // missing release + }) +} + +type captureFailTB struct { + failed bool + testing.TB +} + +func (t *captureFailTB) Errorf(format string, args ...any) { + t.Helper() + t.Logf(format, args...) + t.failed = true +} + +func (t *captureFailTB) Error(args ...any) { + t.Helper() + t.Log(args...) + t.failed = true +} + +func (t *captureFailTB) Fatal(args ...any) { + t.Helper() + t.Log(args...) + t.failed = true +} + +func (t *captureFailTB) Fatalf(format string, args ...any) { + t.Helper() + t.Logf(format, args...) + t.failed = true +} + +func (t *captureFailTB) Fail() { + t.failed = true +} + +func (t *captureFailTB) FailNow() { + t.failed = true +} + +func (t *captureFailTB) Failed() bool { + return t.failed +} + +func tRunFail(t testing.TB, f func(t testing.TB)) { + tb := &captureFailTB{TB: t} + f(tb) + if !tb.Failed() { + t.Fatal("want test to fail") + } +}