Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Parity: NSCoding stragglers: NSCountedSet #2493

Merged
merged 1 commit into from
Aug 28, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
148 changes: 124 additions & 24 deletions Foundation/NSSet.swift
Original file line number Diff line number Diff line change
Expand Up @@ -196,10 +196,16 @@ open class NSSet : NSObject, NSCopying, NSMutableCopying, NSSecureCoding, NSCodi

open override func isEqual(_ value: Any?) -> Bool {
switch value {
case let other as NSSet:
// Check that this isn't a subclass — if both self and other are subclasses, this would otherwise turn into an infinite loop if other.isEqual(_:) calls super.
if (type(of: self) == NSSet.self || type(of: self) == NSMutableSet.self) &&
(type(of: other) != NSSet.self && type(of: other) != NSMutableSet.self) {
return other.isEqual(self) // This ensures NSCountedSet overriding this method is respected no matter which side of the equality it appears on.
} else {
return isEqual(to: Set._unconditionallyBridgeFromObjectiveC(other))
}
case let other as Set<AnyHashable>:
return isEqual(to: other)
case let other as NSSet:
return isEqual(to: Set._unconditionallyBridgeFromObjectiveC(other))
default:
return false
}
Expand Down Expand Up @@ -452,7 +458,9 @@ open class NSMutableSet : NSSet {

/**************** Counted Set ****************/
open class NSCountedSet : NSMutableSet {
internal var _table: Dictionary<NSObject, Int>
// Note: in 5.0 and earlier, _table contained the object's exact count.
// In 5.1 and earlier, it contains the count minus one. This allows us to have a quick 'is this set just like a regular NSSet' flag (if this table is empty, then all objects in it exist at most once in it.)
internal var _table: [NSObject: Int] = [:]

public required init(capacity numItems: Int) {
_table = Dictionary<NSObject, Int>()
Expand All @@ -466,21 +474,93 @@ open class NSCountedSet : NSMutableSet {
public convenience init(array: [Any]) {
self.init(capacity: array.count)
for object in array {
let value = __SwiftValue.store(object)
if let count = _table[value] {
_table[value] = count + 1
} else {
_table[value] = 1
_storage.insert(value)
}
add(__SwiftValue.store(object))
}
}

public convenience init(set: Set<AnyHashable>) {
self.init(array: Array(set))
}

public required convenience init?(coder: NSCoder) { NSUnimplemented() }
private enum NSCodingKeys {
static let maximumAllowedCount = UInt.max >> 4
static let countKey = "NS.count"
static func objectKey(atIndex index: Int64) -> String { return "NS.object\(index)" }
static func objectCountKey(atIndex index: Int64) -> String { return "NS.count\(index)" }
}

public required convenience init?(coder: NSCoder) {
func fail(_ message: String) {
coder.failWithError(NSError(domain: NSCocoaErrorDomain, code: NSCoderReadCorruptError, userInfo: [NSLocalizedDescriptionKey: message]))
}

guard coder.allowsKeyedCoding else {
fail("NSCountedSet requires keyed coding to be archived.")
return nil
}

let count = coder.decodeInt64(forKey: NSCodingKeys.countKey)
guard count >= 0, UInt(count) <= NSCodingKeys.maximumAllowedCount else {
fail("cannot decode set with \(count) elements in this version")
return nil
}

var objects: [(object: Any, count: Int64)] = []

for i in 0 ..< count {
let objectKey = NSCodingKeys.objectKey(atIndex: i)
let countKey = NSCodingKeys.objectCountKey(atIndex: i)

guard coder.containsValue(forKey: objectKey) && coder.containsValue(forKey: countKey) else {
fail("Mismatch in count stored (\(count)) vs. count present (\(i))")
return nil
}

guard let object = coder.decodeObject(forKey: objectKey) else {
fail("Decode failure at index \(i) - item nil")
return nil
}

let itemCount = coder.decodeInt64(forKey: countKey)
guard itemCount > 0 else {
fail("Decode failure at index \(i) - itemCount zero")
return nil
}

guard UInt(itemCount) <= NSCodingKeys.maximumAllowedCount else {
fail("Cannot store \(itemCount) instances of item \(object) in this version")
return nil
}

objects.append((object, itemCount))
}

self.init()
for value in objects {
for _ in 0 ..< value.count {
add(value.object)
}
}
}

open override func encode(with coder: NSCoder) {
func fail(_ message: String) {
coder.failWithError(NSError(domain: NSCocoaErrorDomain, code: NSCoderReadCorruptError, userInfo: [NSLocalizedDescriptionKey: message]))
}

guard coder.allowsKeyedCoding else {
fail("NSCountedSet requires keyed coding to be archived.")
return
}

coder.encode(Int64(self.count), forKey: NSCodingKeys.countKey)
var index: Int64 = 0
for object in self {
coder.encode(object, forKey: NSCodingKeys.objectKey(atIndex: index))
coder.encode(Int64(count(for: object)), forKey: NSCodingKeys.objectCountKey(atIndex: index))
index += 1
}
}

open override func copy(with zone: NSZone? = nil) -> Any {
if type(of: self) === NSCountedSet.self {
Expand All @@ -507,21 +587,23 @@ open class NSCountedSet : NSMutableSet {
NSRequiresConcreteImplementation()
}
let value = __SwiftValue.store(object)
guard let count = _table[value] else {
if let count = _table[value] {
return count + 1
} else if _storage.contains(value) {
return 1
} else {
return 0
}
return count
}

open override func add(_ object: Any) {
guard type(of: self) === NSCountedSet.self else {
NSRequiresConcreteImplementation()
}
let value = __SwiftValue.store(object)
if let count = _table[value] {
_table[value] = count + 1
if _storage.contains(value) {
_table[value, default: 0] += 1
} else {
_table[value] = 1
_storage.insert(value)
}
}
Expand All @@ -531,14 +613,11 @@ open class NSCountedSet : NSMutableSet {
NSRequiresConcreteImplementation()
}
let value = __SwiftValue.store(object)
guard let count = _table[value] else {
return
}

if count > 1 {
_table[value] = count - 1
} else {
_table[value] = nil
if let count = _table[value] {
precondition(count > 0)
_table[value] = count == 1 ? nil : count - 1
} else if _storage.contains(value) {
_table.removeValue(forKey: value)
_storage.remove(value)
}
}
Expand All @@ -551,6 +630,27 @@ open class NSCountedSet : NSMutableSet {
forEach(remove)
}
}

open override func isEqual(_ value: Any?) -> Bool {
if let countedSet = value as? NSCountedSet {
guard count == countedSet.count else { return false }
for object in self {
if !countedSet.contains(object) || count(for: object) != countedSet.count(for: object) {
return false
}
}
return true
}

if _table.isEmpty {
return super.isEqual(value)
} else {
return false
}
}

// The hash of a NSSet in s-c-f is its count, which is the same among equal NSCountedSets as well,
// so just using the superclass's implementation works fine.
}

extension NSSet : _StructTypeBridgeable {
Expand Down
25 changes: 25 additions & 0 deletions TestFoundation/FixtureValues.swift
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,28 @@ enum Fixtures {
return NSMutableSet()
}

// ===== NSCountedSet =====

static let countedSetOfNumbersAppearingOnce = TypedFixture<NSCountedSet>("NSCountedSet-NumbersAppearingOnce") {
let numbers = [1, 2, 3, 4, 5].map { NSNumber(value: $0) }
return NSCountedSet(array: numbers)
}

static let countedSetOfNumbersAppearingSeveralTimes = TypedFixture<NSCountedSet>("NSCountedSet-NumbersAppearingSeveralTimes") {
let numbers = [1, 2, 3, 4, 5].map { NSNumber(value: $0) }
let set = NSCountedSet()
for _ in 0 ..< 5 {
for number in numbers {
set.add(number)
}
}
return set
}

static let countedSetEmpty = TypedFixture<NSCountedSet>("NSCountedSet-Empty") {
return NSCountedSet()
}

// ===== NSCharacterSet, NSMutableCharacterSet =====

static let characterSetEmpty = TypedFixture<NSCharacterSet>("NSCharacterSet-Empty") {
Expand Down Expand Up @@ -288,6 +310,9 @@ enum Fixtures {
AnyFixture(Fixtures.setEmpty),
AnyFixture(Fixtures.mutableSetOfNumbers),
AnyFixture(Fixtures.mutableSetEmpty),
AnyFixture(Fixtures.countedSetOfNumbersAppearingOnce),
AnyFixture(Fixtures.countedSetOfNumbersAppearingSeveralTimes),
AnyFixture(Fixtures.countedSetEmpty),
AnyFixture(Fixtures.characterSetEmpty),
AnyFixture(Fixtures.characterSetRange),
AnyFixture(Fixtures.characterSetString),
Expand Down
Binary file not shown.
Binary file not shown.
Binary file not shown.
12 changes: 12 additions & 0 deletions TestFoundation/TestNSSet.swift
Original file line number Diff line number Diff line change
Expand Up @@ -260,13 +260,22 @@ class TestNSSet : XCTestCase {
Fixtures.mutableSetEmpty,
]

let countedSetFixtures = [
Fixtures.countedSetOfNumbersAppearingOnce,
Fixtures.countedSetOfNumbersAppearingSeveralTimes,
Fixtures.countedSetEmpty,
]

func test_codingRoundtrip() throws {
for fixture in setFixtures {
try fixture.assertValueRoundtripsInCoder()
}
for fixture in mutableSetFixtures {
try fixture.assertValueRoundtripsInCoder()
}
for fixture in countedSetFixtures {
try fixture.assertValueRoundtripsInCoder()
}
}

func test_loadedValuesMatch() throws {
Expand All @@ -276,6 +285,9 @@ class TestNSSet : XCTestCase {
for fixture in mutableSetFixtures {
try fixture.assertLoadedValuesMatch()
}
for fixture in countedSetFixtures {
try fixture.assertLoadedValuesMatch()
}
}

static var allTests: [(String, (TestNSSet) -> () throws -> Void)] {
Expand Down