Skip to content

Commit 0fdca43

Browse files
committed
publicize listen/notify methods, fixes vapor#40
1 parent 1cd5eef commit 0fdca43

6 files changed

+70
-43
lines changed
Original file line numberDiff line numberDiff line change
@@ -1,29 +1,57 @@
11
extension PostgreSQLConnection {
2-
internal func listen(_ channelName: String, handler: @escaping (String) throws -> ()) throws -> Future<Void> {
3-
closeHandlers.append({ conn in
4-
return conn.send([.query(.init(query: "UNLISTEN \"\(channelName)\";"))], onResponse: { _ in })
5-
})
6-
7-
notificationHandlers[channelName] = { message in
8-
try handler(message)
9-
}
10-
return queue.enqueue([.query(.init(query: "LISTEN \"\(channelName)\";"))], onInput: { message in
2+
/// Begins listening for notifications on a channel.
3+
///
4+
/// LISTEN "<channel name>"
5+
///
6+
/// To subscribe to a channel, call `listen(...)` and provide the channel name.
7+
///
8+
/// conn.listen("foo") { message in
9+
/// print(message)
10+
/// return true
11+
/// }
12+
///
13+
/// Once a connection is listening, it may not be used to send further queries until `UNLISTEN` is sent.
14+
/// To unlisten, return `true` in the callback handler. Returning `false` will continue the subscription.
15+
///
16+
/// See `notify(...)` to send messages.
17+
///
18+
/// - parameters:
19+
/// - channelName: String identifier for the channel to subscribe to.
20+
/// - handler: Handles incoming String messages. Returning `true` here will end the subscription
21+
/// sending an `UNLISTEN` command.
22+
/// - returns: A future that signals completion of the `UNLISTEN` command.
23+
public func listen(_ channelName: String, handler: @escaping (String) throws -> (Bool)) -> Future<Void> {
24+
let promise = eventLoop.newPromise(Void.self)
25+
return queue.enqueue([.query(.init(query: "LISTEN \"\(channelName)\";"))]) { message in
1126
switch message {
12-
case let .notificationResponse(notification):
13-
try self.notificationHandlers[notification.channel]?(notification.message)
14-
default:
15-
break
27+
case .close: return false
28+
case .readyForQuery: return false
29+
case .notification(let notif):
30+
if try handler(notif.message) {
31+
self.simpleQuery("UNLISTEN \"\(channelName)\"").transform(to: ()).cascade(promise: promise)
32+
return true
33+
} else {
34+
return false
35+
}
36+
default: throw PostgreSQLError(identifier: "listen", reason: "Unexpected message during listen: \(message).")
1637
}
17-
return false
18-
})
19-
}
20-
21-
internal func notify(_ channelName: String, message: String) throws -> Future<Void> {
22-
return send([.query(.init(query: "NOTIFY \"\(channelName)\", '\(message)';"))]).map(to: Void.self, { _ in })
38+
}.flatMap { promise.futureResult }
2339
}
2440

25-
internal func unlisten(_ channelName: String, unlistenHandler: (() -> Void)? = nil) throws -> Future<Void> {
26-
notificationHandlers.removeValue(forKey: channelName)
27-
return send([.query(.init(query: "UNLISTEN \"\(channelName)\";"))], onResponse: { _ in unlistenHandler?() })
41+
/// Sends a notification to a listening connection. Use in conjunction with `listen(...)`.
42+
///
43+
/// NOTIFY "foo" 'hello'
44+
///
45+
/// A single connection can be used to send notifications to as many channels as desired.
46+
///
47+
/// conn.notify("foo", message: "hello")
48+
///
49+
/// - parameters:
50+
/// - channelName: String identifier for the channel to send to.
51+
/// - message: String message to send to subscribers.
52+
/// - returns: A future that signals completion of the send.
53+
public func notify(_ channelName: String, message: String) -> Future<Void> {
54+
return simpleQuery("NOTIFY \"\(channelName)\", '\(message)'").transform(to: ())
2855
}
2956
}
57+

Sources/PostgreSQL/Connection/PostgreSQLConnection.swift

+2-16
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ public final class PostgreSQLConnection: DatabaseConnection, BasicWorker {
2323
/// The channel
2424
internal let channel: Channel
2525

26+
/// Previously fetched table name cache
2627
internal var tableNameCache: TableNameCache?
2728

2829
/// In-flight `send(...)` futures.
@@ -31,15 +32,6 @@ public final class PostgreSQLConnection: DatabaseConnection, BasicWorker {
3132
/// The current query running, if one exists.
3233
private var pipeline: Future<Void>
3334

34-
/// Block type to be called on close of connection
35-
internal typealias CloseHandler = ((PostgreSQLConnection) -> Future<Void>)
36-
/// Called on close of the connection
37-
internal var closeHandlers = [CloseHandler]()
38-
/// Handler type for Notifications
39-
internal typealias NotificationHandler = (String) throws -> Void
40-
/// Handlers to be stored by channel name
41-
internal var notificationHandlers: [String: NotificationHandler] = [:]
42-
4335
/// Creates a new PostgreSQL client on the provided data source and sink.
4436
init(queue: QueueHandler<PostgreSQLMessage, PostgreSQLMessage>, channel: Channel) {
4537
self.queue = queue
@@ -122,13 +114,7 @@ public final class PostgreSQLConnection: DatabaseConnection, BasicWorker {
122114

123115
/// Executes close handlers before closing.
124116
private func executeCloseHandlersThenClose() -> Future<Void> {
125-
if let beforeClose = closeHandlers.popLast() {
126-
return beforeClose(self).then { _ in
127-
self.executeCloseHandlersThenClose()
128-
}
129-
} else {
130-
return channel.close(mode: .all)
131-
}
117+
return channel.close(mode: .all)
132118
}
133119

134120
/// Called when this class deinitializes.

Sources/PostgreSQL/Message/PostgreSQLMessage+0.swift

+1-1
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ enum PostgreSQLMessage {
3737
case notice(ErrorResponse)
3838

3939
/// Identifies the message as a notification response.
40-
case notificationResponse(NotificationResponse)
40+
case notification(Notification)
4141

4242
/// Identifies the message as a parameter description.
4343
case parameterDescription(ParameterDescription)

Sources/PostgreSQL/Message/PostgreSQLMessage+NotificationResponse.swift

+3-3
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
extension PostgreSQLMessage {
2-
struct NotificationResponse {
2+
struct Notification {
33
/// The message coming from PSQL
44
let processID: Int32
55

@@ -13,9 +13,9 @@ extension PostgreSQLMessage {
1313

1414
// MARK: Parse
1515

16-
extension PostgreSQLMessage.NotificationResponse {
16+
extension PostgreSQLMessage.Notification {
1717
/// Parses an instance of this message type from a byte buffer.
18-
static func parse(from buffer: inout ByteBuffer) throws -> PostgreSQLMessage.NotificationResponse {
18+
static func parse(from buffer: inout ByteBuffer) throws -> PostgreSQLMessage.Notification {
1919
guard let processID = buffer.readInteger(as: Int32.self) else {
2020
throw PostgreSQLError.protocol(reason: "Could not read process ID from notification response.")
2121
}

Sources/PostgreSQL/Pipeline/PostgreSQLMessageDecoder.swift

+1-1
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ final class PostgreSQLMessageDecoder: ByteToMessageDecoder {
6969
case .E: message = try .error(.parse(from: &buffer))
7070
case .n: message = .noData
7171
case .N: message = try .notice(.parse(from: &buffer))
72-
case .A: message = try .notificationResponse(.parse(from: &buffer))
72+
case .A: message = try .notification(.parse(from: &buffer))
7373
case .t: message = try .parameterDescription(.parse(from: &buffer))
7474
case .S: message = try .parameterStatus(.parse(from: &buffer))
7575
case .one: message = .parseComplete

Tests/PostgreSQLTests/PostgreSQLConnectionTests.swift

+13
Original file line numberDiff line numberDiff line change
@@ -523,6 +523,19 @@ class PostgreSQLConnectionTests: XCTestCase {
523523
default: XCTFail("invalid row count")
524524
}
525525
}
526+
527+
func testListen() throws {
528+
let conn = try PostgreSQLConnection.makeTest(transport: .cleartext)
529+
let done = conn.listen("foo") { message in
530+
XCTAssertEqual(message, "hi")
531+
return true
532+
}
533+
do {
534+
let conn = try PostgreSQLConnection.makeTest(transport: .cleartext)
535+
_ = try conn.notify("foo", message: "hi").wait()
536+
}
537+
try done.wait()
538+
}
526539

527540
static var allTests = [
528541
("testVersion", testVersion),

0 commit comments

Comments
 (0)