Skip to content
This repository was archived by the owner on Jul 1, 2023. It is now read-only.

Commit 1300d8a

Browse files
authored
Reflection based PointwiseMultiplicative and VectorProtocol. (#1139)
1 parent 63a43a4 commit 1300d8a

File tree

4 files changed

+273
-0
lines changed

4 files changed

+273
-0
lines changed

Sources/TensorFlow/CMakeLists.txt

+2
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,8 @@ add_library(TensorFlow SHARED
3737
Core/Threading.swift
3838
Core/Utilities.swift
3939
Core/EuclideanDifferentiable.swift
40+
Core/VectorProtocol.swift
41+
Core/PointwiseMultiplicative.swift
4042
Core/ElementaryFunctions.swift
4143

4244
Epochs/Algorithms.swift

Sources/TensorFlow/Core/EuclideanDifferentiable.swift

+6
Original file line numberDiff line numberDiff line change
@@ -110,4 +110,10 @@ where Element: EuclideanDifferentiable {
110110
out = Array.DifferentiableView.TangentVector(self.base.map { $0.differentiableVectorView })
111111
}
112112
}
113+
extension RNNCellInput: _EuclideanDifferentiable
114+
where Input: EuclideanDifferentiable, State: EuclideanDifferentiable {}
115+
extension RNNCellOutput: _EuclideanDifferentiable
116+
where Output: EuclideanDifferentiable, State: EuclideanDifferentiable {}
117+
extension Tensor: _EuclideanDifferentiable where Scalar: TensorFlowFloatingPoint {}
118+
113119
#endif
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,137 @@
1+
// Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
import _Differentiation
16+
17+
#if TENSORFLOW_USE_STANDARD_TOOLCHAIN
18+
@_spi(Reflection) import Swift
19+
20+
infix operator .*: MultiplicationPrecedence
21+
infix operator .*=: AssignmentPrecedence
22+
23+
/// Implementation detail of the reflection default implementation.
24+
///
25+
/// Contains versions of functions in PointwiseMultiplicative that
26+
/// operate over key paths and modify a child of `Root` in-place.
27+
/// The key paths must all be WritableKeyPath<Root, Self>. This is a workaround
28+
/// to simulate having Self constraints.
29+
public protocol _PointwiseMultiplicative {
30+
/// lhs[keyPath: kp] .*= rhs[keyPath: kp]
31+
static func _pointwiseMult<Root>(_ lhs: inout Root, _ rhs: Root, _ kp: PartialKeyPath<Root>)
32+
/// out[keyPath: kp] = Self.one
33+
static func _setOne<Root>(_ out: inout Root, _ kp: PartialKeyPath<Root>)
34+
/// out[keyPath: kp] = out[keyPath: kp].reciprocal
35+
static func _setReciprocal<Root>(_ out: inout Root, _ kp: PartialKeyPath<Root>)
36+
}
37+
38+
public protocol PointwiseMultiplicative: _PointwiseMultiplicative & AdditiveArithmetic {
39+
/// The one value.
40+
///
41+
/// One is the identity element for multiplication. For any value,
42+
/// `x .* .one == x` and `.one .* x == x`.
43+
static var one: Self { get }
44+
45+
/// The multiplicative inverse of self.
46+
///
47+
/// For any value, `x .* x.reciprocal == .one` and
48+
/// `x.reciprocal .* x == .one`.
49+
var reciprocal: Self { get }
50+
51+
/// Multiplies two values and produces their product.
52+
///
53+
/// - Parameters:
54+
/// - lhs: The first value to multiply.
55+
/// - rhs: The second value to multiply.
56+
static func .* (lhs: Self, rhs: Self) -> Self
57+
58+
/// Multiplies two values and produces their product.
59+
///
60+
/// - Parameters:
61+
/// - lhs: The first value to multiply.
62+
/// - rhs: The second value to multiply.
63+
static func .*= (lhs: inout Self, rhs: Self)
64+
}
65+
66+
extension PointwiseMultiplicative {
67+
public static func .*= (lhs: inout Self, rhs: Self) {
68+
lhs = lhs .* rhs
69+
}
70+
}
71+
72+
extension PointwiseMultiplicative
73+
where Self: ExpressibleByIntegerLiteral {
74+
public static var one: Self {
75+
return 1
76+
}
77+
}
78+
79+
extension PointwiseMultiplicative {
80+
public static var one: Self {
81+
var out = self.zero
82+
visitChildren { kp, t in t._setOne(&out, kp) }
83+
return out
84+
}
85+
public var reciprocal: Self {
86+
var out = self
87+
Self.visitChildren { kp, t in t._setReciprocal(&out, kp) }
88+
return out
89+
}
90+
public static func .* (lhs: Self, rhs: Self) -> Self {
91+
var out = lhs
92+
visitChildren { kp, t in
93+
t._pointwiseMult(&out, rhs, kp)
94+
}
95+
return out
96+
}
97+
public static func _pointwiseMult<Root>(
98+
_ lhs: inout Root, _ rhs: Root, _ kp: PartialKeyPath<Root>
99+
) {
100+
let kp = kp as! WritableKeyPath<Root, Self>
101+
lhs[keyPath: kp] .*= rhs[keyPath: kp]
102+
}
103+
public static func _setOne<Root>(_ out: inout Root, _ kp: PartialKeyPath<Root>) {
104+
let kp = kp as! WritableKeyPath<Root, Self>
105+
out[keyPath: kp] = Self.one
106+
}
107+
public static func _setReciprocal<Root>(_ out: inout Root, _ kp: PartialKeyPath<Root>) {
108+
let kp = kp as! WritableKeyPath<Root, Self>
109+
out[keyPath: kp] = out[keyPath: kp].reciprocal
110+
}
111+
}
112+
113+
extension PointwiseMultiplicative {
114+
internal static func visitChildren(
115+
_ body: (PartialKeyPath<Self>, _PointwiseMultiplicative.Type) -> Void
116+
) {
117+
if !_forEachFieldWithKeyPath(
118+
of: Self.self,
119+
body: { name, kp in
120+
let valueType = type(of: kp).valueType
121+
guard let valueType = valueType as? _PointwiseMultiplicative.Type else {
122+
fatalError("not PointwiseMultiplicative: \(valueType)")
123+
}
124+
body(kp, valueType)
125+
return true
126+
})
127+
{
128+
fatalError(
129+
"Unreflectable member of \(Self.self) while implementing PointwiseMultiplicative.")
130+
}
131+
}
132+
}
133+
134+
extension Array.DifferentiableView: _PointwiseMultiplicative
135+
where Element: Differentiable & PointwiseMultiplicative {}
136+
extension Tensor: _PointwiseMultiplicative where Scalar: Numeric {}
137+
#endif
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,128 @@
1+
// Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
import _Differentiation
16+
17+
#if TENSORFLOW_USE_STANDARD_TOOLCHAIN
18+
@_spi(Reflection) import Swift
19+
20+
/// Implementation detail for reflection.
21+
///
22+
/// This should contain the methods of `VectorProtocol`
23+
/// that do not require Self constraints.
24+
public protocol _VectorProtocol {
25+
typealias VectorSpaceScalar = Float
26+
27+
/// Adds the specified scalar to `self`.
28+
mutating func add(_ x: VectorSpaceScalar)
29+
30+
/// Subtracts the specified scalar to `self`.
31+
mutating func subtract(_ x: VectorSpaceScalar)
32+
33+
/// Scales `self` by the specified scalar.
34+
mutating func scale(by scalar: VectorSpaceScalar)
35+
}
36+
37+
extension VectorProtocol {
38+
internal static func visitChildren(
39+
_ body: (PartialKeyPath<Self>, _VectorProtocol.Type) -> Void
40+
) {
41+
if !_forEachFieldWithKeyPath(
42+
of: Self.self,
43+
body: { name, kp in
44+
let valueType = type(of: kp).valueType
45+
guard let valueType = valueType as? _VectorProtocol.Type else {
46+
fatalError("not VectorProtocol: \(valueType)")
47+
}
48+
body(kp, valueType)
49+
return true
50+
})
51+
{
52+
fatalError("Unreflectable member of \(Self.self) while implementing VectorProtocol.")
53+
}
54+
}
55+
}
56+
57+
extension _VectorProtocol {
58+
static func add<Root>(_ v: inout Root, _ kp: PartialKeyPath<Root>, _ x: VectorSpaceScalar) {
59+
v[keyPath: (kp as! WritableKeyPath<Root, Self>)].add(x)
60+
}
61+
static func subtract<Root>(_ v: inout Root, _ kp: PartialKeyPath<Root>, _ x: VectorSpaceScalar)
62+
{
63+
v[keyPath: (kp as! WritableKeyPath<Root, Self>)].subtract(x)
64+
}
65+
static func scale<Root>(
66+
_ v: inout Root, _ kp: PartialKeyPath<Root>, by scalar: VectorSpaceScalar
67+
) {
68+
v[keyPath: (kp as! WritableKeyPath<Root, Self>)].scale(by: scalar)
69+
}
70+
}
71+
72+
/// A type that represents an unranked vector space. Values of this type are
73+
/// elements in this vector space and have either no shape or a static shape.
74+
public protocol VectorProtocol: _VectorProtocol & AdditiveArithmetic {
75+
/// The type of scalars in the vector space.
76+
associatedtype VectorSpaceScalar = Float
77+
78+
func adding(_ x: VectorSpaceScalar) -> Self
79+
80+
mutating func add(_ x: VectorSpaceScalar)
81+
82+
func subtracting(_ x: VectorSpaceScalar) -> Self
83+
84+
mutating func subtract(_ x: VectorSpaceScalar)
85+
86+
/// Returns `self` multiplied by the given scalar.
87+
func scaled(by scalar: VectorSpaceScalar) -> Self
88+
89+
/// Multiplies `self` by the given scalar.
90+
mutating func scale(by scalar: VectorSpaceScalar)
91+
}
92+
93+
extension VectorProtocol {
94+
public mutating func add(_ x: VectorSpaceScalar) {
95+
self = adding(x)
96+
}
97+
98+
public mutating func subtract(_ x: VectorSpaceScalar) {
99+
self = subtracting(x)
100+
}
101+
102+
public mutating func scale(by scalar: VectorSpaceScalar) {
103+
self = scaled(by: scalar)
104+
}
105+
}
106+
107+
extension VectorProtocol {
108+
public func adding(_ x: VectorSpaceScalar) -> Self {
109+
var out = self
110+
Self.visitChildren { kp, t in t.add(&out, kp, x) }
111+
return out
112+
}
113+
public func subtracting(_ x: VectorSpaceScalar) -> Self {
114+
var out = self
115+
Self.visitChildren { kp, t in t.subtract(&out, kp, x) }
116+
return out
117+
}
118+
public func scaled(by scalar: VectorSpaceScalar) -> Self {
119+
var out = self
120+
Self.visitChildren { kp, t in t.scale(&out, kp, by: scalar) }
121+
return out
122+
}
123+
}
124+
125+
extension Tensor: _VectorProtocol where Scalar: TensorFlowFloatingPoint {}
126+
extension Array.DifferentiableView: _VectorProtocol
127+
where Element: Differentiable & VectorProtocol {}
128+
#endif

0 commit comments

Comments
 (0)