Skip to content
This repository was archived by the owner on Jul 1, 2023. It is now read-only.

Commit 12e281d

Browse files
authored
Improve precondition error messages. (#673)
This PR improves error messages regarding ranks and shapes for a number of operations on Tensor, especially regarding mis-specified axes.
1 parent fcde599 commit 12e281d

File tree

2 files changed

+80
-34
lines changed

2 files changed

+80
-34
lines changed

Sources/TensorFlow/Operators/Basic.swift

+16-4
Original file line numberDiff line numberDiff line change
@@ -1334,8 +1334,12 @@ extension Tensor {
13341334

13351335
/// Returns `true` if the given scalar tensor is in the range `[-rank, rank)`.
13361336
@usableFromInline
1337-
internal func isAxisInRange(_ axis: Tensor<Int32>) -> Bool {
1338-
precondition(axis.rank == 0, "Axis must have rank 0.")
1337+
internal func isAxisInRange(
1338+
_ axis: Tensor<Int32>,
1339+
file: StaticString = #file,
1340+
line: UInt = #line
1341+
) -> Bool {
1342+
precondition(axis.rank == 0, "Axis must have rank 0.", file: file, line: line)
13391343
return areAxesInRange(axis.scalars)
13401344
}
13411345

@@ -1347,8 +1351,16 @@ extension Tensor {
13471351

13481352
/// Returns `true` if all scalars of the given 1-D tensor are in the range `[-rank, rank)`.
13491353
@usableFromInline
1350-
internal func areAxesInRange(_ axes: Tensor<Int32>) -> Bool {
1351-
precondition(axes.rank == 1, "Axes must have rank 1.")
1354+
internal func areAxesInRange(
1355+
_ axes: Tensor<Int32>,
1356+
file: StaticString = #file,
1357+
line: UInt = #line
1358+
) -> Bool {
1359+
precondition(
1360+
axes.rank < 2,
1361+
"Axes must have rank 0 or rank 1; axes has rank \(axes.rank) with values \(axes.scalars).",
1362+
file: file,
1363+
line: line)
13521364
return areAxesInRange(axes.scalars)
13531365
}
13541366
}

Sources/TensorFlow/Operators/Math.swift

+64-30
Original file line numberDiff line numberDiff line change
@@ -1693,7 +1693,7 @@ extension Tensor where Scalar == Bool {
16931693
/// - Precondition: Each value in `axes` must be in the range `-rank..<rank`.
16941694
@inlinable
16951695
public func all(squeezingAxes axes: Int...) -> Tensor {
1696-
precondition(areAxesInRange(axes), "All axes must be in the range `[-rank, rank)`.")
1696+
ensureValid(axes: axes)
16971697
let axes = axes.map(Int32.init)
16981698
return _Raw.all(self, reductionIndices: Tensor<Int32>(axes), keepDims: false)
16991699
}
@@ -1704,7 +1704,7 @@ extension Tensor where Scalar == Bool {
17041704
/// - Precondition: Each value in `axes` must be in the range `-rank..<rank`.
17051705
@inlinable
17061706
public func any(squeezingAxes axes: Int...) -> Tensor {
1707-
precondition(areAxesInRange(axes), "All axes must be in the range `[-rank, rank)`.")
1707+
ensureValid(axes: axes)
17081708
let axes = axes.map(Int32.init)
17091709
return _Raw.any(self, reductionIndices: Tensor<Int32>(axes), keepDims: false)
17101710
}
@@ -1715,7 +1715,7 @@ extension Tensor where Scalar == Bool {
17151715
/// - Precondition: Each value in `axes` must be in the range `-rank..<rank`.
17161716
@inlinable
17171717
public func all(alongAxes axes: Int...) -> Tensor {
1718-
precondition(areAxesInRange(axes), "All axes must be in the range `[-rank, rank)`.")
1718+
ensureValid(axes: axes)
17191719
let axes = axes.map(Int32.init)
17201720
return _Raw.all(self, reductionIndices: Tensor<Int32>(axes), keepDims: true)
17211721
}
@@ -1726,7 +1726,7 @@ extension Tensor where Scalar == Bool {
17261726
/// - Precondition: Each value in `axes` must be in the range `-rank..<rank`.
17271727
@inlinable
17281728
public func any(alongAxes axes: Int...) -> Tensor {
1729-
precondition(areAxesInRange(axes), "All axes must be in the range `[-rank, rank)`.")
1729+
ensureValid(axes: axes)
17301730
let axes = axes.map(Int32.init)
17311731
return _Raw.any(self, reductionIndices: Tensor<Int32>(axes), keepDims: true)
17321732
}
@@ -1757,7 +1757,7 @@ extension Tensor where Scalar: Numeric & Comparable {
17571757
@inlinable
17581758
@differentiable(wrt: self where Scalar: TensorFlowFloatingPoint)
17591759
public func max(squeezingAxes axes: Tensor<Int32>) -> Tensor {
1760-
precondition(areAxesInRange(axes), "All axes must be in the range `[-rank, rank)`.")
1760+
ensureValid(axes: axes)
17611761
return _Raw.max(self, reductionIndices: axes, keepDims: false)
17621762
}
17631763

@@ -1787,7 +1787,7 @@ extension Tensor where Scalar: Numeric & Comparable {
17871787
@inlinable
17881788
@differentiable(wrt: self where Scalar: TensorFlowFloatingPoint)
17891789
public func min(squeezingAxes axes: Tensor<Int32>) -> Tensor {
1790-
precondition(areAxesInRange(axes), "All axes must be in the range `[-rank, rank)`.")
1790+
ensureValid(axes: axes)
17911791
return _Raw.min(self, reductionIndices: axes, keepDims: false)
17921792
}
17931793

@@ -1817,7 +1817,7 @@ extension Tensor where Scalar: Numeric & Comparable {
18171817
/// - Precondition: Each value in `axes` must be in the range `-rank..<rank`.
18181818
@inlinable
18191819
public func argmax(squeezingAxis axis: Int) -> Tensor<Int32> {
1820-
precondition(isAxisInRange(axis), "Axis must be in the range `[-rank, rank)`.")
1820+
ensureValid(axes: [axis])
18211821
return _Raw.argMax(self, dimension: Int64(axis))
18221822
}
18231823

@@ -1827,7 +1827,7 @@ extension Tensor where Scalar: Numeric & Comparable {
18271827
/// - Precondition: Each value in `axes` must be in the range `-rank..<rank`.
18281828
@inlinable
18291829
public func argmin(squeezingAxis axis: Int) -> Tensor<Int32> {
1830-
precondition(isAxisInRange(axis), "Axis must be in the range `[-rank, rank)`.")
1830+
ensureValid(axes: [axis])
18311831
return _Raw.argMin(self, dimension: Tensor<Int32>(Int32(axis)))
18321832
}
18331833

@@ -1838,7 +1838,7 @@ extension Tensor where Scalar: Numeric & Comparable {
18381838
@inlinable
18391839
@differentiable(wrt: self where Scalar: TensorFlowFloatingPoint)
18401840
public func min(alongAxes axes: Tensor<Int32>) -> Tensor {
1841-
precondition(areAxesInRange(axes), "All axes must be in the range `[-rank, rank)`.")
1841+
ensureValid(axes: axes)
18421842
return _Raw.min(self, reductionIndices: axes, keepDims: true)
18431843
}
18441844

@@ -1871,7 +1871,7 @@ extension Tensor where Scalar: Numeric & Comparable {
18711871
@inlinable
18721872
@differentiable(wrt: self where Scalar: TensorFlowFloatingPoint)
18731873
public func max(alongAxes axes: Tensor<Int32>) -> Tensor {
1874-
precondition(areAxesInRange(axes), "All axes must be in the range `[-rank, rank)`.")
1874+
ensureValid(axes: axes)
18751875
return _Raw.max(self, reductionIndices: axes, keepDims: true)
18761876
}
18771877

@@ -2011,7 +2011,7 @@ extension Tensor where Scalar: Numeric {
20112011
@inlinable
20122012
@differentiable(wrt: self where Scalar: TensorFlowFloatingPoint)
20132013
public func sum(squeezingAxes axes: Tensor<Int32>) -> Tensor {
2014-
precondition(areAxesInRange(axes), "All axes must be in the range `[-rank, rank)`.")
2014+
ensureValid(axes: axes)
20152015
return _Raw.sum(self, reductionIndices: axes.scalars.map { Int64($0) }, keepDims: false)
20162016
}
20172017

@@ -2047,7 +2047,7 @@ extension Tensor where Scalar: Numeric {
20472047
@inlinable
20482048
@differentiable(wrt: self where Scalar: TensorFlowFloatingPoint)
20492049
public func sum(alongAxes axes: Tensor<Int32>) -> Tensor {
2050-
precondition(areAxesInRange(axes), "All axes must be in the range `[-rank, rank)`.")
2050+
ensureValid(axes: axes)
20512051
return _Raw.sum(self, reductionIndices: axes, keepDims: true)
20522052
}
20532053

@@ -2080,7 +2080,7 @@ extension Tensor where Scalar: Numeric {
20802080
@inlinable
20812081
@differentiable(wrt: self where Scalar: TensorFlowFloatingPoint)
20822082
public func product(squeezingAxes axes: Tensor<Int32>) -> Tensor {
2083-
precondition(areAxesInRange(axes), "All axes must be in the range `[-rank, rank)`.")
2083+
ensureValid(axes: axes)
20842084
return _Raw.prod(self, reductionIndices: axes, keepDims: false)
20852085
}
20862086

@@ -2118,7 +2118,7 @@ extension Tensor where Scalar: Numeric {
21182118
/// - Precondition: Each value in `axes` must be in the range `-rank..<rank`.
21192119
@inlinable
21202120
public func product(alongAxes axes: Tensor<Int32>) -> Tensor {
2121-
precondition(areAxesInRange(axes), "All axes must be in the range `[-rank, rank)`.")
2121+
ensureValid(axes: axes)
21222122
return _Raw.prod(self, reductionIndices: axes, keepDims: true)
21232123
}
21242124

@@ -2150,7 +2150,7 @@ extension Tensor where Scalar: Numeric {
21502150
@inlinable
21512151
@differentiable(wrt: self where Scalar: TensorFlowFloatingPoint)
21522152
public func mean(squeezingAxes axes: Tensor<Int32>) -> Tensor {
2153-
precondition(areAxesInRange(axes), "All axes must be in the range `[-rank, rank)`.")
2153+
ensureValid(axes: axes)
21542154
return _Raw.mean(self, reductionIndices: axes, keepDims: false)
21552155
}
21562156

@@ -2187,7 +2187,7 @@ extension Tensor where Scalar: Numeric {
21872187
@inlinable
21882188
@differentiable(wrt: self where Scalar: TensorFlowFloatingPoint)
21892189
public func mean(alongAxes axes: Tensor<Int32>) -> Tensor {
2190-
precondition(areAxesInRange(axes), "All axes must be in the range `[-rank, rank)`.")
2190+
ensureValid(axes: axes)
21912191
return _Raw.mean(self, reductionIndices: axes, keepDims: true)
21922192
}
21932193

@@ -2222,7 +2222,7 @@ extension Tensor where Scalar: Numeric {
22222222
@inlinable
22232223
@differentiable(wrt: self where Scalar: TensorFlowFloatingPoint)
22242224
public func variance(squeezingAxes axes: Tensor<Int32>) -> Tensor {
2225-
precondition(areAxesInRange(axes), "All axes must be in the range `[-rank, rank)`.")
2225+
ensureValid(axes: axes)
22262226
let squaredDiff = squaredDifference(self, mean(alongAxes: axes))
22272227
return squaredDiff.mean(squeezingAxes: axes)
22282228
}
@@ -2264,7 +2264,7 @@ extension Tensor where Scalar: Numeric {
22642264
@inlinable
22652265
@differentiable(wrt: self where Scalar: TensorFlowFloatingPoint)
22662266
public func variance(alongAxes axes: Tensor<Int32>) -> Tensor {
2267-
precondition(areAxesInRange(axes), "All axes must be in the range `[-rank, rank)`.")
2267+
ensureValid(axes: axes)
22682268
let squaredDiff = squaredDifference(self, mean(alongAxes: axes))
22692269
return squaredDiff.mean(alongAxes: axes)
22702270
}
@@ -2362,7 +2362,7 @@ extension Tensor where Scalar: Numeric {
23622362
exclusive: Bool = false,
23632363
reverse: Bool = false
23642364
) -> Tensor {
2365-
precondition(isAxisInRange(axis), "Axis must be in the range `[-rank, rank)`.")
2365+
ensureValid(axes: axis)
23662366
return _Raw.cumsum(self, axis: axis, exclusive: exclusive, reverse: reverse)
23672367
}
23682368

@@ -2437,7 +2437,7 @@ extension Tensor where Scalar: Numeric {
24372437
exclusive: Bool = false,
24382438
reverse: Bool = false
24392439
) -> Tensor {
2440-
precondition(isAxisInRange(axis), "Axis must be in the range `[-rank, rank)`.")
2440+
ensureValid(axes: axis)
24412441
return _Raw.cumprod(self, axis: axis, exclusive: exclusive, reverse: reverse)
24422442
}
24432443
}
@@ -2640,7 +2640,7 @@ extension Tensor where Scalar: TensorFlowFloatingPoint {
26402640
@inlinable
26412641
@differentiable(wrt: self)
26422642
public func standardDeviation(squeezingAxes axes: Tensor<Int32>) -> Tensor {
2643-
precondition(areAxesInRange(axes), "All axes must be in the range `[-rank, rank)`.")
2643+
ensureValid(axes: axes)
26442644
return Tensor.sqrt(variance(squeezingAxes: axes))
26452645
}
26462646

@@ -2652,7 +2652,7 @@ extension Tensor where Scalar: TensorFlowFloatingPoint {
26522652
@inlinable
26532653
@differentiable(wrt: self)
26542654
public func standardDeviation(squeezingAxes axes: [Int]) -> Tensor {
2655-
precondition(areAxesInRange(axes), "All axes must be in the range `[-rank, rank)`.")
2655+
ensureValid(axes: axes)
26562656
return Tensor.sqrt(variance(squeezingAxes: axes))
26572657
}
26582658

@@ -2686,7 +2686,7 @@ extension Tensor where Scalar: TensorFlowFloatingPoint {
26862686
@inlinable
26872687
@differentiable(wrt: self)
26882688
public func standardDeviation(alongAxes axes: Tensor<Int32>) -> Tensor {
2689-
precondition(areAxesInRange(axes), "All axes must be in the range `[-rank, rank)`.")
2689+
ensureValid(axes: axes)
26902690
return Tensor.sqrt(variance(alongAxes: axes))
26912691
}
26922692

@@ -2711,7 +2711,7 @@ extension Tensor where Scalar: TensorFlowFloatingPoint {
27112711
@inlinable
27122712
@differentiable(wrt: self)
27132713
public func standardDeviation(alongAxes axes: Int...) -> Tensor {
2714-
precondition(areAxesInRange(axes), "All axes must be in the range `[-rank, rank)`.")
2714+
ensureValid(axes: axes)
27152715
return Tensor.sqrt(variance(alongAxes: axes))
27162716
}
27172717

@@ -2726,7 +2726,7 @@ extension Tensor where Scalar: TensorFlowFloatingPoint {
27262726
@inlinable
27272727
@differentiable(wrt: self)
27282728
public func logSumExp(squeezingAxes axes: Tensor<Int32>) -> Tensor {
2729-
precondition(areAxesInRange(axes), "All axes must be in the range `[-rank, rank)`.")
2729+
ensureValid(axes: axes)
27302730
let rawMax = max(alongAxes: axes)
27312731
let offset = withoutDerivative(at: rawMax) { rawMax in
27322732
Tensor<Scalar>(zerosLike: rawMax).replacing(
@@ -2791,7 +2791,7 @@ extension Tensor where Scalar: TensorFlowFloatingPoint {
27912791
@inlinable
27922792
@differentiable(wrt: self)
27932793
public func logSumExp(alongAxes axes: Tensor<Int32>) -> Tensor {
2794-
precondition(areAxesInRange(axes), "All axes must be in the range `[-rank, rank)`.")
2794+
ensureValid(axes: axes)
27952795
let rawMax = max(alongAxes: axes)
27962796
let offset = withoutDerivative(at: rawMax) { rawMax in
27972797
Tensor<Scalar>(zerosLike: rawMax).replacing(
@@ -2858,7 +2858,7 @@ extension Tensor where Scalar: TensorFlowFloatingPoint {
28582858
@inlinable
28592859
@differentiable(wrt: self)
28602860
public func moments(squeezingAxes axes: Tensor<Int32>) -> Moments<Scalar> {
2861-
precondition(areAxesInRange(axes), "All axes must be in the range `[-rank, rank)`.")
2861+
ensureValid(axes: axes)
28622862
let mean = self.mean(alongAxes: axes)
28632863
let variance = squaredDifference(self, mean).mean(squeezingAxes: axes)
28642864
return Moments(
@@ -2876,7 +2876,7 @@ extension Tensor where Scalar: TensorFlowFloatingPoint {
28762876
@inlinable
28772877
@differentiable(wrt: self)
28782878
public func moments(squeezingAxes axes: [Int]) -> Moments<Scalar> {
2879-
precondition(areAxesInRange(axes), "All axes must be in the range `[-rank, rank)`.")
2879+
ensureValid(axes: axes)
28802880
let mean = self.mean(squeezingAxes: axes)
28812881
let variance = squaredDifference(self, mean).mean(squeezingAxes: axes)
28822882
return Moments(mean: mean, variance: variance)
@@ -2909,7 +2909,7 @@ extension Tensor where Scalar: TensorFlowFloatingPoint {
29092909
@inlinable
29102910
@differentiable(wrt: self)
29112911
public func moments(alongAxes axes: Tensor<Int32>) -> Moments<Scalar> {
2912-
precondition(areAxesInRange(axes), "All axes must be in the range `[-rank, rank)`.")
2912+
ensureValid(axes: axes)
29132913
let mean = self.mean(alongAxes: axes)
29142914
let variance = squaredDifference(self, mean).mean(alongAxes: axes)
29152915
return Moments<Scalar>(mean: mean, variance: variance)
@@ -2923,7 +2923,7 @@ extension Tensor where Scalar: TensorFlowFloatingPoint {
29232923
@inlinable
29242924
@differentiable(wrt: self)
29252925
public func moments(alongAxes axes: [Int]) -> Moments<Scalar> {
2926-
precondition(areAxesInRange(axes), "All axes must be in the range `[-rank, rank)`.")
2926+
ensureValid(axes: axes)
29272927
let mean = self.mean(alongAxes: axes)
29282928
let variance = squaredDifference(self, mean).mean(alongAxes: axes)
29292929
return Moments<Scalar>(mean: mean, variance: variance)
@@ -3018,3 +3018,37 @@ extension Tensor where Scalar: Numeric {
30183018
matmul(lhs, rhs)
30193019
}
30203020
}
3021+
3022+
//===------------------------------------------------------------------------------------------===//
3023+
// Precondition helpers.
3024+
//===------------------------------------------------------------------------------------------===//
3025+
3026+
internal extension Tensor {
3027+
@usableFromInline
3028+
func ensureValid(
3029+
axes: Tensor<Int32>,
3030+
function: StaticString = #function,
3031+
file: StaticString = #file,
3032+
line: UInt = #line
3033+
) {
3034+
precondition(
3035+
areAxesInRange(axes, file: file, line: line),
3036+
"All axes must be in the range `[-rank, rank)` when calling \(function) (rank is: \(rank), axes: \(axes))",
3037+
file: file,
3038+
line: line)
3039+
}
3040+
3041+
@usableFromInline
3042+
func ensureValid(
3043+
axes: [Int],
3044+
function: StaticString = #function,
3045+
file: StaticString = #file,
3046+
line: UInt = #line
3047+
) {
3048+
precondition(
3049+
areAxesInRange(axes),
3050+
"All axes must be in the range `[-rank, rank)` when calling \(function) (rank is: \(rank), axes: \(axes))",
3051+
file: file,
3052+
line: line)
3053+
}
3054+
}

0 commit comments

Comments
 (0)