@@ -10,7 +10,7 @@ protocol PSQLChannelHandlerNotificationDelegate: AnyObject {
1010final class PSQLChannelHandler : ChannelDuplexHandler {
1111 typealias OutboundIn = PSQLTask
1212 typealias InboundIn = ByteBuffer
13- typealias OutboundOut = PSQLFrontendMessage
13+ typealias OutboundOut = ByteBuffer
1414
1515 private let logger : Logger
1616 private var state : ConnectionStateMachine {
@@ -25,32 +25,33 @@ final class PSQLChannelHandler: ChannelDuplexHandler {
2525 private var handlerContext : ChannelHandlerContext !
2626 private var rowStream : PSQLRowStream ?
2727 private var decoder : NIOSingleStepByteToMessageProcessor < PSQLBackendMessageDecoder >
28- private let authentificationConfiguration : PSQLConnection . Configuration . Authentication ?
28+ private var encoder : BufferedMessageEncoder < PSQLFrontendMessageEncoder > !
29+ private let configuration : PSQLConnection . Configuration
2930 private let configureSSLCallback : ( ( Channel ) throws -> Void ) ?
3031
3132 /// this delegate should only be accessed on the connections `EventLoop`
3233 weak var notificationDelegate : PSQLChannelHandlerNotificationDelegate ?
3334
34- init ( authentification : PSQLConnection . Configuration . Authentication ? ,
35+ init ( configuration : PSQLConnection . Configuration ,
3536 logger: Logger ,
3637 configureSSLCallback: ( ( Channel ) throws -> Void ) ? )
3738 {
3839 self . state = ConnectionStateMachine ( )
39- self . authentificationConfiguration = authentification
40+ self . configuration = configuration
4041 self . configureSSLCallback = configureSSLCallback
4142 self . logger = logger
4243 self . decoder = NIOSingleStepByteToMessageProcessor ( PSQLBackendMessageDecoder ( ) )
4344 }
4445
4546 #if DEBUG
4647 /// for testing purposes only
47- init ( authentification : PSQLConnection . Configuration . Authentication ? ,
48+ init ( configuration : PSQLConnection . Configuration ,
4849 state: ConnectionStateMachine = . init( . initialized) ,
4950 logger: Logger = . psqlNoOpLogger,
5051 configureSSLCallback: ( ( Channel ) throws -> Void ) ? )
5152 {
5253 self . state = state
53- self . authentificationConfiguration = authentification
54+ self . configuration = configuration
5455 self . configureSSLCallback = configureSSLCallback
5556 self . logger = logger
5657 self . decoder = NIOSingleStepByteToMessageProcessor ( PSQLBackendMessageDecoder ( ) )
@@ -61,6 +62,11 @@ final class PSQLChannelHandler: ChannelDuplexHandler {
6162
6263 func handlerAdded( context: ChannelHandlerContext ) {
6364 self . handlerContext = context
65+ self . encoder = BufferedMessageEncoder (
66+ buffer: context. channel. allocator. buffer ( capacity: 256 ) ,
67+ encoder: PSQLFrontendMessageEncoder ( jsonEncoder: self . configuration. coders. jsonEncoder)
68+ )
69+
6470 if context. channel. isActive {
6571 self . connected ( context: context)
6672 }
@@ -222,15 +228,19 @@ final class PSQLChannelHandler: ChannelDuplexHandler {
222228 case . wait:
223229 break
224230 case . sendStartupMessage( let authContext) :
225- context. writeAndFlush ( . startup( . versionThree( parameters: authContext. toStartupParameters ( ) ) ) , promise: nil )
231+ try ! self . encoder. encode ( . startup( . versionThree( parameters: authContext. toStartupParameters ( ) ) ) )
232+ context. writeAndFlush ( self . wrapOutboundOut ( self . encoder. flush ( ) !) , promise: nil )
226233 case . sendSSLRequest:
227- context. writeAndFlush ( . sslRequest( . init( ) ) , promise: nil )
234+ try ! self . encoder. encode ( . sslRequest( . init( ) ) )
235+ context. writeAndFlush ( self . wrapOutboundOut ( self . encoder. flush ( ) !) , promise: nil )
228236 case . sendPasswordMessage( let mode, let authContext) :
229237 self . sendPasswordMessage ( mode: mode, authContext: authContext, context: context)
230238 case . sendSaslInitialResponse( let name, let initialResponse) :
231- context. writeAndFlush ( . saslInitialResponse( . init( saslMechanism: name, initialData: initialResponse) ) )
239+ try ! self . encoder. encode ( . saslInitialResponse( . init( saslMechanism: name, initialData: initialResponse) ) )
240+ context. writeAndFlush ( self . wrapOutboundOut ( self . encoder. flush ( ) !) , promise: nil )
232241 case . sendSaslResponse( let bytes) :
233- context. writeAndFlush ( . saslResponse( . init( data: bytes) ) )
242+ try ! self . encoder. encode ( . saslResponse( . init( data: bytes) ) )
243+ context. writeAndFlush ( self . wrapOutboundOut ( self . encoder. flush ( ) !) , promise: nil )
234244 case . closeConnectionAndCleanup( let cleanupContext) :
235245 self . closeConnectionAndCleanup ( cleanupContext, context: context)
236246 case . fireChannelInactive:
@@ -277,7 +287,7 @@ final class PSQLChannelHandler: ChannelDuplexHandler {
277287 case . provideAuthenticationContext:
278288 context. fireUserInboundEventTriggered ( PSQLEvent . readyForStartup)
279289
280- if let authentication = self . authentificationConfiguration {
290+ if let authentication = self . configuration . authentication {
281291 let authContext = AuthContext (
282292 username: authentication. username,
283293 password: authentication. password,
@@ -293,7 +303,8 @@ final class PSQLChannelHandler: ChannelDuplexHandler {
293303 // The normal, graceful termination procedure is that the frontend sends a Terminate
294304 // message and immediately closes the connection. On receipt of this message, the
295305 // backend closes the connection and terminates.
296- context. write ( . terminate, promise: nil )
306+ try ! self . encoder. encode ( . terminate)
307+ context. writeAndFlush ( self . wrapOutboundOut ( self . encoder. flush ( ) !) , promise: nil )
297308 }
298309 context. close ( mode: . all, promise: promise)
299310 case . succeedPreparedStatementCreation( let preparedContext, with: let rowDescription) :
@@ -357,22 +368,26 @@ final class PSQLChannelHandler: ChannelDuplexHandler {
357368 hash2. append ( salt. 3 )
358369 let hash = " md5 " + Insecure. MD5. hash ( data: hash2) . hexdigest ( )
359370
360- context. writeAndFlush ( . password( . init( value: hash) ) , promise: nil )
371+ try ! self . encoder. encode ( . password( . init( value: hash) ) )
372+ context. writeAndFlush ( self . wrapOutboundOut ( self . encoder. flush ( ) !) , promise: nil )
373+
361374 case . cleartext:
362- context. writeAndFlush ( . password( . init( value: authContext. password ?? " " ) ) , promise: nil )
375+ try ! self . encoder. encode ( . password( . init( value: authContext. password ?? " " ) ) )
376+ context. writeAndFlush ( self . wrapOutboundOut ( self . encoder. flush ( ) !) , promise: nil )
363377 }
364378 }
365379
366380 private func sendCloseAndSyncMessage( _ sendClose: CloseTarget , context: ChannelHandlerContext ) {
367381 switch sendClose {
368382 case . preparedStatement( let name) :
369- context. write ( . close( . preparedStatement( name) ) , promise: nil )
370- context. write ( . sync, promise: nil )
371- context. flush ( )
383+ try ! self . encoder. encode ( . close( . preparedStatement( name) ) )
384+ try ! self . encoder. encode ( . sync)
385+ context. writeAndFlush ( self . wrapOutboundOut ( self . encoder. flush ( ) !) , promise: nil )
386+
372387 case . portal( let name) :
373- context . write ( . close( . portal( name) ) , promise : nil )
374- context . write ( . sync, promise : nil )
375- context. flush ( )
388+ try ! self . encoder . encode ( . close( . portal( name) ) )
389+ try ! self . encoder . encode ( . sync)
390+ context. writeAndFlush ( self . wrapOutboundOut ( self . encoder . flush ( ) ! ) , promise : nil )
376391 }
377392 }
378393
@@ -387,10 +402,16 @@ final class PSQLChannelHandler: ChannelDuplexHandler {
387402 query: query,
388403 parameters: [ ] )
389404
390- context. write ( . parse( parse) , promise: nil )
391- context. write ( . describe( . preparedStatement( statementName) ) , promise: nil )
392- context. write ( . sync, promise: nil )
393- context. flush ( )
405+
406+ do {
407+ try self . encoder. encode ( . parse( parse) )
408+ try self . encoder. encode ( . describe( . preparedStatement( statementName) ) )
409+ try self . encoder. encode ( . sync)
410+ context. writeAndFlush ( self . wrapOutboundOut ( self . encoder. flush ( ) !) , promise: nil )
411+ } catch {
412+ let action = self . state. errorHappened ( . channel( underlying: error) )
413+ self . run ( action, with: context)
414+ }
394415 }
395416
396417 private func sendBindExecuteAndSyncMessage(
@@ -403,10 +424,15 @@ final class PSQLChannelHandler: ChannelDuplexHandler {
403424 preparedStatementName: statementName,
404425 parameters: binds)
405426
406- context. write ( . bind( bind) , promise: nil )
407- context. write ( . execute( . init( portalName: " " ) ) , promise: nil )
408- context. write ( . sync, promise: nil )
409- context. flush ( )
427+ do {
428+ try self . encoder. encode ( . bind( bind) )
429+ try self . encoder. encode ( . execute( . init( portalName: " " ) ) )
430+ try self . encoder. encode ( . sync)
431+ context. writeAndFlush ( self . wrapOutboundOut ( self . encoder. flush ( ) !) , promise: nil )
432+ } catch {
433+ let action = self . state. errorHappened ( . channel( underlying: error) )
434+ self . run ( action, with: context)
435+ }
410436 }
411437
412438 private func sendParseDescribeBindExecuteAndSyncMessage(
@@ -424,12 +450,17 @@ final class PSQLChannelHandler: ChannelDuplexHandler {
424450 preparedStatementName: unnamedStatementName,
425451 parameters: binds)
426452
427- context. write ( wrapOutboundOut ( . parse( parse) ) , promise: nil )
428- context. write ( wrapOutboundOut ( . describe( . preparedStatement( " " ) ) ) , promise: nil )
429- context. write ( wrapOutboundOut ( . bind( bind) ) , promise: nil )
430- context. write ( wrapOutboundOut ( . execute( . init( portalName: " " ) ) ) , promise: nil )
431- context. write ( wrapOutboundOut ( . sync) , promise: nil )
432- context. flush ( )
453+ do {
454+ try self . encoder. encode ( . parse( parse) )
455+ try self . encoder. encode ( . describe( . preparedStatement( " " ) ) )
456+ try self . encoder. encode ( . bind( bind) )
457+ try self . encoder. encode ( . execute( . init( portalName: " " ) ) )
458+ try self . encoder. encode ( . sync)
459+ context. writeAndFlush ( self . wrapOutboundOut ( self . encoder. flush ( ) !) , promise: nil )
460+ } catch {
461+ let action = self . state. errorHappened ( . channel( underlying: error) )
462+ self . run ( action, with: context)
463+ }
433464 }
434465
435466 private func succeedQueryWithRowStream(
@@ -503,16 +534,6 @@ extension PSQLChannelHandler: PSQLRowsDataSource {
503534 }
504535}
505536
506- extension ChannelHandlerContext {
507- func write( _ psqlMessage: PSQLFrontendMessage , promise: EventLoopPromise < Void > ? = nil ) {
508- self . write ( NIOAny ( psqlMessage) , promise: promise)
509- }
510-
511- func writeAndFlush( _ psqlMessage: PSQLFrontendMessage , promise: EventLoopPromise < Void > ? = nil ) {
512- self . writeAndFlush ( NIOAny ( psqlMessage) , promise: promise)
513- }
514- }
515-
516537extension PSQLConnection . Configuration . Authentication {
517538 func toAuthContext( ) -> AuthContext {
518539 AuthContext (
0 commit comments