@@ -177,6 +177,8 @@ type Association struct {
177177 cumulativeTSNAckPoint uint32
178178 advancedPeerTSNAckPoint uint32
179179 useForwardTSN bool
180+ useZeroChecksum bool
181+ requestZeroChecksum bool
180182
181183 // Congestion control parameters
182184 maxReceiveBufferSize uint32
@@ -233,6 +235,7 @@ type Config struct {
233235 NetConn net.Conn
234236 MaxReceiveBufferSize uint32
235237 MaxMessageSize uint32
238+ EnableZeroChecksum bool
236239 LoggerFactory logging.LoggerFactory
237240}
238241
@@ -320,6 +323,7 @@ func createAssociation(config Config) *Association {
320323 handshakeCompletedCh : make (chan error ),
321324 cumulativeTSNAckPoint : tsn - 1 ,
322325 advancedPeerTSNAckPoint : tsn - 1 ,
326+ requestZeroChecksum : config .EnableZeroChecksum ,
323327 silentError : ErrSilentlyDiscard ,
324328 stats : & associationStats {},
325329 log : config .LoggerFactory .NewLogger ("sctp" ),
@@ -362,6 +366,11 @@ func (a *Association) init(isClient bool) {
362366 init .initiateTag = a .myVerificationTag
363367 init .advertisedReceiverWindowCredit = a .maxReceiveBufferSize
364368 setSupportedExtensions (& init .chunkInitCommon )
369+
370+ if a .requestZeroChecksum {
371+ init .params = append (init .params , & paramZeroChecksumAcceptable {edmid : dtlsErrorDetectionMethod })
372+ }
373+
365374 a .storedInit = init
366375
367376 err := a .sendInit ()
@@ -618,10 +627,45 @@ func (a *Association) unregisterStream(s *Stream, err error) {
618627 s .readNotifier .Broadcast ()
619628}
620629
630+ func chunkMandatoryChecksum (cc []chunk ) bool {
631+ for _ , c := range cc {
632+ switch c .(type ) {
633+ case * chunkInit , * chunkInitAck , * chunkCookieEcho :
634+ return true
635+ }
636+ }
637+ return false
638+ }
639+
640+ func (a * Association ) marshalPacket (p * packet ) ([]byte , error ) {
641+ return p .marshal (! a .useZeroChecksum || chunkMandatoryChecksum (p .chunks ))
642+ }
643+
644+ func (a * Association ) unmarshalPacket (raw []byte ) (* packet , error ) {
645+ p := & packet {}
646+ if ! a .useZeroChecksum {
647+ if err := p .unmarshal (true , raw ); err != nil {
648+ return nil , err
649+ }
650+ return p , nil
651+ }
652+
653+ if err := p .unmarshal (false , raw ); err != nil {
654+ return nil , err
655+ }
656+ if chunkMandatoryChecksum (p .chunks ) {
657+ if err := p .unmarshal (true , raw ); err != nil {
658+ return nil , err
659+ }
660+ }
661+
662+ return p , nil
663+ }
664+
621665// handleInbound parses incoming raw packets
622666func (a * Association ) handleInbound (raw []byte ) error {
623- p := & packet {}
624- if err := p . unmarshal ( raw ); err != nil {
667+ p , err := a . unmarshalPacket ( raw )
668+ if err != nil {
625669 a .log .Warnf ("[%s] unable to parse SCTP packet %s" , a .name , err )
626670 return nil
627671 }
@@ -647,7 +691,7 @@ func (a *Association) handleInbound(raw []byte) error {
647691// The caller should hold the lock
648692func (a * Association ) gatherDataPacketsToRetransmit (rawPackets [][]byte ) [][]byte {
649693 for _ , p := range a .getDataPacketsToRetransmit () {
650- raw , err := p . marshal ( )
694+ raw , err := a . marshalPacket ( p )
651695 if err != nil {
652696 a .log .Warnf ("[%s] failed to serialize a DATA packet to be retransmitted" , a .name )
653697 continue
@@ -668,7 +712,7 @@ func (a *Association) gatherOutboundDataAndReconfigPackets(rawPackets [][]byte)
668712 a .log .Tracef ("[%s] T3-rtx timer start (pt1)" , a .name )
669713 a .t3RTX .start (a .rtoMgr .getRTO ())
670714 for _ , p := range a .bundleDataChunksIntoPackets (chunks ) {
671- raw , err := p . marshal ( )
715+ raw , err := a . marshalPacket ( p )
672716 if err != nil {
673717 a .log .Warnf ("[%s] failed to serialize a DATA packet" , a .name )
674718 continue
@@ -683,7 +727,7 @@ func (a *Association) gatherOutboundDataAndReconfigPackets(rawPackets [][]byte)
683727 a .log .Debugf ("[%s] retransmit %d RECONFIG chunk(s)" , a .name , len (a .reconfigs ))
684728 for _ , c := range a .reconfigs {
685729 p := a .createPacket ([]chunk {c })
686- raw , err := p . marshal ( )
730+ raw , err := a . marshalPacket ( p )
687731 if err != nil {
688732 a .log .Warnf ("[%s] failed to serialize a RECONFIG packet to be retransmitted" , a .name )
689733 } else {
@@ -706,7 +750,7 @@ func (a *Association) gatherOutboundDataAndReconfigPackets(rawPackets [][]byte)
706750 a .log .Debugf ("[%s] sending RECONFIG: rsn=%d tsn=%d streams=%v" ,
707751 a .name , rsn , a .myNextTSN - 1 , sisToReset )
708752 p := a .createPacket ([]chunk {c })
709- raw , err := p . marshal ( )
753+ raw , err := a . marshalPacket ( p )
710754 if err != nil {
711755 a .log .Warnf ("[%s] failed to serialize a RECONFIG packet to be transmitted" , a .name )
712756 } else {
@@ -769,7 +813,7 @@ func (a *Association) gatherOutboundFastRetransmissionPackets(rawPackets [][]byt
769813 }
770814
771815 if len (toFastRetrans ) > 0 {
772- raw , err := a .createPacket (toFastRetrans ). marshal ( )
816+ raw , err := a .marshalPacket ( a . createPacket (toFastRetrans ))
773817 if err != nil {
774818 a .log .Warnf ("[%s] failed to serialize a DATA packet to be fast-retransmitted" , a .name )
775819 } else {
@@ -787,7 +831,7 @@ func (a *Association) gatherOutboundSackPackets(rawPackets [][]byte) [][]byte {
787831 a .ackState = ackStateIdle
788832 sack := a .createSelectiveAckChunk ()
789833 a .log .Debugf ("[%s] sending SACK: %s" , a .name , sack )
790- raw , err := a .createPacket ([]chunk {sack }). marshal ( )
834+ raw , err := a .marshalPacket ( a . createPacket ([]chunk {sack }))
791835 if err != nil {
792836 a .log .Warnf ("[%s] failed to serialize a SACK packet" , a .name )
793837 } else {
@@ -804,7 +848,7 @@ func (a *Association) gatherOutboundForwardTSNPackets(rawPackets [][]byte) [][]b
804848 a .willSendForwardTSN = false
805849 if sna32GT (a .advancedPeerTSNAckPoint , a .cumulativeTSNAckPoint ) {
806850 fwdtsn := a .createForwardTSN ()
807- raw , err := a .createPacket ([]chunk {fwdtsn }). marshal ( )
851+ raw , err := a .marshalPacket ( a . createPacket ([]chunk {fwdtsn }))
808852 if err != nil {
809853 a .log .Warnf ("[%s] failed to serialize a Forward TSN packet" , a .name )
810854 } else {
@@ -827,7 +871,7 @@ func (a *Association) gatherOutboundShutdownPackets(rawPackets [][]byte) ([][]by
827871 cumulativeTSNAck : a .cumulativeTSNAckPoint ,
828872 }
829873
830- raw , err := a .createPacket ([]chunk {shutdown }). marshal ( )
874+ raw , err := a .marshalPacket ( a . createPacket ([]chunk {shutdown }))
831875 if err != nil {
832876 a .log .Warnf ("[%s] failed to serialize a Shutdown packet" , a .name )
833877 } else {
@@ -839,7 +883,7 @@ func (a *Association) gatherOutboundShutdownPackets(rawPackets [][]byte) ([][]by
839883
840884 shutdownAck := & chunkShutdownAck {}
841885
842- raw , err := a .createPacket ([]chunk {shutdownAck }). marshal ( )
886+ raw , err := a .marshalPacket ( a . createPacket ([]chunk {shutdownAck }))
843887 if err != nil {
844888 a .log .Warnf ("[%s] failed to serialize a ShutdownAck packet" , a .name )
845889 } else {
@@ -851,7 +895,7 @@ func (a *Association) gatherOutboundShutdownPackets(rawPackets [][]byte) ([][]by
851895
852896 shutdownComplete := & chunkShutdownComplete {}
853897
854- raw , err := a .createPacket ([]chunk {shutdownComplete }). marshal ( )
898+ raw , err := a .marshalPacket ( a . createPacket ([]chunk {shutdownComplete }))
855899 if err != nil {
856900 a .log .Warnf ("[%s] failed to serialize a ShutdownComplete packet" , a .name )
857901 } else {
@@ -875,7 +919,7 @@ func (a *Association) gatherAbortPacket() ([]byte, error) {
875919 abort .errorCauses = []errorCause {cause }
876920 }
877921
878- raw , err := a .createPacket ([]chunk {abort }). marshal ( )
922+ raw , err := a .marshalPacket ( a . createPacket ([]chunk {abort }))
879923
880924 return raw , err
881925}
@@ -900,7 +944,7 @@ func (a *Association) gatherOutbound() ([][]byte, bool) {
900944
901945 if a .controlQueue .size () > 0 {
902946 for _ , p := range a .controlQueue .popAll () {
903- raw , err := p . marshal ( )
947+ raw , err := a . marshalPacket ( p )
904948 if err != nil {
905949 a .log .Warnf ("[%s] failed to serialize a control packet" , a .name )
906950 continue
@@ -1092,6 +1136,7 @@ func (a *Association) handleInit(p *packet, i *chunkInit) ([]*packet, error) {
10921136 // subtracting one from it.
10931137 a .peerLastTSN = i .initialTSN - 1
10941138
1139+ peerHasZeroChecksum := false
10951140 for _ , param := range i .params {
10961141 switch v := param .(type ) { // nolint:gocritic
10971142 case * paramSupportedExtensions :
@@ -1101,8 +1146,11 @@ func (a *Association) handleInit(p *packet, i *chunkInit) ([]*packet, error) {
11011146 a .useForwardTSN = true
11021147 }
11031148 }
1149+ case * paramZeroChecksumAcceptable :
1150+ peerHasZeroChecksum = v .edmid == dtlsErrorDetectionMethod
11041151 }
11051152 }
1153+
11061154 if ! a .useForwardTSN {
11071155 a .log .Warnf ("[%s] not using ForwardTSN (on init)" , a .name )
11081156 }
@@ -1129,6 +1177,12 @@ func (a *Association) handleInit(p *packet, i *chunkInit) ([]*packet, error) {
11291177
11301178 initAck .params = []param {a .myCookie }
11311179
1180+ if peerHasZeroChecksum {
1181+ initAck .params = append (initAck .params , & paramZeroChecksumAcceptable {edmid : dtlsErrorDetectionMethod })
1182+ a .useZeroChecksum = true
1183+ }
1184+ a .log .Debugf ("[%s] useZeroChecksum=%t (on init)" , a .name , a .useZeroChecksum )
1185+
11321186 setSupportedExtensions (& initAck .chunkInitCommon )
11331187
11341188 outbound .chunks = []chunk {initAck }
@@ -1186,8 +1240,13 @@ func (a *Association) handleInitAck(p *packet, i *chunkInitAck) error {
11861240 a .useForwardTSN = true
11871241 }
11881242 }
1243+ case * paramZeroChecksumAcceptable :
1244+ a .useZeroChecksum = v .edmid == dtlsErrorDetectionMethod
11891245 }
11901246 }
1247+
1248+ a .log .Debugf ("[%s] useZeroChecksum=%t (on initAck)" , a .name , a .useZeroChecksum )
1249+
11911250 if ! a .useForwardTSN {
11921251 a .log .Warnf ("[%s] not using ForwardTSN (on initAck)" , a .name )
11931252 }
0 commit comments