From a8f9a620d82ca8781c4ac87b2a3b084d70c111b6 Mon Sep 17 00:00:00 2001 From: Garrett Date: Tue, 11 Jan 2022 19:06:54 +0000 Subject: [PATCH] Add NewConnectorListener --- notify.go | 36 ++++++++++++++++++++++++++++++++---- 1 file changed, 32 insertions(+), 4 deletions(-) diff --git a/notify.go b/notify.go index 5c421fdb8..37953ff2d 100644 --- a/notify.go +++ b/notify.go @@ -119,11 +119,11 @@ type ListenerConn struct { // NewListenerConn creates a new ListenerConn. Use NewListener instead. func NewListenerConn(name string, notificationChan chan<- *Notification) (*ListenerConn, error) { - return newDialListenerConn(defaultDialer{}, name, notificationChan) + return newDialListenerConn(defaultDialer{}, nil, name, notificationChan) } -func newDialListenerConn(d Dialer, name string, c chan<- *Notification) (*ListenerConn, error) { - cn, err := DialOpen(d, name) +func newDialListenerConn(d Dialer, connector driver.Connector, name string, c chan<- *Notification) (*ListenerConn, error) { + cn, err := getConn(connector, d, name) if err != nil { return nil, err } @@ -140,6 +140,15 @@ func newDialListenerConn(d Dialer, name string, c chan<- *Notification) (*Listen return l, nil } +// getConn uses driver.Connector if provided and falls back to Dialer otherwise. +func getConn(c driver.Connector, d Dialer, name string) (driver.Conn, error) { + if c != nil { + return c.Connect(context.Background()) + } + + return DialOpen(d, name) +} + // We can only allow one goroutine at a time to be running a query on the // connection for various reasons, so the goroutine sending on the connection // must be holding senderLock. @@ -470,6 +479,7 @@ type Listener struct { maxReconnectInterval time.Duration dialer Dialer eventCallback EventCallbackType + connector driver.Connector lock sync.Mutex isClosed bool @@ -502,19 +512,37 @@ func NewListener(name string, return NewDialListener(defaultDialer{}, name, minReconnectInterval, maxReconnectInterval, eventCallback) } +// NewConnectorListener is like NewListener but it takes a driver.Connector. +func NewConnectorListener(c driver.Connector, + name string, + minReconnectInterval time.Duration, + maxReconnectInterval time.Duration, + eventCallback EventCallbackType) *Listener { + return listener(c, defaultDialer{}, name, minReconnectInterval, maxReconnectInterval, eventCallback) +} + // NewDialListener is like NewListener but it takes a Dialer. func NewDialListener(d Dialer, name string, minReconnectInterval time.Duration, maxReconnectInterval time.Duration, eventCallback EventCallbackType) *Listener { + return listener(nil, d, name, minReconnectInterval, maxReconnectInterval, eventCallback) +} +func listener(c driver.Connector, + d Dialer, + name string, + minReconnectInterval time.Duration, + maxReconnectInterval time.Duration, + eventCallback EventCallbackType) *Listener { l := &Listener{ name: name, minReconnectInterval: minReconnectInterval, maxReconnectInterval: maxReconnectInterval, dialer: d, eventCallback: eventCallback, + connector: c, channels: make(map[string]struct{}), @@ -749,7 +777,7 @@ func (l *Listener) closed() bool { func (l *Listener) connect() error { notificationChan := make(chan *Notification, 32) - cn, err := newDialListenerConn(l.dialer, l.name, notificationChan) + cn, err := newDialListenerConn(l.dialer, l.connector, l.name, notificationChan) if err != nil { return err }