// Copyright 2020 The TensorFlow Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. import TensorFlow import XCTest struct SimpleKPI: KeyPathIterable, Equatable { var w, b: Float } struct MixedKPI: KeyPathIterable, Equatable { // Mutable. var string: String var float: Float // Immutable. let int: Int } struct NestedKPI: KeyPathIterable, Equatable { // Immutable. let simple: SimpleKPI // Mutable. var mixed: MixedKPI } struct ComplexNestedKPI: KeyPathIterable, Equatable { var float: Float let simple: SimpleKPI let optional: SimpleKPI? let array: [SimpleKPI] var dictionary: [String: SimpleKPI] } final class KeyPathIterableTests: XCTestCase { func testSimple() { var x = SimpleKPI(w: 1, b: 2) XCTAssertEqual([\SimpleKPI.w, \SimpleKPI.b], x.allKeyPaths) XCTAssertEqual([\SimpleKPI.w, \SimpleKPI.b], x.allKeyPaths(to: Float.self)) XCTAssertEqual([\SimpleKPI.w, \SimpleKPI.b], x.allWritableKeyPaths(to: Float.self)) XCTAssertEqual([\SimpleKPI.w, \SimpleKPI.b], x.recursivelyAllKeyPaths) XCTAssertEqual([\SimpleKPI.w, \SimpleKPI.b], x.recursivelyAllKeyPaths(to: Float.self)) XCTAssertEqual([\SimpleKPI.w, \SimpleKPI.b], x.recursivelyAllWritableKeyPaths(to: Float.self)) XCTAssertEqual([], x.allKeyPaths(to: Int.self)) XCTAssertEqual([], x.recursivelyAllKeyPaths(to: Double.self)) // Mutate recursively all `Float` properties. for kp in x.allWritableKeyPaths(to: Float.self) { x[keyPath: kp] += x[keyPath: kp] } // Check that recursively all `Float` properties have been mutated. XCTAssertEqual(SimpleKPI(w: 2, b: 4), x) } func testMixed() { var x = MixedKPI(string: "hello", float: .pi, int: 0) XCTAssertEqual([\MixedKPI.string, \MixedKPI.float, \MixedKPI.int], x.allKeyPaths) XCTAssertEqual([\MixedKPI.string, \MixedKPI.float, \MixedKPI.int], x.recursivelyAllKeyPaths) XCTAssertEqual([\MixedKPI.string], x.allKeyPaths(to: String.self)) XCTAssertEqual([\MixedKPI.string], x.allWritableKeyPaths(to: String.self)) XCTAssertEqual([\MixedKPI.string], x.recursivelyAllKeyPaths(to: String.self)) XCTAssertEqual([\MixedKPI.string], x.recursivelyAllWritableKeyPaths(to: String.self)) XCTAssertEqual([\MixedKPI.float], x.allKeyPaths(to: Float.self)) XCTAssertEqual([\MixedKPI.float], x.allWritableKeyPaths(to: Float.self)) XCTAssertEqual([\MixedKPI.float], x.recursivelyAllKeyPaths(to: Float.self)) XCTAssertEqual([\MixedKPI.float], x.recursivelyAllWritableKeyPaths(to: Float.self)) XCTAssertEqual([\MixedKPI.int], x.allKeyPaths(to: Int.self)) XCTAssertEqual([], x.allWritableKeyPaths(to: Int.self)) XCTAssertEqual([\MixedKPI.int], x.recursivelyAllKeyPaths(to: Int.self)) XCTAssertEqual([], x.recursivelyAllWritableKeyPaths(to: Int.self)) // Mutate recursively all `String` properties. for kp in x.allWritableKeyPaths(to: String.self) { x[keyPath: kp] = x[keyPath: kp] + " world" } // Check that recursively all `String` properties have been mutated. XCTAssertEqual(MixedKPI(string: "hello world", float: .pi, int: 0), x) } func testSimpleNested() { var x = NestedKPI( simple: SimpleKPI(w: 1, b: 2), mixed: MixedKPI(string: "foo", float: 3, int: 0)) XCTAssertEqual([\NestedKPI.simple, \NestedKPI.mixed], x.allKeyPaths) XCTAssertEqual( [ \NestedKPI.simple, \NestedKPI.simple.w, \NestedKPI.simple.b, \NestedKPI.mixed, \NestedKPI.mixed.string, \NestedKPI.mixed.float, \NestedKPI.mixed.int, ], x.recursivelyAllKeyPaths) XCTAssertEqual([], x.allKeyPaths(to: Float.self)) XCTAssertEqual([], x.allKeyPaths(to: Int.self)) XCTAssertEqual([], x.allKeyPaths(to: String.self)) XCTAssertEqual([], x.allWritableKeyPaths(to: Float.self)) XCTAssertEqual([], x.allWritableKeyPaths(to: Int.self)) XCTAssertEqual([], x.allWritableKeyPaths(to: String.self)) XCTAssertEqual( [\NestedKPI.simple.w, \NestedKPI.simple.b, \NestedKPI.mixed.float], x.recursivelyAllKeyPaths(to: Float.self)) XCTAssertEqual([\NestedKPI.mixed.int], x.recursivelyAllKeyPaths(to: Int.self)) XCTAssertEqual([\NestedKPI.mixed.string], x.recursivelyAllKeyPaths(to: String.self)) XCTAssertEqual([\NestedKPI.mixed.float], x.recursivelyAllWritableKeyPaths(to: Float.self)) XCTAssertEqual([], x.recursivelyAllWritableKeyPaths(to: Int.self)) XCTAssertEqual([\NestedKPI.mixed.string], x.recursivelyAllWritableKeyPaths(to: String.self)) XCTAssertEqual([], x.recursivelyAllKeyPaths(to: Double.self)) // Mutate recursively all `Float` properties. for kp in x.recursivelyAllWritableKeyPaths(to: Float.self) { x[keyPath: kp] *= 100 } // Check that recursively all `Float` properties have been mutated. let expected = NestedKPI( simple: SimpleKPI(w: 1, b: 2), mixed: MixedKPI(string: "foo", float: 300, int: 0)) XCTAssertEqual(expected, x) } func testComplexNested() { var x = ComplexNestedKPI( float: 1, simple: SimpleKPI(w: 3, b: 4), optional: SimpleKPI(w: 5, b: 6), array: [SimpleKPI(w: 5, b: 6), SimpleKPI(w: 7, b: 8)], dictionary: [ "foo": SimpleKPI(w: 1, b: 2), "bar": SimpleKPI(w: 3, b: 4), ]) XCTAssertEqual( [ \ComplexNestedKPI.float, \ComplexNestedKPI.simple, \ComplexNestedKPI.optional, \ComplexNestedKPI.array, \ComplexNestedKPI.dictionary, ], x.allKeyPaths) let key1 = (x.dictionary.keys.map {$0})[0] let key2 = (x.dictionary.keys.map {$0})[1] XCTAssertEqual( [ \ComplexNestedKPI.float, \ComplexNestedKPI.simple, \ComplexNestedKPI.simple.w, \ComplexNestedKPI.simple.b, \ComplexNestedKPI.optional, \ComplexNestedKPI.optional!, \ComplexNestedKPI.optional!.w, \ComplexNestedKPI.optional!.b, \ComplexNestedKPI.array, \ComplexNestedKPI.array[0], \ComplexNestedKPI.array[0].w, \ComplexNestedKPI.array[0].b, \ComplexNestedKPI.array[1], \ComplexNestedKPI.array[1].w, \ComplexNestedKPI.array[1].b, \ComplexNestedKPI.dictionary, \ComplexNestedKPI.dictionary[key1]!, \ComplexNestedKPI.dictionary[key1]!.w, \ComplexNestedKPI.dictionary[key1]!.b, \ComplexNestedKPI.dictionary[key2]!, \ComplexNestedKPI.dictionary[key2]!.w, \ComplexNestedKPI.dictionary[key2]!.b, ], x.recursivelyAllKeyPaths) XCTAssertEqual( [ \ComplexNestedKPI.float, \ComplexNestedKPI.simple.w, \ComplexNestedKPI.simple.b, \ComplexNestedKPI.optional!.w, \ComplexNestedKPI.optional!.b, \ComplexNestedKPI.array[0].w, \ComplexNestedKPI.array[0].b, \ComplexNestedKPI.array[1].w, \ComplexNestedKPI.array[1].b, \ComplexNestedKPI.dictionary[key1]!.w, \ComplexNestedKPI.dictionary[key1]!.b, \ComplexNestedKPI.dictionary[key2]!.w, \ComplexNestedKPI.dictionary[key2]!.b, ], x.recursivelyAllKeyPaths(to: Float.self)) XCTAssertEqual( [ \ComplexNestedKPI.float, \ComplexNestedKPI.dictionary[key1]!.w, \ComplexNestedKPI.dictionary[key1]!.b, \ComplexNestedKPI.dictionary[key2]!.w, \ComplexNestedKPI.dictionary[key2]!.b, ], x.recursivelyAllWritableKeyPaths(to: Float.self)) // Mutate recursively all `Float` properties. for kp in x.recursivelyAllWritableKeyPaths(to: Float.self) { x[keyPath: kp] += 1 } // Check that recursively all `Float` properties have been mutated. let expected = ComplexNestedKPI( float: 2, simple: SimpleKPI(w: 3, b: 4), optional: SimpleKPI(w: 5, b: 6), array: [SimpleKPI(w: 5, b: 6), SimpleKPI(w: 7, b: 8)], dictionary: [ "foo": SimpleKPI(w: 2, b: 3), "bar": SimpleKPI(w: 4, b: 5), ]) XCTAssertEqual(expected, x) } static var allTests = [ ("testSimple", testSimple), ("testMixed", testMixed), ("testSimpleNested", testSimpleNested), ("testComplexNested", testComplexNested), ] }