forked from vapor/postgres-nio
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathPostgresConnection+Database.swift
127 lines (111 loc) · 4.16 KB
/
PostgresConnection+Database.swift
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
import Logging
extension PostgresConnection: PostgresDatabase {
public func send(
_ request: PostgresRequest,
logger: Logger
) -> EventLoopFuture<Void> {
request.log(to: logger)
let promise = self.channel.eventLoop.makePromise(of: Void.self)
let request = PostgresRequestContext(delegate: request, promise: promise)
self.channel.write(request).cascadeFailure(to: promise)
self.channel.flush()
return promise.futureResult
}
public func withConnection<T>(_ closure: (PostgresConnection) -> EventLoopFuture<T>) -> EventLoopFuture<T> {
closure(self)
}
}
final class PostgresRequestContext {
let delegate: PostgresRequest
let promise: EventLoopPromise<Void>
var lastError: Error?
init(delegate: PostgresRequest, promise: EventLoopPromise<Void>) {
self.delegate = delegate
self.promise = promise
}
}
final class PostgresRequestHandler: ChannelDuplexHandler {
typealias InboundIn = PostgresMessage
typealias OutboundIn = PostgresRequestContext
typealias OutboundOut = PostgresMessage
private var queue: [PostgresRequestContext]
let logger: Logger
public init(logger: Logger) {
self.queue = []
self.logger = logger
}
private func _channelRead(context: ChannelHandlerContext, data: NIOAny) throws {
let message = self.unwrapInboundIn(data)
guard self.queue.count > 0 else {
// discard packet
return
}
let request = self.queue[0]
switch message.identifier {
case .error:
let error = try PostgresMessage.Error(message: message)
self.logger.error("\(error)")
request.lastError = PostgresError.server(error)
case .notice:
let notice = try PostgresMessage.Error(message: message)
self.logger.notice("\(notice)")
default: break
}
if let responses = try request.delegate.respond(to: message) {
for response in responses {
context.write(self.wrapOutboundOut(response), promise: nil)
}
context.flush()
} else {
self.queue.removeFirst()
if let error = request.lastError {
request.promise.fail(error)
} else {
request.promise.succeed(())
}
}
}
func channelRead(context: ChannelHandlerContext, data: NIOAny) {
do {
try self._channelRead(context: context, data: data)
} catch {
self.errorCaught(context: context, error: error)
}
// Regardless of error, also pass the message downstream; this is necessary for PostgresNotificationHandler (which is appended at the end) to receive notifications
context.fireChannelRead(data)
}
func write(context: ChannelHandlerContext, data: NIOAny, promise: EventLoopPromise<Void>?) {
let request = self.unwrapOutboundIn(data)
self.queue.append(request)
do {
let messages = try request.delegate.start()
self.write(context: context, items: messages, promise: promise)
context.flush()
} catch {
promise?.fail(error)
self.errorCaught(context: context, error: error)
}
}
func close(context: ChannelHandlerContext, mode: CloseMode, promise: EventLoopPromise<Void>?) {
let terminate = try! PostgresMessage.Terminate().message()
context.write(self.wrapOutboundOut(terminate), promise: nil)
context.close(mode: mode, promise: promise)
for current in self.queue {
current.promise.fail(PostgresError.connectionClosed)
}
self.queue = []
}
}
extension ChannelInboundHandler {
func write(context: ChannelHandlerContext, items: [OutboundOut], promise: EventLoopPromise<Void>?) {
var items = items
if let last = items.popLast() {
for item in items {
context.write(self.wrapOutboundOut(item), promise: nil)
}
context.write(self.wrapOutboundOut(last), promise: promise)
} else {
promise?.succeed(())
}
}
}