@@ -196,10 +196,16 @@ open class NSSet : NSObject, NSCopying, NSMutableCopying, NSSecureCoding, NSCodi
196
196
197
197
open override func isEqual( _ value: Any ? ) -> Bool {
198
198
switch value {
199
+ case let other as NSSet :
200
+ // Check that this isn't a subclass — if both self and other are subclasses, this would otherwise turn into an infinite loop if other.isEqual(_:) calls super.
201
+ if ( type ( of: self ) == NSSet . self || type ( of: self ) == NSMutableSet . self) &&
202
+ ( type ( of: other) != NSSet . self && type ( of: other) != NSMutableSet . self) {
203
+ return other. isEqual ( self ) // This ensures NSCountedSet overriding this method is respected no matter which side of the equality it appears on.
204
+ } else {
205
+ return isEqual ( to: Set . _unconditionallyBridgeFromObjectiveC ( other) )
206
+ }
199
207
case let other as Set < AnyHashable > :
200
208
return isEqual ( to: other)
201
- case let other as NSSet :
202
- return isEqual ( to: Set . _unconditionallyBridgeFromObjectiveC ( other) )
203
209
default :
204
210
return false
205
211
}
@@ -452,7 +458,9 @@ open class NSMutableSet : NSSet {
452
458
453
459
/**************** Counted Set ****************/
454
460
open class NSCountedSet : NSMutableSet {
455
- internal var _table : Dictionary < NSObject , Int >
461
+ // Note: in 5.0 and earlier, _table contained the object's exact count.
462
+ // In 5.1 and earlier, it contains the count minus one. This allows us to have a quick 'is this set just like a regular NSSet' flag (if this table is empty, then all objects in it exist at most once in it.)
463
+ internal var _table : [ NSObject : Int ] = [ : ]
456
464
457
465
public required init ( capacity numItems: Int ) {
458
466
_table = Dictionary < NSObject , Int > ( )
@@ -466,21 +474,93 @@ open class NSCountedSet : NSMutableSet {
466
474
public convenience init ( array: [ Any ] ) {
467
475
self . init ( capacity: array. count)
468
476
for object in array {
469
- let value = __SwiftValue. store ( object)
470
- if let count = _table [ value] {
471
- _table [ value] = count + 1
472
- } else {
473
- _table [ value] = 1
474
- _storage. insert ( value)
475
- }
477
+ add ( __SwiftValue. store ( object) )
476
478
}
477
479
}
478
480
479
481
public convenience init ( set: Set < AnyHashable > ) {
480
482
self . init ( array: Array ( set) )
481
483
}
482
484
483
- public required convenience init ? ( coder: NSCoder ) { NSUnimplemented ( ) }
485
+ private enum NSCodingKeys {
486
+ static let maximumAllowedCount = UInt . max >> 4
487
+ static let countKey = " NS.count "
488
+ static func objectKey( atIndex index: Int64 ) -> String { return " NS.object \( index) " }
489
+ static func objectCountKey( atIndex index: Int64 ) -> String { return " NS.count \( index) " }
490
+ }
491
+
492
+ public required convenience init ? ( coder: NSCoder ) {
493
+ func fail( _ message: String ) {
494
+ coder. failWithError ( NSError ( domain: NSCocoaErrorDomain, code: NSCoderReadCorruptError, userInfo: [ NSLocalizedDescriptionKey: message] ) )
495
+ }
496
+
497
+ guard coder. allowsKeyedCoding else {
498
+ fail ( " NSCountedSet requires keyed coding to be archived. " )
499
+ return nil
500
+ }
501
+
502
+ let count = coder. decodeInt64 ( forKey: NSCodingKeys . countKey)
503
+ guard count >= 0 , UInt ( count) <= NSCodingKeys . maximumAllowedCount else {
504
+ fail ( " cannot decode set with \( count) elements in this version " )
505
+ return nil
506
+ }
507
+
508
+ var objects : [ ( object: Any , count: Int64 ) ] = [ ]
509
+
510
+ for i in 0 ..< count {
511
+ let objectKey = NSCodingKeys . objectKey ( atIndex: i)
512
+ let countKey = NSCodingKeys . objectCountKey ( atIndex: i)
513
+
514
+ guard coder. containsValue ( forKey: objectKey) && coder. containsValue ( forKey: countKey) else {
515
+ fail ( " Mismatch in count stored ( \( count) ) vs. count present ( \( i) ) " )
516
+ return nil
517
+ }
518
+
519
+ guard let object = coder. decodeObject ( forKey: objectKey) else {
520
+ fail ( " Decode failure at index \( i) - item nil " )
521
+ return nil
522
+ }
523
+
524
+ let itemCount = coder. decodeInt64 ( forKey: countKey)
525
+ guard itemCount > 0 else {
526
+ fail ( " Decode failure at index \( i) - itemCount zero " )
527
+ return nil
528
+ }
529
+
530
+ guard UInt ( itemCount) <= NSCodingKeys . maximumAllowedCount else {
531
+ fail ( " Cannot store \( itemCount) instances of item \( object) in this version " )
532
+ return nil
533
+ }
534
+
535
+ objects. append ( ( object, itemCount) )
536
+ }
537
+
538
+ self . init ( )
539
+ for value in objects {
540
+ for _ in 0 ..< value. count {
541
+ add ( value. object)
542
+ }
543
+ }
544
+ }
545
+
546
+ open override func encode( with coder: NSCoder ) {
547
+ func fail( _ message: String ) {
548
+ coder. failWithError ( NSError ( domain: NSCocoaErrorDomain, code: NSCoderReadCorruptError, userInfo: [ NSLocalizedDescriptionKey: message] ) )
549
+ }
550
+
551
+ guard coder. allowsKeyedCoding else {
552
+ fail ( " NSCountedSet requires keyed coding to be archived. " )
553
+ return
554
+ }
555
+
556
+ coder. encode ( Int64 ( self . count) , forKey: NSCodingKeys . countKey)
557
+ var index : Int64 = 0
558
+ for object in self {
559
+ coder. encode ( object, forKey: NSCodingKeys . objectKey ( atIndex: index) )
560
+ coder. encode ( Int64 ( count ( for: object) ) , forKey: NSCodingKeys . objectCountKey ( atIndex: index) )
561
+ index += 1
562
+ }
563
+ }
484
564
485
565
open override func copy( with zone: NSZone ? = nil ) -> Any {
486
566
if type ( of: self ) === NSCountedSet . self {
@@ -507,21 +587,23 @@ open class NSCountedSet : NSMutableSet {
507
587
NSRequiresConcreteImplementation ( )
508
588
}
509
589
let value = __SwiftValue. store ( object)
510
- guard let count = _table [ value] else {
590
+ if let count = _table [ value] {
591
+ return count + 1
592
+ } else if _storage. contains ( value) {
593
+ return 1
594
+ } else {
511
595
return 0
512
596
}
513
- return count
514
597
}
515
598
516
599
open override func add( _ object: Any ) {
517
600
guard type ( of: self ) === NSCountedSet . self else {
518
601
NSRequiresConcreteImplementation ( )
519
602
}
520
603
let value = __SwiftValue. store ( object)
521
- if let count = _table [ value] {
522
- _table [ value] = count + 1
604
+ if _storage . contains ( value) {
605
+ _table [ value, default : 0 ] += 1
523
606
} else {
524
- _table [ value] = 1
525
607
_storage. insert ( value)
526
608
}
527
609
}
@@ -531,14 +613,11 @@ open class NSCountedSet : NSMutableSet {
531
613
NSRequiresConcreteImplementation ( )
532
614
}
533
615
let value = __SwiftValue. store ( object)
534
- guard let count = _table [ value] else {
535
- return
536
- }
537
-
538
- if count > 1 {
539
- _table [ value] = count - 1
540
- } else {
541
- _table [ value] = nil
616
+ if let count = _table [ value] {
617
+ precondition ( count > 0 )
618
+ _table [ value] = count == 1 ? nil : count - 1
619
+ } else if _storage. contains ( value) {
620
+ _table. removeValue ( forKey: value)
542
621
_storage. remove ( value)
543
622
}
544
623
}
@@ -551,6 +630,27 @@ open class NSCountedSet : NSMutableSet {
551
630
forEach ( remove)
552
631
}
553
632
}
633
+
634
+ open override func isEqual( _ value: Any ? ) -> Bool {
635
+ if let countedSet = value as? NSCountedSet {
636
+ guard count == countedSet. count else { return false }
637
+ for object in self {
638
+ if !countedSet. contains ( object) || count ( for: object) != countedSet. count ( for: object) {
639
+ return false
640
+ }
641
+ }
642
+ return true
643
+ }
644
+
645
+ if _table. isEmpty {
646
+ return super. isEqual ( value)
647
+ } else {
648
+ return false
649
+ }
650
+ }
651
+
652
+ // The hash of a NSSet in s-c-f is its count, which is the same among equal NSCountedSets as well,
653
+ // so just using the superclass's implementation works fine.
554
654
}
555
655
556
656
extension NSSet : _StructTypeBridgeable {
0 commit comments