Skip to content

Commit ef3a00f

Browse files
authored
Cleanup encoding Startup message (vapor#395)
Further cleanup of message encoding: - Move Startup struct into PostgresFrontendMessageEncoder - Move PSQLMessagePayloadEncodable into tests, since it isn't used in PostgresNIO anymore - Only support the parameters that are actually used in encoding startup messages
1 parent d5c5258 commit ef3a00f

File tree

7 files changed

+98
-134
lines changed

7 files changed

+98
-134
lines changed

Diff for: Sources/PostgresNIO/New/Messages/Startup.swift

-52
This file was deleted.

Diff for: Sources/PostgresNIO/New/PostgresChannelHandler.swift

+1-12
Original file line numberDiff line numberDiff line change
@@ -328,7 +328,7 @@ final class PostgresChannelHandler: ChannelDuplexHandler {
328328
case .wait:
329329
break
330330
case .sendStartupMessage(let authContext):
331-
self.encoder.startup(authContext.toStartupParameters())
331+
self.encoder.startup(user: authContext.username, database: authContext.database)
332332
context.writeAndFlush(self.wrapOutboundOut(self.encoder.flushBuffer()), promise: nil)
333333
case .sendSSLRequest:
334334
self.encoder.ssl()
@@ -793,17 +793,6 @@ extension PostgresChannelHandler: PSQLRowsDataSource {
793793
}
794794
}
795795

796-
extension AuthContext {
797-
func toStartupParameters() -> PostgresFrontendMessage.Startup.Parameters {
798-
PostgresFrontendMessage.Startup.Parameters(
799-
user: self.username,
800-
database: self.database,
801-
options: nil,
802-
replication: .false
803-
)
804-
}
805-
}
806-
807796
private extension Insecure.MD5.Digest {
808797

809798
private static let lowercaseLookup: [UInt8] = [

Diff for: Sources/PostgresNIO/New/PostgresFrontendMessage.swift

+44-4
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,50 @@ enum PostgresFrontendMessage: Equatable {
102102
static let requestCode: Int32 = 80877103
103103
}
104104

105+
struct Startup: Hashable {
106+
static let versionThree: Int32 = 0x00_03_00_00
107+
108+
/// Creates a `Startup` with "3.0" as the protocol version.
109+
static func versionThree(parameters: Parameters) -> Startup {
110+
return .init(protocolVersion: Self.versionThree, parameters: parameters)
111+
}
112+
113+
/// The protocol version number. The most significant 16 bits are the major
114+
/// version number (3 for the protocol described here). The least significant
115+
/// 16 bits are the minor version number (0 for the protocol described here).
116+
var protocolVersion: Int32
117+
118+
/// The protocol version number is followed by one or more pairs of parameter
119+
/// name and value strings. A zero byte is required as a terminator after
120+
/// the last name/value pair. `user` is required, others are optional.
121+
struct Parameters: Hashable {
122+
enum Replication {
123+
case `true`
124+
case `false`
125+
case database
126+
}
127+
128+
/// The database user name to connect as. Required; there is no default.
129+
var user: String
130+
131+
/// The database to connect to. Defaults to the user name.
132+
var database: String?
133+
134+
/// Command-line arguments for the backend. (This is deprecated in favor
135+
/// of setting individual run-time parameters.) Spaces within this string are
136+
/// considered to separate arguments, unless escaped with a
137+
/// backslash (\); write \\ to represent a literal backslash.
138+
var options: String?
139+
140+
/// Used to connect in streaming replication mode, where a small set of
141+
/// replication commands can be issued instead of SQL statements. Value
142+
/// can be true, false, or database, and the default is false.
143+
var replication: Replication
144+
}
145+
146+
var parameters: Parameters
147+
}
148+
105149
case bind(Bind)
106150
case cancel(Cancel)
107151
case close(Close)
@@ -225,7 +269,3 @@ extension PostgresFrontendMessage {
225269
}
226270
}
227271
}
228-
229-
protocol PSQLMessagePayloadEncodable {
230-
func encode(into buffer: inout ByteBuffer)
231-
}

Diff for: Sources/PostgresNIO/New/PostgresFrontendMessageEncoder.swift

+3-19
Original file line numberDiff line numberDiff line change
@@ -13,34 +13,18 @@ struct PostgresFrontendMessageEncoder {
1313
self.buffer = buffer
1414
}
1515

16-
mutating func startup(_ parameters: PostgresFrontendMessage.Startup.Parameters) {
16+
mutating func startup(user: String, database: String?) {
1717
self.clearIfNeeded()
1818
self.encodeLengthPrefixed { buffer in
1919
buffer.writeInteger(PostgresFrontendMessage.Startup.versionThree)
2020
buffer.writeNullTerminatedString("user")
21-
buffer.writeNullTerminatedString(parameters.user)
21+
buffer.writeNullTerminatedString(user)
2222

23-
if let database = parameters.database {
23+
if let database = database {
2424
buffer.writeNullTerminatedString("database")
2525
buffer.writeNullTerminatedString(database)
2626
}
2727

28-
if let options = parameters.options {
29-
buffer.writeNullTerminatedString("options")
30-
buffer.writeNullTerminatedString(options)
31-
}
32-
33-
switch parameters.replication {
34-
case .database:
35-
buffer.writeNullTerminatedString("replication")
36-
buffer.writeNullTerminatedString("replication")
37-
case .true:
38-
buffer.writeNullTerminatedString("replication")
39-
buffer.writeNullTerminatedString("true")
40-
case .false:
41-
break
42-
}
43-
4428
buffer.writeInteger(UInt8(0))
4529
}
4630
}

Diff for: Tests/PostgresNIOTests/New/Extensions/PSQLBackendMessageEncoder.swift

+4
Original file line numberDiff line numberDiff line change
@@ -257,3 +257,7 @@ extension RowDescription: PSQLMessagePayloadEncodable {
257257
}
258258
}
259259
}
260+
261+
protocol PSQLMessagePayloadEncodable {
262+
func encode(into buffer: inout ByteBuffer)
263+
}

Diff for: Tests/PostgresNIOTests/New/Messages/StartupTests.swift

+35-47
Original file line numberDiff line numberDiff line change
@@ -4,56 +4,44 @@ import NIOCore
44

55
class StartupTests: XCTestCase {
66

7-
func testStartupMessage() {
7+
func testStartupMessageWithDatabase() {
88
var encoder = PostgresFrontendMessageEncoder(buffer: .init())
99
var byteBuffer = ByteBuffer()
10-
11-
let replicationValues: [PostgresFrontendMessage.Startup.Parameters.Replication] = [
12-
.`true`,
13-
.`false`,
14-
.database
15-
]
16-
17-
for replication in replicationValues {
18-
let parameters = PostgresFrontendMessage.Startup.Parameters(
19-
user: "test",
20-
database: "abc123",
21-
options: "some options",
22-
replication: replication
23-
)
24-
25-
encoder.startup(parameters)
26-
byteBuffer = encoder.flushBuffer()
27-
28-
let byteBufferLength = Int32(byteBuffer.readableBytes)
29-
XCTAssertEqual(byteBufferLength, byteBuffer.readInteger())
30-
XCTAssertEqual(PostgresFrontendMessage.Startup.versionThree, byteBuffer.readInteger())
31-
XCTAssertEqual(byteBuffer.readNullTerminatedString(), "user")
32-
XCTAssertEqual(byteBuffer.readNullTerminatedString(), "test")
33-
XCTAssertEqual(byteBuffer.readNullTerminatedString(), "database")
34-
XCTAssertEqual(byteBuffer.readNullTerminatedString(), "abc123")
35-
XCTAssertEqual(byteBuffer.readNullTerminatedString(), "options")
36-
XCTAssertEqual(byteBuffer.readNullTerminatedString(), "some options")
37-
if replication != .false {
38-
XCTAssertEqual(byteBuffer.readNullTerminatedString(), "replication")
39-
XCTAssertEqual(byteBuffer.readNullTerminatedString(), replication.stringValue)
40-
}
41-
XCTAssertEqual(byteBuffer.readInteger(), UInt8(0))
42-
43-
XCTAssertEqual(byteBuffer.readableBytes, 0)
44-
}
10+
11+
let user = "test"
12+
let database = "abc123"
13+
14+
encoder.startup(user: user, database: database)
15+
byteBuffer = encoder.flushBuffer()
16+
17+
let byteBufferLength = Int32(byteBuffer.readableBytes)
18+
XCTAssertEqual(byteBufferLength, byteBuffer.readInteger())
19+
XCTAssertEqual(PostgresFrontendMessage.Startup.versionThree, byteBuffer.readInteger())
20+
XCTAssertEqual(byteBuffer.readNullTerminatedString(), "user")
21+
XCTAssertEqual(byteBuffer.readNullTerminatedString(), "test")
22+
XCTAssertEqual(byteBuffer.readNullTerminatedString(), "database")
23+
XCTAssertEqual(byteBuffer.readNullTerminatedString(), "abc123")
24+
XCTAssertEqual(byteBuffer.readInteger(), UInt8(0))
25+
26+
XCTAssertEqual(byteBuffer.readableBytes, 0)
4527
}
46-
}
4728

48-
extension PostgresFrontendMessage.Startup.Parameters.Replication {
49-
var stringValue: String {
50-
switch self {
51-
case .true:
52-
return "true"
53-
case .false:
54-
return "false"
55-
case .database:
56-
return "replication"
57-
}
29+
func testStartupMessageWithoutDatabase() {
30+
var encoder = PostgresFrontendMessageEncoder(buffer: .init())
31+
var byteBuffer = ByteBuffer()
32+
33+
let user = "test"
34+
35+
encoder.startup(user: user, database: nil)
36+
byteBuffer = encoder.flushBuffer()
37+
38+
let byteBufferLength = Int32(byteBuffer.readableBytes)
39+
XCTAssertEqual(byteBufferLength, byteBuffer.readInteger())
40+
XCTAssertEqual(PostgresFrontendMessage.Startup.versionThree, byteBuffer.readInteger())
41+
XCTAssertEqual(byteBuffer.readNullTerminatedString(), "user")
42+
XCTAssertEqual(byteBuffer.readNullTerminatedString(), "test")
43+
XCTAssertEqual(byteBuffer.readInteger(), UInt8(0))
44+
45+
XCTAssertEqual(byteBuffer.readableBytes, 0)
5846
}
5947
}

Diff for: Tests/PostgresNIOTests/New/PostgresChannelHandlerTests.swift

+11
Original file line numberDiff line numberDiff line change
@@ -277,3 +277,14 @@ class TestEventHandler: ChannelInboundHandler {
277277
self.events.append(psqlEvent)
278278
}
279279
}
280+
281+
extension AuthContext {
282+
func toStartupParameters() -> PostgresFrontendMessage.Startup.Parameters {
283+
PostgresFrontendMessage.Startup.Parameters(
284+
user: self.username,
285+
database: self.database,
286+
options: nil,
287+
replication: .false
288+
)
289+
}
290+
}

0 commit comments

Comments
 (0)