Skip to content
This repository was archived by the owner on Jul 1, 2023. It is now read-only.

Commit c0d1520

Browse files
authored
Keypath reflection default implementation. (#1140)
1 parent 1300d8a commit c0d1520

File tree

5 files changed

+408
-0
lines changed

5 files changed

+408
-0
lines changed

Sources/TensorFlow/CMakeLists.txt

+1
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ add_library(TensorFlow SHARED
3636
Core/TensorShape.swift
3737
Core/Threading.swift
3838
Core/Utilities.swift
39+
Core/KeyPathIterable.swift
3940
Core/EuclideanDifferentiable.swift
4041
Core/VectorProtocol.swift
4142
Core/PointwiseMultiplicative.swift
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,174 @@
1+
// Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2+
// Copyright (c) 2014 - 2017 Apple Inc. and the Swift project authors
3+
// Licensed under Apache License v2.0 with Runtime Library Exceptions.
4+
//
5+
// Licensed under the Apache License, Version 2.0 (the "License");
6+
// you may not use this file except in compliance with the License.
7+
// You may obtain a copy of the License at
8+
//
9+
// http://www.apache.org/licenses/LICENSE-2.0
10+
//
11+
// Unless required by applicable law or agreed to in writing, software
12+
// distributed under the License is distributed on an "AS IS" BASIS,
13+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
// See the License for the specific language governing permissions and
15+
// limitations under the License.
16+
17+
//===----------------------------------------------------------------------===//
18+
// KeyPathIterable
19+
//===----------------------------------------------------------------------===//
20+
21+
import _Differentiation
22+
23+
#if TENSORFLOW_USE_STANDARD_TOOLCHAIN
24+
@_spi(Reflection) import Swift
25+
26+
/// An implementation detail of `KeyPathIterable`; do not use this protocol
27+
/// directly.
28+
public protocol _KeyPathIterableBase {
29+
var _allKeyPathsTypeErased: [AnyKeyPath] { get }
30+
var _recursivelyAllKeyPathsTypeErased: [AnyKeyPath] { get }
31+
}
32+
33+
/// A type whose values provides custom key paths to properties or elements.
34+
public protocol KeyPathIterable: _KeyPathIterableBase {
35+
/// A type that can represent a collection of all key paths of this type.
36+
associatedtype AllKeyPaths: Collection
37+
where AllKeyPaths.Element == PartialKeyPath<Self>
38+
39+
/// A collection of all custom key paths of this value.
40+
var allKeyPaths: AllKeyPaths { get }
41+
}
42+
43+
public extension KeyPathIterable {
44+
var allKeyPaths: [PartialKeyPath<Self>] {
45+
var out = [PartialKeyPath<Self>]()
46+
_forEachFieldWithKeyPath(of: Self.self, options: .ignoreUnknown) { name, kp in
47+
out.append(kp)
48+
return true
49+
}
50+
return out
51+
}
52+
}
53+
54+
public extension KeyPathIterable {
55+
/// An array of all custom key paths of this value and any custom key paths
56+
/// nested within each of what this value's key paths refers to.
57+
var recursivelyAllKeyPaths: [PartialKeyPath<Self>] {
58+
var result: [PartialKeyPath<Self>] = []
59+
for kp in allKeyPaths {
60+
result.append(kp)
61+
if let nested = self[keyPath: kp] as? _KeyPathIterableBase {
62+
for nkp in nested._recursivelyAllKeyPathsTypeErased {
63+
result.append(kp.appending(path: nkp)!)
64+
}
65+
}
66+
}
67+
return result
68+
}
69+
}
70+
71+
public extension KeyPathIterable {
72+
var _allKeyPathsTypeErased: [AnyKeyPath] {
73+
return allKeyPaths.map { $0 as AnyKeyPath }
74+
}
75+
var _recursivelyAllKeyPathsTypeErased: [AnyKeyPath] {
76+
return recursivelyAllKeyPaths.map { $0 as AnyKeyPath }
77+
}
78+
}
79+
80+
public extension KeyPathIterable {
81+
/// Returns an array of all custom key paths of this value, to the specified
82+
/// type.
83+
func allKeyPaths<T>(to _: T.Type) -> [KeyPath<Self, T>] {
84+
return allKeyPaths.compactMap { $0 as? KeyPath<Self, T> }
85+
}
86+
87+
/// Returns an array of all custom key paths of this value and any custom key
88+
/// paths nested within each of what this value's key paths refers to, to
89+
/// the specified type.
90+
func recursivelyAllKeyPaths<T>(to _: T.Type) -> [KeyPath<Self, T>] {
91+
return recursivelyAllKeyPaths.compactMap { $0 as? KeyPath<Self, T> }
92+
}
93+
94+
/// Returns an array of all custom writable key paths of this value, to the
95+
/// specified type.
96+
func allWritableKeyPaths<T>(to _: T.Type) -> [WritableKeyPath<Self, T>] {
97+
return allKeyPaths(to: T.self)
98+
.compactMap { $0 as? WritableKeyPath<Self, T> }
99+
}
100+
101+
/// Returns an array of all custom writable key paths of this value and any
102+
/// custom writable key paths nested within each of what this value's key
103+
/// paths refers to, to the specified type.
104+
func recursivelyAllWritableKeyPaths<T>(
105+
to _: T.Type
106+
) -> [WritableKeyPath<Self, T>] {
107+
return recursivelyAllKeyPaths(to: T.self)
108+
.compactMap { $0 as? WritableKeyPath<Self, T> }
109+
}
110+
}
111+
112+
//===----------------------------------------------------------------------===//
113+
// Collection conformances
114+
//===----------------------------------------------------------------------===//
115+
116+
extension Array: KeyPathIterable {
117+
public typealias AllKeyPaths = [PartialKeyPath<Array>]
118+
public var allKeyPaths: [PartialKeyPath<Array>] {
119+
return indices.map { \Array[$0] }
120+
}
121+
}
122+
123+
// TODO(TF-938): Remove this conformance after removing
124+
// `Element: Differentiable` requirement.
125+
//
126+
// Currently necessary to avoid error:
127+
//
128+
// error: conditional conformance of type 'Array<Element>.DifferentiableView'
129+
// to protocol 'KeyPathIterable' does not imply conformance to inherited
130+
// protocol '_KeyPathIterableBase'.
131+
extension Array.DifferentiableView: _KeyPathIterableBase
132+
where Element: Differentiable {}
133+
134+
// TODO(TF-938): Remove `Element: Differentiable` requirement.
135+
extension Array.DifferentiableView: KeyPathIterable
136+
where Element: Differentiable {
137+
public typealias AllKeyPaths = [PartialKeyPath<Array.DifferentiableView>]
138+
public var allKeyPaths: [PartialKeyPath<Array.DifferentiableView>] {
139+
return [\Array.DifferentiableView.base]
140+
}
141+
}
142+
143+
extension Dictionary: KeyPathIterable {
144+
public typealias AllKeyPaths = [PartialKeyPath<Dictionary>]
145+
public var allKeyPaths: [PartialKeyPath<Dictionary>] {
146+
// Note: `Dictionary.subscript(_: Key)` returns `Value?` and can be used to
147+
// form `WritableKeyPath<Self, Value>` key paths.
148+
// Force-unwrapping the result is necessary.
149+
return keys.map { \Dictionary[$0]! }
150+
}
151+
}
152+
153+
extension Optional: KeyPathIterable {
154+
public typealias AllKeyPaths = [PartialKeyPath<Self>]
155+
156+
public var allKeyPaths: [PartialKeyPath<Self>] {
157+
if self != nil {
158+
return [\.!]
159+
}
160+
return []
161+
}
162+
}
163+
164+
extension Optional.TangentVector: KeyPathIterable {
165+
public typealias AllKeyPaths = [PartialKeyPath<Self>]
166+
167+
public var allKeyPaths: [PartialKeyPath<Self>] {
168+
if value != nil {
169+
return [\Self.value!]
170+
}
171+
return []
172+
}
173+
}
174+
#endif

Tests/TensorFlowTests/CMakeLists.txt

+1
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ add_library(TensorFlowTests
55
FreezableTests.swift
66
Helpers.swift
77
InitializerTests.swift
8+
KeyPathIterableTests.swift
89
LayerTests.swift
910
LazyTensorEvaluationTests.swift
1011
LazyTensorExplicitTraceTests.swift

0 commit comments

Comments
 (0)