Skip to content

Commit 2b20d60

Browse files
author
Shaun Hubbard
committed
added the ability to listen to multiple channels
1 parent 54e9e04 commit 2b20d60

File tree

4 files changed

+91
-8
lines changed

4 files changed

+91
-8
lines changed

Sources/PostgreSQL/Connection/PostgreSQLConnection+NotifyAndListen.swift

+13-2
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,26 @@
11
import Async
22

3+
34
extension PostgreSQLConnection {
45
/// Note: after calling `listen'` on a connection, it can no longer handle other database operations. Do not try to send other SQL commands through this connection afterwards.
56
/// IAlso, notifications will only be sent for as long as this connection remains open; you are responsible for opening a new connection to listen on when this one closes.
67
public func listen(
78
_ channelName: String,
89
handler: @escaping (String) throws -> ()
910
) throws -> Future<Void> {
10-
beforeClose = { conn in
11+
closeHandlers.append({ conn in
1112
let query = PostgreSQLQuery(query: "UNLISTEN \"\(channelName)\";")
1213
return conn.send([.query(query)], onResponse: { _ in })
14+
})
15+
16+
notificationHandlers[channelName] = { message in
17+
try handler(message)
1318
}
1419
let query = PostgreSQLQuery(query: "LISTEN \"\(channelName)\";")
1520
return queue.enqueue([.query(query)], onInput: { message in
1621
switch message {
1722
case let .notificationResponse(notification):
18-
try handler(notification.message)
23+
try self.notificationHandlers[notification.channel]?(notification.message)
1924
default:
2025
break
2126
}
@@ -28,4 +33,10 @@ extension PostgreSQLConnection {
2833
let query = PostgreSQLQuery(query: "NOTIFY \"\(channelName)\", '\(message)';")
2934
return send([.query(query)]).map(to: Void.self, { _ in })
3035
}
36+
37+
public func unlisten(_ channelName: String, unlistenHandler: (() -> Void)? = nil) throws -> Future<Void> {
38+
notificationHandlers.removeValue(forKey: channelName)
39+
let query = PostgreSQLQuery(query: "UNLISTEN \"\(channelName)\";")
40+
return send([.query(query)], onResponse: { _ in unlistenHandler?() })
41+
}
3142
}

Sources/PostgreSQL/Connection/PostgreSQLConnection.swift

+19-5
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,15 @@ public final class PostgreSQLConnection: DatabaseConnection, BasicWorker {
4444
/// The current query running, if one exists.
4545
private var pipeline: Future<Void>
4646

47+
/// Block type to be called on close of connection
48+
internal typealias CloseHandler = ((PostgreSQLConnection) -> Future<Void>)
49+
/// Called on close of the connection
50+
internal var closeHandlers = [CloseHandler]()
51+
/// Handler type for Notifications
52+
internal typealias NotificationHandler = (String) throws -> Void
53+
/// Handlers to be stored by channel name
54+
internal var notificationHandlers: [String: NotificationHandler] = [:]
55+
4756
/// Creates a new Redis client on the provided data source and sink.
4857
init(queue: QueueHandler<PostgreSQLMessage, PostgreSQLMessage>, channel: Channel) {
4958
self.queue = queue
@@ -184,19 +193,24 @@ public final class PostgreSQLConnection: DatabaseConnection, BasicWorker {
184193
}
185194
}
186195

187-
internal var beforeClose: ((PostgreSQLConnection) -> Future<Void>)?
188196

189197
/// Closes this client.
190198
public func close() {
191-
if let beforeClose = beforeClose {
192-
_ = beforeClose(self).then { _ in
193-
self.channel.close(mode: CloseMode.all)
199+
_ = executeCloseHandlersThenClose()
200+
}
201+
202+
203+
private func executeCloseHandlersThenClose() -> Future<Void> {
204+
if let beforeClose = closeHandlers.popLast() {
205+
return beforeClose(self).then { _ in
206+
self.executeCloseHandlersThenClose()
194207
}
195208
} else {
196-
channel.close(promise: nil)
209+
return channel.close(mode: .all)
197210
}
198211
}
199212

213+
200214
/// Called when this class deinitializes.
201215
deinit {
202216
close()

Sources/PostgreSQL/Message/PostgreSQLNotificationResponse.swift

-1
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,5 @@ struct PostgreSQLNotificationResponse: Decodable {
99
_ = try container.decode(Int32.self)
1010
channel = try container.decode(String.self)
1111
message = try container.decode(String.self)
12-
NSLog("Found self \(channel) \(message)")
1312
}
1413
}

Tests/PostgreSQLTests/PostgreSQLConnectionTests.swift

+59
Original file line numberDiff line numberDiff line change
@@ -368,6 +368,63 @@ class PostgreSQLConnectionTests: XCTestCase {
368368
listenConn.close()
369369
}
370370

371+
func testNotifyAndListenOnMultipleChannels() throws {
372+
let completionHandlerExpectation1 = expectation(description: "first completion handler called")
373+
let completionHandlerExpectation2 = expectation(description: "final completion handler called")
374+
let notifyConn = try PostgreSQLConnection.makeTest()
375+
let listenConn = try PostgreSQLConnection.makeTest()
376+
let channelName = "Fooze"
377+
let channelName2 = "Foozalz"
378+
let messageText = "Bar"
379+
let finalMessageText = "Baz"
380+
381+
try listenConn.listen(channelName) { text in
382+
if text == messageText {
383+
completionHandlerExpectation1.fulfill()
384+
}
385+
386+
}.catch({ err in XCTFail("error \(err)") })
387+
388+
try listenConn.listen(channelName2) { text in
389+
if text == finalMessageText {
390+
completionHandlerExpectation2.fulfill()
391+
}
392+
}.catch({ err in XCTFail("error \(err)") })
393+
394+
try notifyConn.notify(channelName, message: messageText).wait()
395+
try notifyConn.notify(channelName2, message: finalMessageText).wait()
396+
397+
waitForExpectations(timeout: defaultTimeout)
398+
notifyConn.close()
399+
listenConn.close()
400+
}
401+
402+
func testUnlisten() throws {
403+
let completionHandlerExpectation = expectation(description: "notify completion handler called")
404+
completionHandlerExpectation.expectedFulfillmentCount = 2
405+
completionHandlerExpectation.assertForOverFulfill = true
406+
407+
let notifyConn = try PostgreSQLConnection.makeTest()
408+
let listenConn = try PostgreSQLConnection.makeTest()
409+
let channelName = "Foozers"
410+
let messageText = "Bar"
411+
412+
try listenConn.listen(channelName) { text in
413+
if text == messageText {
414+
completionHandlerExpectation.fulfill()
415+
}
416+
}.catch({ err in XCTFail("error \(err)") })
417+
418+
try notifyConn.notify(channelName, message: messageText).wait()
419+
try notifyConn.unlisten(channelName, unlistenHandler: {
420+
completionHandlerExpectation.fulfill()
421+
}).wait()
422+
waitForExpectations(timeout: defaultTimeout)
423+
424+
notifyConn.close()
425+
listenConn.close()
426+
}
427+
371428
func testURLParsing() throws {
372429
let databaseURL = "postgres://username:password@hostname.com:5432/database"
373430
let config = try PostgreSQLDatabaseConfig(url: databaseURL)
@@ -388,6 +445,8 @@ class PostgreSQLConnectionTests: XCTestCase {
388445
("testNull", testNull),
389446
("testGH24", testGH24),
390447
("testNotifyAndListen", testNotifyAndListen),
448+
("testNotifyAndListenOnMultipleChannels", testNotifyAndListenOnMultipleChannels),
449+
("testUnlisten", testUnlisten),
391450
("testURLParsing", testURLParsing),
392451
]
393452
}

0 commit comments

Comments
 (0)