Skip to content

Commit 4030f27

Browse files
authored
Merge pull request #2493 from millenomi/nscountedset-nscoding
2 parents d99abed + d98e7b5 commit 4030f27

6 files changed

+161
-24
lines changed

Foundation/NSSet.swift

+124-24
Original file line numberDiff line numberDiff line change
@@ -196,10 +196,16 @@ open class NSSet : NSObject, NSCopying, NSMutableCopying, NSSecureCoding, NSCodi
196196

197197
open override func isEqual(_ value: Any?) -> Bool {
198198
switch value {
199+
case let other as NSSet:
200+
// Check that this isn't a subclass — if both self and other are subclasses, this would otherwise turn into an infinite loop if other.isEqual(_:) calls super.
201+
if (type(of: self) == NSSet.self || type(of: self) == NSMutableSet.self) &&
202+
(type(of: other) != NSSet.self && type(of: other) != NSMutableSet.self) {
203+
return other.isEqual(self) // This ensures NSCountedSet overriding this method is respected no matter which side of the equality it appears on.
204+
} else {
205+
return isEqual(to: Set._unconditionallyBridgeFromObjectiveC(other))
206+
}
199207
case let other as Set<AnyHashable>:
200208
return isEqual(to: other)
201-
case let other as NSSet:
202-
return isEqual(to: Set._unconditionallyBridgeFromObjectiveC(other))
203209
default:
204210
return false
205211
}
@@ -452,7 +458,9 @@ open class NSMutableSet : NSSet {
452458

453459
/**************** Counted Set ****************/
454460
open class NSCountedSet : NSMutableSet {
455-
internal var _table: Dictionary<NSObject, Int>
461+
// Note: in 5.0 and earlier, _table contained the object's exact count.
462+
// In 5.1 and earlier, it contains the count minus one. This allows us to have a quick 'is this set just like a regular NSSet' flag (if this table is empty, then all objects in it exist at most once in it.)
463+
internal var _table: [NSObject: Int] = [:]
456464

457465
public required init(capacity numItems: Int) {
458466
_table = Dictionary<NSObject, Int>()
@@ -466,21 +474,93 @@ open class NSCountedSet : NSMutableSet {
466474
public convenience init(array: [Any]) {
467475
self.init(capacity: array.count)
468476
for object in array {
469-
let value = __SwiftValue.store(object)
470-
if let count = _table[value] {
471-
_table[value] = count + 1
472-
} else {
473-
_table[value] = 1
474-
_storage.insert(value)
475-
}
477+
add(__SwiftValue.store(object))
476478
}
477479
}
478480

479481
public convenience init(set: Set<AnyHashable>) {
480482
self.init(array: Array(set))
481483
}
482484

483-
public required convenience init?(coder: NSCoder) { NSUnimplemented() }
485+
private enum NSCodingKeys {
486+
static let maximumAllowedCount = UInt.max >> 4
487+
static let countKey = "NS.count"
488+
static func objectKey(atIndex index: Int64) -> String { return "NS.object\(index)" }
489+
static func objectCountKey(atIndex index: Int64) -> String { return "NS.count\(index)" }
490+
}
491+
492+
public required convenience init?(coder: NSCoder) {
493+
func fail(_ message: String) {
494+
coder.failWithError(NSError(domain: NSCocoaErrorDomain, code: NSCoderReadCorruptError, userInfo: [NSLocalizedDescriptionKey: message]))
495+
}
496+
497+
guard coder.allowsKeyedCoding else {
498+
fail("NSCountedSet requires keyed coding to be archived.")
499+
return nil
500+
}
501+
502+
let count = coder.decodeInt64(forKey: NSCodingKeys.countKey)
503+
guard count >= 0, UInt(count) <= NSCodingKeys.maximumAllowedCount else {
504+
fail("cannot decode set with \(count) elements in this version")
505+
return nil
506+
}
507+
508+
var objects: [(object: Any, count: Int64)] = []
509+
510+
for i in 0 ..< count {
511+
let objectKey = NSCodingKeys.objectKey(atIndex: i)
512+
let countKey = NSCodingKeys.objectCountKey(atIndex: i)
513+
514+
guard coder.containsValue(forKey: objectKey) && coder.containsValue(forKey: countKey) else {
515+
fail("Mismatch in count stored (\(count)) vs. count present (\(i))")
516+
return nil
517+
}
518+
519+
guard let object = coder.decodeObject(forKey: objectKey) else {
520+
fail("Decode failure at index \(i) - item nil")
521+
return nil
522+
}
523+
524+
let itemCount = coder.decodeInt64(forKey: countKey)
525+
guard itemCount > 0 else {
526+
fail("Decode failure at index \(i) - itemCount zero")
527+
return nil
528+
}
529+
530+
guard UInt(itemCount) <= NSCodingKeys.maximumAllowedCount else {
531+
fail("Cannot store \(itemCount) instances of item \(object) in this version")
532+
return nil
533+
}
534+
535+
objects.append((object, itemCount))
536+
}
537+
538+
self.init()
539+
for value in objects {
540+
for _ in 0 ..< value.count {
541+
add(value.object)
542+
}
543+
}
544+
}
545+
546+
open override func encode(with coder: NSCoder) {
547+
func fail(_ message: String) {
548+
coder.failWithError(NSError(domain: NSCocoaErrorDomain, code: NSCoderReadCorruptError, userInfo: [NSLocalizedDescriptionKey: message]))
549+
}
550+
551+
guard coder.allowsKeyedCoding else {
552+
fail("NSCountedSet requires keyed coding to be archived.")
553+
return
554+
}
555+
556+
coder.encode(Int64(self.count), forKey: NSCodingKeys.countKey)
557+
var index: Int64 = 0
558+
for object in self {
559+
coder.encode(object, forKey: NSCodingKeys.objectKey(atIndex: index))
560+
coder.encode(Int64(count(for: object)), forKey: NSCodingKeys.objectCountKey(atIndex: index))
561+
index += 1
562+
}
563+
}
484564

485565
open override func copy(with zone: NSZone? = nil) -> Any {
486566
if type(of: self) === NSCountedSet.self {
@@ -507,21 +587,23 @@ open class NSCountedSet : NSMutableSet {
507587
NSRequiresConcreteImplementation()
508588
}
509589
let value = __SwiftValue.store(object)
510-
guard let count = _table[value] else {
590+
if let count = _table[value] {
591+
return count + 1
592+
} else if _storage.contains(value) {
593+
return 1
594+
} else {
511595
return 0
512596
}
513-
return count
514597
}
515598

516599
open override func add(_ object: Any) {
517600
guard type(of: self) === NSCountedSet.self else {
518601
NSRequiresConcreteImplementation()
519602
}
520603
let value = __SwiftValue.store(object)
521-
if let count = _table[value] {
522-
_table[value] = count + 1
604+
if _storage.contains(value) {
605+
_table[value, default: 0] += 1
523606
} else {
524-
_table[value] = 1
525607
_storage.insert(value)
526608
}
527609
}
@@ -531,14 +613,11 @@ open class NSCountedSet : NSMutableSet {
531613
NSRequiresConcreteImplementation()
532614
}
533615
let value = __SwiftValue.store(object)
534-
guard let count = _table[value] else {
535-
return
536-
}
537-
538-
if count > 1 {
539-
_table[value] = count - 1
540-
} else {
541-
_table[value] = nil
616+
if let count = _table[value] {
617+
precondition(count > 0)
618+
_table[value] = count == 1 ? nil : count - 1
619+
} else if _storage.contains(value) {
620+
_table.removeValue(forKey: value)
542621
_storage.remove(value)
543622
}
544623
}
@@ -551,6 +630,27 @@ open class NSCountedSet : NSMutableSet {
551630
forEach(remove)
552631
}
553632
}
633+
634+
open override func isEqual(_ value: Any?) -> Bool {
635+
if let countedSet = value as? NSCountedSet {
636+
guard count == countedSet.count else { return false }
637+
for object in self {
638+
if !countedSet.contains(object) || count(for: object) != countedSet.count(for: object) {
639+
return false
640+
}
641+
}
642+
return true
643+
}
644+
645+
if _table.isEmpty {
646+
return super.isEqual(value)
647+
} else {
648+
return false
649+
}
650+
}
651+
652+
// The hash of a NSSet in s-c-f is its count, which is the same among equal NSCountedSets as well,
653+
// so just using the superclass's implementation works fine.
554654
}
555655

556656
extension NSSet : _StructTypeBridgeable {

TestFoundation/FixtureValues.swift

+25
Original file line numberDiff line numberDiff line change
@@ -236,6 +236,28 @@ enum Fixtures {
236236
return NSMutableSet()
237237
}
238238

239+
// ===== NSCountedSet =====
240+
241+
static let countedSetOfNumbersAppearingOnce = TypedFixture<NSCountedSet>("NSCountedSet-NumbersAppearingOnce") {
242+
let numbers = [1, 2, 3, 4, 5].map { NSNumber(value: $0) }
243+
return NSCountedSet(array: numbers)
244+
}
245+
246+
static let countedSetOfNumbersAppearingSeveralTimes = TypedFixture<NSCountedSet>("NSCountedSet-NumbersAppearingSeveralTimes") {
247+
let numbers = [1, 2, 3, 4, 5].map { NSNumber(value: $0) }
248+
let set = NSCountedSet()
249+
for _ in 0 ..< 5 {
250+
for number in numbers {
251+
set.add(number)
252+
}
253+
}
254+
return set
255+
}
256+
257+
static let countedSetEmpty = TypedFixture<NSCountedSet>("NSCountedSet-Empty") {
258+
return NSCountedSet()
259+
}
260+
239261
// ===== NSCharacterSet, NSMutableCharacterSet =====
240262

241263
static let characterSetEmpty = TypedFixture<NSCharacterSet>("NSCharacterSet-Empty") {
@@ -307,6 +329,9 @@ enum Fixtures {
307329
AnyFixture(Fixtures.setEmpty),
308330
AnyFixture(Fixtures.mutableSetOfNumbers),
309331
AnyFixture(Fixtures.mutableSetEmpty),
332+
AnyFixture(Fixtures.countedSetOfNumbersAppearingOnce),
333+
AnyFixture(Fixtures.countedSetOfNumbersAppearingSeveralTimes),
334+
AnyFixture(Fixtures.countedSetEmpty),
310335
AnyFixture(Fixtures.characterSetEmpty),
311336
AnyFixture(Fixtures.characterSetRange),
312337
AnyFixture(Fixtures.characterSetString),
Binary file not shown.
Binary file not shown.

TestFoundation/TestNSSet.swift

+12
Original file line numberDiff line numberDiff line change
@@ -260,13 +260,22 @@ class TestNSSet : XCTestCase {
260260
Fixtures.mutableSetEmpty,
261261
]
262262

263+
let countedSetFixtures = [
264+
Fixtures.countedSetOfNumbersAppearingOnce,
265+
Fixtures.countedSetOfNumbersAppearingSeveralTimes,
266+
Fixtures.countedSetEmpty,
267+
]
268+
263269
func test_codingRoundtrip() throws {
264270
for fixture in setFixtures {
265271
try fixture.assertValueRoundtripsInCoder()
266272
}
267273
for fixture in mutableSetFixtures {
268274
try fixture.assertValueRoundtripsInCoder()
269275
}
276+
for fixture in countedSetFixtures {
277+
try fixture.assertValueRoundtripsInCoder()
278+
}
270279
}
271280

272281
func test_loadedValuesMatch() throws {
@@ -276,6 +285,9 @@ class TestNSSet : XCTestCase {
276285
for fixture in mutableSetFixtures {
277286
try fixture.assertLoadedValuesMatch()
278287
}
288+
for fixture in countedSetFixtures {
289+
try fixture.assertLoadedValuesMatch()
290+
}
279291
}
280292

281293
static var allTests: [(String, (TestNSSet) -> () throws -> Void)] {

0 commit comments

Comments
 (0)