Skip to content
This repository was archived by the owner on Jul 1, 2023. It is now read-only.

Commit 8a910b4

Browse files
authored
TensorFlow: annotate functions using _forEachFieldWithKeyPath (#1162)
Because the standard toolchain support requires the newest runtime in order to have access to `_forEachFieldWithKeyPath` and the runtime is bundled into the OS on macOS, we need to annotate the functions with availability.
1 parent 2886790 commit 8a910b4

6 files changed

+52
-2
lines changed

Sources/TensorFlow/Core/ElementaryFunctions.swift

+7
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,18 @@
1313
// limitations under the License.
1414

1515
#if TENSORFLOW_USE_STANDARD_TOOLCHAIN
16+
1617
import Numerics
1718
@_spi(Reflection) import Swift
19+
1820
extension ElementaryFunctions {
1921
internal static func visitChildren(
2022
_ body: (PartialKeyPath<Self>, ElementaryFunctionsVisit.Type) -> Void
2123
) {
24+
guard #available(macOS 9999, *) else {
25+
fatalError("\(#function) is unavailable")
26+
}
27+
2228
if !_forEachFieldWithKeyPath(
2329
of: Self.self,
2430
body: { name, kp in
@@ -164,4 +170,5 @@ extension ElementaryFunctions {
164170
public static func root(_ x: Self, _ n: Int) -> Self { .init(mapped: Functor_root(n: n), x) }
165171
public static func pow(_ x: Self, _ y: Self) -> Self { .init(mapped: Functor_pow2(), x, y) }
166172
}
173+
167174
#endif

Sources/TensorFlow/Core/EuclideanDifferentiable.swift

+10-2
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,14 @@
1515
import _Differentiation
1616

1717
#if TENSORFLOW_USE_STANDARD_TOOLCHAIN
18+
1819
@_spi(Reflection) import Swift
1920

2021
func listFields<Root>(of type: Root.Type) -> [(String, PartialKeyPath<Root>)] {
22+
guard #available(macOS 9999, *) else {
23+
fatalError("\(#function) is unavailable")
24+
}
25+
2126
var out = [(String, PartialKeyPath<Root>)]()
2227
_forEachFieldWithKeyPath(of: type, options: .ignoreUnknown) { name, kp in
2328
out.append((String(validatingUTF8: name)!, kp))
@@ -27,8 +32,11 @@ func listFields<Root>(of type: Root.Type) -> [(String, PartialKeyPath<Root>)] {
2732
}
2833

2934
extension Differentiable {
30-
static var differentiableFields: [(String, PartialKeyPath<Self>, PartialKeyPath<TangentVector>)]
31-
{
35+
static var differentiableFields: [(String, PartialKeyPath<Self>, PartialKeyPath<TangentVector>)] {
36+
guard #available(macOS 9999, *) else {
37+
fatalError("\(#function) is unavailable")
38+
}
39+
3240
let tangentFields = listFields(of: TangentVector.self)
3341
var i = 0
3442
var out = [(String, PartialKeyPath<Self>, PartialKeyPath<TangentVector>)]()

Sources/TensorFlow/Core/KeyPathIterable.swift

+6
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
import _Differentiation
2222

2323
#if TENSORFLOW_USE_STANDARD_TOOLCHAIN
24+
2425
@_spi(Reflection) import Swift
2526

2627
/// An implementation detail of `KeyPathIterable`; do not use this protocol
@@ -42,6 +43,10 @@ public protocol KeyPathIterable: _KeyPathIterableBase {
4243

4344
public extension KeyPathIterable {
4445
var allKeyPaths: [PartialKeyPath<Self>] {
46+
guard #available(macOS 9999, *) else {
47+
fatalError("\(#function) is unavailable")
48+
}
49+
4550
var out = [PartialKeyPath<Self>]()
4651
_forEachFieldWithKeyPath(of: Self.self, options: .ignoreUnknown) { name, kp in
4752
out.append(kp)
@@ -171,4 +176,5 @@ extension Optional.TangentVector: KeyPathIterable {
171176
return []
172177
}
173178
}
179+
174180
#endif

Sources/TensorFlow/Core/PointwiseMultiplicative.swift

+6
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
import _Differentiation
1616

1717
#if TENSORFLOW_USE_STANDARD_TOOLCHAIN
18+
1819
@_spi(Reflection) import Swift
1920

2021
infix operator .*: MultiplicationPrecedence
@@ -114,6 +115,10 @@ extension PointwiseMultiplicative {
114115
internal static func visitChildren(
115116
_ body: (PartialKeyPath<Self>, _PointwiseMultiplicative.Type) -> Void
116117
) {
118+
guard #available(macOS 9999, *) else {
119+
fatalError("\(#function) is unavailable")
120+
}
121+
117122
if !_forEachFieldWithKeyPath(
118123
of: Self.self,
119124
body: { name, kp in
@@ -134,4 +139,5 @@ extension PointwiseMultiplicative {
134139
extension Array.DifferentiableView: _PointwiseMultiplicative
135140
where Element: Differentiable & PointwiseMultiplicative {}
136141
extension Tensor: _PointwiseMultiplicative where Scalar: Numeric {}
142+
137143
#endif

Sources/TensorFlow/Core/TensorGroup.swift

+17
Original file line numberDiff line numberDiff line change
@@ -340,9 +340,14 @@ extension Array: TensorArrayProtocol where Element: TensorGroup {
340340
}
341341

342342
#if TENSORFLOW_USE_STANDARD_TOOLCHAIN
343+
343344
@_spi(Reflection) import Swift
344345

345346
func reflectionInit<T>(type: T.Type, body: (inout T, PartialKeyPath<T>) -> Void) -> T {
347+
guard #available(macOS 9999, *) else {
348+
fatalError("\(#function) is unavailable")
349+
}
350+
346351
let x = UnsafeMutablePointer<T>.allocate(capacity: 1)
347352
defer { x.deallocate() }
348353
if !_forEachFieldWithKeyPath(of: type, body: { name, kp in
@@ -356,6 +361,10 @@ func reflectionInit<T>(type: T.Type, body: (inout T, PartialKeyPath<T>) -> Void)
356361

357362
extension TensorGroup {
358363
public static var _typeList: [TensorDataType] {
364+
guard #available(macOS 9999, *) else {
365+
fatalError("\(#function) is unavailable")
366+
}
367+
359368
var out = [TensorDataType]()
360369
if !(_forEachFieldWithKeyPath(of: Self.self) { name, kp in
361370
guard let valueType = type(of: kp).valueType as? TensorGroup.Type else { return false }
@@ -366,6 +375,7 @@ extension TensorGroup {
366375
}
367376
return out
368377
}
378+
369379
public static func initialize<Root>(
370380
_ base: inout Root, _ kp: PartialKeyPath<Root>,
371381
_owning tensorHandles: UnsafePointer<CTensorHandle>?
@@ -377,6 +387,7 @@ extension TensorGroup {
377387
v.initialize(to: .init(_owning: tensorHandles))
378388
}
379389
}
390+
380391
public init(_owning tensorHandles: UnsafePointer<CTensorHandle>?) {
381392
var i = 0
382393
self = reflectionInit(type: Self.self) { base, kp in
@@ -387,7 +398,12 @@ extension TensorGroup {
387398
i += Int(valueType._tensorHandleCount)
388399
}
389400
}
401+
390402
public func _unpackTensorHandles(into address: UnsafeMutablePointer<CTensorHandle>?) {
403+
guard #available(macOS 9999, *) else {
404+
fatalError("\(#function) is unavailable")
405+
}
406+
391407
var i = 0
392408
if !_forEachFieldWithKeyPath(of: Self.self, body: { name, kp in
393409
guard let x = self[keyPath: kp] as? TensorGroup else { return false }
@@ -399,4 +415,5 @@ extension TensorGroup {
399415
}
400416
}
401417
}
418+
402419
#endif

Sources/TensorFlow/Core/VectorProtocol.swift

+6
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
import _Differentiation
1616

1717
#if TENSORFLOW_USE_STANDARD_TOOLCHAIN
18+
1819
@_spi(Reflection) import Swift
1920

2021
/// Implementation detail for reflection.
@@ -38,6 +39,10 @@ extension VectorProtocol {
3839
internal static func visitChildren(
3940
_ body: (PartialKeyPath<Self>, _VectorProtocol.Type) -> Void
4041
) {
42+
guard #available(macOS 9999, *) else {
43+
fatalError("\(#function) is unavailable")
44+
}
45+
4146
if !_forEachFieldWithKeyPath(
4247
of: Self.self,
4348
body: { name, kp in
@@ -127,4 +132,5 @@ extension VectorProtocol {
127132
extension Tensor: _VectorProtocol where Scalar: TensorFlowFloatingPoint {}
128133
extension Array.DifferentiableView: _VectorProtocol
129134
where Element: Differentiable & VectorProtocol {}
135+
130136
#endif

0 commit comments

Comments
 (0)