diff --git a/mock.go b/mock.go index 5d81270..f67354e 100644 --- a/mock.go +++ b/mock.go @@ -190,7 +190,7 @@ func (m *Mock) removeEventLocked(e event) { } } -func (m *Mock) matchCallLocked(c *Call) { +func (m *Mock) matchCallLocked(c *apiCall) { var traps []*Trap for _, t := range m.traps { if t.matches(c) { @@ -435,7 +435,7 @@ func (m *Mock) newTrap(fn clockFunction, tags []string) *Trap { fn: fn, tags: tags, mock: m, - calls: make(chan *Call), + calls: make(chan *apiCall), done: make(chan struct{}), } m.traps = append(m.traps, tr) @@ -557,9 +557,10 @@ const ( clockFunctionUntil ) -type callArg func(c *Call) +type callArg func(c *apiCall) -type Call struct { +// apiCall represents a single call to one of the Clock APIs. +type apiCall struct { Time time.Time Duration time.Duration Tags []string @@ -569,25 +570,36 @@ type Call struct { complete chan struct{} } +// Call represents an apiCall that has been trapped. +type Call struct { + Time time.Time + Duration time.Duration + Tags []string + + apiCall *apiCall + trap *Trap +} + func (c *Call) Release() { - c.releases.Done() - <-c.complete + c.apiCall.releases.Done() + <-c.apiCall.complete + c.trap.callReleased() } func withTime(t time.Time) callArg { - return func(c *Call) { + return func(c *apiCall) { c.Time = t } } func withDuration(d time.Duration) callArg { - return func(c *Call) { + return func(c *apiCall) { c.Duration = d } } -func newCall(fn clockFunction, tags []string, args ...callArg) *Call { - c := &Call{ +func newCall(fn clockFunction, tags []string, args ...callArg) *apiCall { + c := &apiCall{ fn: fn, Tags: tags, complete: make(chan struct{}), @@ -602,19 +614,23 @@ type Trap struct { fn clockFunction tags []string mock *Mock - calls chan *Call + calls chan *apiCall done chan struct{} + + // mu protects the unreleasedCalls count + mu sync.Mutex + unreleasedCalls int } -func (t *Trap) catch(c *Call) { +func (t *Trap) catch(c *apiCall) { select { case t.calls <- c: case <-t.done: - c.Release() + c.releases.Done() } } -func (t *Trap) matches(c *Call) bool { +func (t *Trap) matches(c *apiCall) bool { if t.fn != c.fn { return false } @@ -629,6 +645,10 @@ func (t *Trap) matches(c *Call) bool { func (t *Trap) Close() { t.mock.mu.Lock() defer t.mock.mu.Unlock() + if t.unreleasedCalls != 0 { + t.mock.tb.Helper() + t.mock.tb.Errorf("trap Closed() with %d unreleased calls", t.unreleasedCalls) + } for i, tr := range t.mock.traps { if t == tr { t.mock.traps = append(t.mock.traps[:i], t.mock.traps[i+1:]...) @@ -637,6 +657,12 @@ func (t *Trap) Close() { close(t.done) } +func (t *Trap) callReleased() { + t.mu.Lock() + defer t.mu.Unlock() + t.unreleasedCalls-- +} + var ErrTrapClosed = errors.New("trap closed") func (t *Trap) Wait(ctx context.Context) (*Call, error) { @@ -645,7 +671,17 @@ func (t *Trap) Wait(ctx context.Context) (*Call, error) { return nil, ctx.Err() case <-t.done: return nil, ErrTrapClosed - case c := <-t.calls: + case a := <-t.calls: + c := &Call{ + Time: a.Time, + Duration: a.Duration, + Tags: a.Tags, + apiCall: a, + trap: t, + } + t.mu.Lock() + defer t.mu.Unlock() + t.unreleasedCalls++ return c, nil } } diff --git a/mock_test.go b/mock_test.go index 8099738..2d2097a 100644 --- a/mock_test.go +++ b/mock_test.go @@ -319,3 +319,117 @@ func TestTickerFunc_LongCallback(t *testing.T) { } w.MustWait(testCtx) } + +func Test_MultipleTraps(t *testing.T) { + t.Parallel() + testCtx, testCancel := context.WithTimeout(context.Background(), 10*time.Second) + defer testCancel() + mClock := quartz.NewMock(t) + + trap0 := mClock.Trap().Now("0") + defer trap0.Close() + trap1 := mClock.Trap().Now("1") + defer trap1.Close() + + timeCh := make(chan time.Time) + go func() { + timeCh <- mClock.Now("0", "1") + }() + + c0 := trap0.MustWait(testCtx) + mClock.Advance(time.Second) + // the two trapped call instances need to be released on separate goroutines since they each wait for the Now() call + // to return, which is blocked on both releases happening. If you release them on the same goroutine, in either + // order, it will deadlock. + done := make(chan struct{}) + go func() { + defer close(done) + c0.Release() + }() + c1 := trap1.MustWait(testCtx) + mClock.Advance(time.Second) + c1.Release() + + select { + case <-done: + case <-testCtx.Done(): + t.Fatal("timed out waiting for c0.Release()") + } + + select { + case got := <-timeCh: + end := mClock.Now("end") + if !got.Equal(end) { + t.Fatalf("expected %s got %s", end, got) + } + case <-testCtx.Done(): + t.Fatal("timed out waiting for Now()") + } +} + +func Test_UnreleasedCalls(t *testing.T) { + t.Parallel() + tRunFail(t, func(t testing.TB) { + testCtx, testCancel := context.WithTimeout(context.Background(), 10*time.Second) + defer testCancel() + mClock := quartz.NewMock(t) + + trap := mClock.Trap().Now() + defer trap.Close() + + go func() { + _ = mClock.Now() + }() + + trap.MustWait(testCtx) // missing release + }) +} + +type captureFailTB struct { + failed bool + testing.TB +} + +func (t *captureFailTB) Errorf(format string, args ...any) { + t.Helper() + t.Logf(format, args...) + t.failed = true +} + +func (t *captureFailTB) Error(args ...any) { + t.Helper() + t.Log(args...) + t.failed = true +} + +func (t *captureFailTB) Fatal(args ...any) { + t.Helper() + t.Log(args...) + t.failed = true +} + +func (t *captureFailTB) Fatalf(format string, args ...any) { + t.Helper() + t.Logf(format, args...) + t.failed = true +} + +func (t *captureFailTB) Fail() { + t.failed = true +} + +func (t *captureFailTB) FailNow() { + t.failed = true +} + +func (t *captureFailTB) Failed() bool { + return t.failed +} + +func tRunFail(t testing.TB, f func(t testing.TB)) { + tb := &captureFailTB{TB: t} + f(tb) + if !tb.Failed() { + t.Fatal("want test to fail") + } +}