diff --git a/Foundation/NSSet.swift b/Foundation/NSSet.swift index 7175bd99a4..88c0487640 100644 --- a/Foundation/NSSet.swift +++ b/Foundation/NSSet.swift @@ -196,10 +196,16 @@ open class NSSet : NSObject, NSCopying, NSMutableCopying, NSSecureCoding, NSCodi open override func isEqual(_ value: Any?) -> Bool { switch value { + case let other as NSSet: + // Check that this isn't a subclass — if both self and other are subclasses, this would otherwise turn into an infinite loop if other.isEqual(_:) calls super. + if (type(of: self) == NSSet.self || type(of: self) == NSMutableSet.self) && + (type(of: other) != NSSet.self && type(of: other) != NSMutableSet.self) { + return other.isEqual(self) // This ensures NSCountedSet overriding this method is respected no matter which side of the equality it appears on. + } else { + return isEqual(to: Set._unconditionallyBridgeFromObjectiveC(other)) + } case let other as Set: return isEqual(to: other) - case let other as NSSet: - return isEqual(to: Set._unconditionallyBridgeFromObjectiveC(other)) default: return false } @@ -452,7 +458,9 @@ open class NSMutableSet : NSSet { /**************** Counted Set ****************/ open class NSCountedSet : NSMutableSet { - internal var _table: Dictionary + // Note: in 5.0 and earlier, _table contained the object's exact count. + // In 5.1 and earlier, it contains the count minus one. This allows us to have a quick 'is this set just like a regular NSSet' flag (if this table is empty, then all objects in it exist at most once in it.) + internal var _table: [NSObject: Int] = [:] public required init(capacity numItems: Int) { _table = Dictionary() @@ -466,13 +474,7 @@ open class NSCountedSet : NSMutableSet { public convenience init(array: [Any]) { self.init(capacity: array.count) for object in array { - let value = __SwiftValue.store(object) - if let count = _table[value] { - _table[value] = count + 1 - } else { - _table[value] = 1 - _storage.insert(value) - } + add(__SwiftValue.store(object)) } } @@ -480,7 +482,85 @@ open class NSCountedSet : NSMutableSet { self.init(array: Array(set)) } - public required convenience init?(coder: NSCoder) { NSUnimplemented() } + private enum NSCodingKeys { + static let maximumAllowedCount = UInt.max >> 4 + static let countKey = "NS.count" + static func objectKey(atIndex index: Int64) -> String { return "NS.object\(index)" } + static func objectCountKey(atIndex index: Int64) -> String { return "NS.count\(index)" } + } + + public required convenience init?(coder: NSCoder) { + func fail(_ message: String) { + coder.failWithError(NSError(domain: NSCocoaErrorDomain, code: NSCoderReadCorruptError, userInfo: [NSLocalizedDescriptionKey: message])) + } + + guard coder.allowsKeyedCoding else { + fail("NSCountedSet requires keyed coding to be archived.") + return nil + } + + let count = coder.decodeInt64(forKey: NSCodingKeys.countKey) + guard count >= 0, UInt(count) <= NSCodingKeys.maximumAllowedCount else { + fail("cannot decode set with \(count) elements in this version") + return nil + } + + var objects: [(object: Any, count: Int64)] = [] + + for i in 0 ..< count { + let objectKey = NSCodingKeys.objectKey(atIndex: i) + let countKey = NSCodingKeys.objectCountKey(atIndex: i) + + guard coder.containsValue(forKey: objectKey) && coder.containsValue(forKey: countKey) else { + fail("Mismatch in count stored (\(count)) vs. count present (\(i))") + return nil + } + + guard let object = coder.decodeObject(forKey: objectKey) else { + fail("Decode failure at index \(i) - item nil") + return nil + } + + let itemCount = coder.decodeInt64(forKey: countKey) + guard itemCount > 0 else { + fail("Decode failure at index \(i) - itemCount zero") + return nil + } + + guard UInt(itemCount) <= NSCodingKeys.maximumAllowedCount else { + fail("Cannot store \(itemCount) instances of item \(object) in this version") + return nil + } + + objects.append((object, itemCount)) + } + + self.init() + for value in objects { + for _ in 0 ..< value.count { + add(value.object) + } + } + } + + open override func encode(with coder: NSCoder) { + func fail(_ message: String) { + coder.failWithError(NSError(domain: NSCocoaErrorDomain, code: NSCoderReadCorruptError, userInfo: [NSLocalizedDescriptionKey: message])) + } + + guard coder.allowsKeyedCoding else { + fail("NSCountedSet requires keyed coding to be archived.") + return + } + + coder.encode(Int64(self.count), forKey: NSCodingKeys.countKey) + var index: Int64 = 0 + for object in self { + coder.encode(object, forKey: NSCodingKeys.objectKey(atIndex: index)) + coder.encode(Int64(count(for: object)), forKey: NSCodingKeys.objectCountKey(atIndex: index)) + index += 1 + } + } open override func copy(with zone: NSZone? = nil) -> Any { if type(of: self) === NSCountedSet.self { @@ -507,10 +587,13 @@ open class NSCountedSet : NSMutableSet { NSRequiresConcreteImplementation() } let value = __SwiftValue.store(object) - guard let count = _table[value] else { + if let count = _table[value] { + return count + 1 + } else if _storage.contains(value) { + return 1 + } else { return 0 } - return count } open override func add(_ object: Any) { @@ -518,10 +601,9 @@ open class NSCountedSet : NSMutableSet { NSRequiresConcreteImplementation() } let value = __SwiftValue.store(object) - if let count = _table[value] { - _table[value] = count + 1 + if _storage.contains(value) { + _table[value, default: 0] += 1 } else { - _table[value] = 1 _storage.insert(value) } } @@ -531,14 +613,11 @@ open class NSCountedSet : NSMutableSet { NSRequiresConcreteImplementation() } let value = __SwiftValue.store(object) - guard let count = _table[value] else { - return - } - - if count > 1 { - _table[value] = count - 1 - } else { - _table[value] = nil + if let count = _table[value] { + precondition(count > 0) + _table[value] = count == 1 ? nil : count - 1 + } else if _storage.contains(value) { + _table.removeValue(forKey: value) _storage.remove(value) } } @@ -551,6 +630,27 @@ open class NSCountedSet : NSMutableSet { forEach(remove) } } + + open override func isEqual(_ value: Any?) -> Bool { + if let countedSet = value as? NSCountedSet { + guard count == countedSet.count else { return false } + for object in self { + if !countedSet.contains(object) || count(for: object) != countedSet.count(for: object) { + return false + } + } + return true + } + + if _table.isEmpty { + return super.isEqual(value) + } else { + return false + } + } + + // The hash of a NSSet in s-c-f is its count, which is the same among equal NSCountedSets as well, + // so just using the superclass's implementation works fine. } extension NSSet : _StructTypeBridgeable { diff --git a/TestFoundation/FixtureValues.swift b/TestFoundation/FixtureValues.swift index 2472c0ca99..beb5bc520b 100644 --- a/TestFoundation/FixtureValues.swift +++ b/TestFoundation/FixtureValues.swift @@ -220,6 +220,28 @@ enum Fixtures { return NSMutableSet() } + // ===== NSCountedSet ===== + + static let countedSetOfNumbersAppearingOnce = TypedFixture("NSCountedSet-NumbersAppearingOnce") { + let numbers = [1, 2, 3, 4, 5].map { NSNumber(value: $0) } + return NSCountedSet(array: numbers) + } + + static let countedSetOfNumbersAppearingSeveralTimes = TypedFixture("NSCountedSet-NumbersAppearingSeveralTimes") { + let numbers = [1, 2, 3, 4, 5].map { NSNumber(value: $0) } + let set = NSCountedSet() + for _ in 0 ..< 5 { + for number in numbers { + set.add(number) + } + } + return set + } + + static let countedSetEmpty = TypedFixture("NSCountedSet-Empty") { + return NSCountedSet() + } + // ===== NSCharacterSet, NSMutableCharacterSet ===== static let characterSetEmpty = TypedFixture("NSCharacterSet-Empty") { @@ -288,6 +310,9 @@ enum Fixtures { AnyFixture(Fixtures.setEmpty), AnyFixture(Fixtures.mutableSetOfNumbers), AnyFixture(Fixtures.mutableSetEmpty), + AnyFixture(Fixtures.countedSetOfNumbersAppearingOnce), + AnyFixture(Fixtures.countedSetOfNumbersAppearingSeveralTimes), + AnyFixture(Fixtures.countedSetEmpty), AnyFixture(Fixtures.characterSetEmpty), AnyFixture(Fixtures.characterSetRange), AnyFixture(Fixtures.characterSetString), diff --git a/TestFoundation/Fixtures/macOS-10.14/NSCountedSet-Empty.archive b/TestFoundation/Fixtures/macOS-10.14/NSCountedSet-Empty.archive new file mode 100644 index 0000000000..80e7a0d0b1 Binary files /dev/null and b/TestFoundation/Fixtures/macOS-10.14/NSCountedSet-Empty.archive differ diff --git a/TestFoundation/Fixtures/macOS-10.14/NSCountedSet-NumbersAppearingOnce.archive b/TestFoundation/Fixtures/macOS-10.14/NSCountedSet-NumbersAppearingOnce.archive new file mode 100644 index 0000000000..cf620e5d87 Binary files /dev/null and b/TestFoundation/Fixtures/macOS-10.14/NSCountedSet-NumbersAppearingOnce.archive differ diff --git a/TestFoundation/Fixtures/macOS-10.14/NSCountedSet-NumbersAppearingSeveralTimes.archive b/TestFoundation/Fixtures/macOS-10.14/NSCountedSet-NumbersAppearingSeveralTimes.archive new file mode 100644 index 0000000000..5868b433bf Binary files /dev/null and b/TestFoundation/Fixtures/macOS-10.14/NSCountedSet-NumbersAppearingSeveralTimes.archive differ diff --git a/TestFoundation/TestNSSet.swift b/TestFoundation/TestNSSet.swift index b6fb8b77a3..b409dabe16 100644 --- a/TestFoundation/TestNSSet.swift +++ b/TestFoundation/TestNSSet.swift @@ -260,6 +260,12 @@ class TestNSSet : XCTestCase { Fixtures.mutableSetEmpty, ] + let countedSetFixtures = [ + Fixtures.countedSetOfNumbersAppearingOnce, + Fixtures.countedSetOfNumbersAppearingSeveralTimes, + Fixtures.countedSetEmpty, + ] + func test_codingRoundtrip() throws { for fixture in setFixtures { try fixture.assertValueRoundtripsInCoder() @@ -267,6 +273,9 @@ class TestNSSet : XCTestCase { for fixture in mutableSetFixtures { try fixture.assertValueRoundtripsInCoder() } + for fixture in countedSetFixtures { + try fixture.assertValueRoundtripsInCoder() + } } func test_loadedValuesMatch() throws { @@ -276,6 +285,9 @@ class TestNSSet : XCTestCase { for fixture in mutableSetFixtures { try fixture.assertLoadedValuesMatch() } + for fixture in countedSetFixtures { + try fixture.assertLoadedValuesMatch() + } } static var allTests: [(String, (TestNSSet) -> () throws -> Void)] {