@@ -1693,7 +1693,7 @@ extension Tensor where Scalar == Bool {
1693
1693
/// - Precondition: Each value in `axes` must be in the range `-rank..<rank`.
1694
1694
@inlinable
1695
1695
public func all( squeezingAxes axes: Int ... ) -> Tensor {
1696
- precondition ( areAxesInRange ( axes) , " All axes must be in the range `[-rank, rank)`. " )
1696
+ ensureValid ( axes: axes)
1697
1697
let axes = axes. map ( Int32 . init)
1698
1698
return _Raw. all ( self , reductionIndices: Tensor < Int32 > ( axes) , keepDims: false )
1699
1699
}
@@ -1704,7 +1704,7 @@ extension Tensor where Scalar == Bool {
1704
1704
/// - Precondition: Each value in `axes` must be in the range `-rank..<rank`.
1705
1705
@inlinable
1706
1706
public func any( squeezingAxes axes: Int ... ) -> Tensor {
1707
- precondition ( areAxesInRange ( axes) , " All axes must be in the range `[-rank, rank)`. " )
1707
+ ensureValid ( axes: axes)
1708
1708
let axes = axes. map ( Int32 . init)
1709
1709
return _Raw. any ( self , reductionIndices: Tensor < Int32 > ( axes) , keepDims: false )
1710
1710
}
@@ -1715,7 +1715,7 @@ extension Tensor where Scalar == Bool {
1715
1715
/// - Precondition: Each value in `axes` must be in the range `-rank..<rank`.
1716
1716
@inlinable
1717
1717
public func all( alongAxes axes: Int ... ) -> Tensor {
1718
- precondition ( areAxesInRange ( axes) , " All axes must be in the range `[-rank, rank)`. " )
1718
+ ensureValid ( axes: axes)
1719
1719
let axes = axes. map ( Int32 . init)
1720
1720
return _Raw. all ( self , reductionIndices: Tensor < Int32 > ( axes) , keepDims: true )
1721
1721
}
@@ -1726,7 +1726,7 @@ extension Tensor where Scalar == Bool {
1726
1726
/// - Precondition: Each value in `axes` must be in the range `-rank..<rank`.
1727
1727
@inlinable
1728
1728
public func any( alongAxes axes: Int ... ) -> Tensor {
1729
- precondition ( areAxesInRange ( axes) , " All axes must be in the range `[-rank, rank)`. " )
1729
+ ensureValid ( axes: axes)
1730
1730
let axes = axes. map ( Int32 . init)
1731
1731
return _Raw. any ( self , reductionIndices: Tensor < Int32 > ( axes) , keepDims: true )
1732
1732
}
@@ -1757,7 +1757,7 @@ extension Tensor where Scalar: Numeric & Comparable {
1757
1757
@inlinable
1758
1758
@differentiable ( wrt: self where Scalar: TensorFlowFloatingPoint)
1759
1759
public func max( squeezingAxes axes: Tensor < Int32 > ) -> Tensor {
1760
- precondition ( areAxesInRange ( axes) , " All axes must be in the range `[-rank, rank)`. " )
1760
+ ensureValid ( axes: axes)
1761
1761
return _Raw. max ( self , reductionIndices: axes, keepDims: false )
1762
1762
}
1763
1763
@@ -1787,7 +1787,7 @@ extension Tensor where Scalar: Numeric & Comparable {
1787
1787
@inlinable
1788
1788
@differentiable ( wrt: self where Scalar: TensorFlowFloatingPoint)
1789
1789
public func min( squeezingAxes axes: Tensor < Int32 > ) -> Tensor {
1790
- precondition ( areAxesInRange ( axes) , " All axes must be in the range `[-rank, rank)`. " )
1790
+ ensureValid ( axes: axes)
1791
1791
return _Raw. min ( self , reductionIndices: axes, keepDims: false )
1792
1792
}
1793
1793
@@ -1817,7 +1817,7 @@ extension Tensor where Scalar: Numeric & Comparable {
1817
1817
/// - Precondition: Each value in `axes` must be in the range `-rank..<rank`.
1818
1818
@inlinable
1819
1819
public func argmax( squeezingAxis axis: Int ) -> Tensor < Int32 > {
1820
- precondition ( isAxisInRange ( axis ) , " Axis must be in the range `[-rank, rank)`. " )
1820
+ ensureValid ( axes : [ axis ] )
1821
1821
return _Raw. argMax ( self , dimension: Int64 ( axis) )
1822
1822
}
1823
1823
@@ -1827,7 +1827,7 @@ extension Tensor where Scalar: Numeric & Comparable {
1827
1827
/// - Precondition: Each value in `axes` must be in the range `-rank..<rank`.
1828
1828
@inlinable
1829
1829
public func argmin( squeezingAxis axis: Int ) -> Tensor < Int32 > {
1830
- precondition ( isAxisInRange ( axis ) , " Axis must be in the range `[-rank, rank)`. " )
1830
+ ensureValid ( axes : [ axis ] )
1831
1831
return _Raw. argMin ( self , dimension: Tensor < Int32 > ( Int32 ( axis) ) )
1832
1832
}
1833
1833
@@ -1838,7 +1838,7 @@ extension Tensor where Scalar: Numeric & Comparable {
1838
1838
@inlinable
1839
1839
@differentiable ( wrt: self where Scalar: TensorFlowFloatingPoint)
1840
1840
public func min( alongAxes axes: Tensor < Int32 > ) -> Tensor {
1841
- precondition ( areAxesInRange ( axes) , " All axes must be in the range `[-rank, rank)`. " )
1841
+ ensureValid ( axes: axes)
1842
1842
return _Raw. min ( self , reductionIndices: axes, keepDims: true )
1843
1843
}
1844
1844
@@ -1871,7 +1871,7 @@ extension Tensor where Scalar: Numeric & Comparable {
1871
1871
@inlinable
1872
1872
@differentiable ( wrt: self where Scalar: TensorFlowFloatingPoint)
1873
1873
public func max( alongAxes axes: Tensor < Int32 > ) -> Tensor {
1874
- precondition ( areAxesInRange ( axes) , " All axes must be in the range `[-rank, rank)`. " )
1874
+ ensureValid ( axes: axes)
1875
1875
return _Raw. max ( self , reductionIndices: axes, keepDims: true )
1876
1876
}
1877
1877
@@ -2011,7 +2011,7 @@ extension Tensor where Scalar: Numeric {
2011
2011
@inlinable
2012
2012
@differentiable ( wrt: self where Scalar: TensorFlowFloatingPoint)
2013
2013
public func sum( squeezingAxes axes: Tensor < Int32 > ) -> Tensor {
2014
- precondition ( areAxesInRange ( axes) , " All axes must be in the range `[-rank, rank)`. " )
2014
+ ensureValid ( axes: axes)
2015
2015
return _Raw. sum ( self , reductionIndices: axes. scalars. map { Int64 ( $0) } , keepDims: false )
2016
2016
}
2017
2017
@@ -2047,7 +2047,7 @@ extension Tensor where Scalar: Numeric {
2047
2047
@inlinable
2048
2048
@differentiable ( wrt: self where Scalar: TensorFlowFloatingPoint)
2049
2049
public func sum( alongAxes axes: Tensor < Int32 > ) -> Tensor {
2050
- precondition ( areAxesInRange ( axes) , " All axes must be in the range `[-rank, rank)`. " )
2050
+ ensureValid ( axes: axes)
2051
2051
return _Raw. sum ( self , reductionIndices: axes, keepDims: true )
2052
2052
}
2053
2053
@@ -2080,7 +2080,7 @@ extension Tensor where Scalar: Numeric {
2080
2080
@inlinable
2081
2081
@differentiable ( wrt: self where Scalar: TensorFlowFloatingPoint)
2082
2082
public func product( squeezingAxes axes: Tensor < Int32 > ) -> Tensor {
2083
- precondition ( areAxesInRange ( axes) , " All axes must be in the range `[-rank, rank)`. " )
2083
+ ensureValid ( axes: axes)
2084
2084
return _Raw. prod ( self , reductionIndices: axes, keepDims: false )
2085
2085
}
2086
2086
@@ -2118,7 +2118,7 @@ extension Tensor where Scalar: Numeric {
2118
2118
/// - Precondition: Each value in `axes` must be in the range `-rank..<rank`.
2119
2119
@inlinable
2120
2120
public func product( alongAxes axes: Tensor < Int32 > ) -> Tensor {
2121
- precondition ( areAxesInRange ( axes) , " All axes must be in the range `[-rank, rank)`. " )
2121
+ ensureValid ( axes: axes)
2122
2122
return _Raw. prod ( self , reductionIndices: axes, keepDims: true )
2123
2123
}
2124
2124
@@ -2150,7 +2150,7 @@ extension Tensor where Scalar: Numeric {
2150
2150
@inlinable
2151
2151
@differentiable ( wrt: self where Scalar: TensorFlowFloatingPoint)
2152
2152
public func mean( squeezingAxes axes: Tensor < Int32 > ) -> Tensor {
2153
- precondition ( areAxesInRange ( axes) , " All axes must be in the range `[-rank, rank)`. " )
2153
+ ensureValid ( axes: axes)
2154
2154
return _Raw. mean ( self , reductionIndices: axes, keepDims: false )
2155
2155
}
2156
2156
@@ -2187,7 +2187,7 @@ extension Tensor where Scalar: Numeric {
2187
2187
@inlinable
2188
2188
@differentiable ( wrt: self where Scalar: TensorFlowFloatingPoint)
2189
2189
public func mean( alongAxes axes: Tensor < Int32 > ) -> Tensor {
2190
- precondition ( areAxesInRange ( axes) , " All axes must be in the range `[-rank, rank)`. " )
2190
+ ensureValid ( axes: axes)
2191
2191
return _Raw. mean ( self , reductionIndices: axes, keepDims: true )
2192
2192
}
2193
2193
@@ -2222,7 +2222,7 @@ extension Tensor where Scalar: Numeric {
2222
2222
@inlinable
2223
2223
@differentiable ( wrt: self where Scalar: TensorFlowFloatingPoint)
2224
2224
public func variance( squeezingAxes axes: Tensor < Int32 > ) -> Tensor {
2225
- precondition ( areAxesInRange ( axes) , " All axes must be in the range `[-rank, rank)`. " )
2225
+ ensureValid ( axes: axes)
2226
2226
let squaredDiff = squaredDifference ( self , mean ( alongAxes: axes) )
2227
2227
return squaredDiff. mean ( squeezingAxes: axes)
2228
2228
}
@@ -2264,7 +2264,7 @@ extension Tensor where Scalar: Numeric {
2264
2264
@inlinable
2265
2265
@differentiable ( wrt: self where Scalar: TensorFlowFloatingPoint)
2266
2266
public func variance( alongAxes axes: Tensor < Int32 > ) -> Tensor {
2267
- precondition ( areAxesInRange ( axes) , " All axes must be in the range `[-rank, rank)`. " )
2267
+ ensureValid ( axes: axes)
2268
2268
let squaredDiff = squaredDifference ( self , mean ( alongAxes: axes) )
2269
2269
return squaredDiff. mean ( alongAxes: axes)
2270
2270
}
@@ -2362,7 +2362,7 @@ extension Tensor where Scalar: Numeric {
2362
2362
exclusive: Bool = false ,
2363
2363
reverse: Bool = false
2364
2364
) -> Tensor {
2365
- precondition ( isAxisInRange ( axis ) , " Axis must be in the range `[-rank, rank)`. " )
2365
+ ensureValid ( axes : axis )
2366
2366
return _Raw. cumsum ( self , axis: axis, exclusive: exclusive, reverse: reverse)
2367
2367
}
2368
2368
@@ -2437,7 +2437,7 @@ extension Tensor where Scalar: Numeric {
2437
2437
exclusive: Bool = false ,
2438
2438
reverse: Bool = false
2439
2439
) -> Tensor {
2440
- precondition ( isAxisInRange ( axis ) , " Axis must be in the range `[-rank, rank)`. " )
2440
+ ensureValid ( axes : axis )
2441
2441
return _Raw. cumprod ( self , axis: axis, exclusive: exclusive, reverse: reverse)
2442
2442
}
2443
2443
}
@@ -2640,7 +2640,7 @@ extension Tensor where Scalar: TensorFlowFloatingPoint {
2640
2640
@inlinable
2641
2641
@differentiable ( wrt: self )
2642
2642
public func standardDeviation( squeezingAxes axes: Tensor < Int32 > ) -> Tensor {
2643
- precondition ( areAxesInRange ( axes) , " All axes must be in the range `[-rank, rank)`. " )
2643
+ ensureValid ( axes: axes)
2644
2644
return Tensor . sqrt ( variance ( squeezingAxes: axes) )
2645
2645
}
2646
2646
@@ -2652,7 +2652,7 @@ extension Tensor where Scalar: TensorFlowFloatingPoint {
2652
2652
@inlinable
2653
2653
@differentiable ( wrt: self )
2654
2654
public func standardDeviation( squeezingAxes axes: [ Int ] ) -> Tensor {
2655
- precondition ( areAxesInRange ( axes) , " All axes must be in the range `[-rank, rank)`. " )
2655
+ ensureValid ( axes: axes)
2656
2656
return Tensor . sqrt ( variance ( squeezingAxes: axes) )
2657
2657
}
2658
2658
@@ -2686,7 +2686,7 @@ extension Tensor where Scalar: TensorFlowFloatingPoint {
2686
2686
@inlinable
2687
2687
@differentiable ( wrt: self )
2688
2688
public func standardDeviation( alongAxes axes: Tensor < Int32 > ) -> Tensor {
2689
- precondition ( areAxesInRange ( axes) , " All axes must be in the range `[-rank, rank)`. " )
2689
+ ensureValid ( axes: axes)
2690
2690
return Tensor . sqrt ( variance ( alongAxes: axes) )
2691
2691
}
2692
2692
@@ -2711,7 +2711,7 @@ extension Tensor where Scalar: TensorFlowFloatingPoint {
2711
2711
@inlinable
2712
2712
@differentiable ( wrt: self )
2713
2713
public func standardDeviation( alongAxes axes: Int ... ) -> Tensor {
2714
- precondition ( areAxesInRange ( axes) , " All axes must be in the range `[-rank, rank)`. " )
2714
+ ensureValid ( axes: axes)
2715
2715
return Tensor . sqrt ( variance ( alongAxes: axes) )
2716
2716
}
2717
2717
@@ -2726,7 +2726,7 @@ extension Tensor where Scalar: TensorFlowFloatingPoint {
2726
2726
@inlinable
2727
2727
@differentiable ( wrt: self )
2728
2728
public func logSumExp( squeezingAxes axes: Tensor < Int32 > ) -> Tensor {
2729
- precondition ( areAxesInRange ( axes) , " All axes must be in the range `[-rank, rank)`. " )
2729
+ ensureValid ( axes: axes)
2730
2730
let rawMax = max ( alongAxes: axes)
2731
2731
let offset = withoutDerivative ( at: rawMax) { rawMax in
2732
2732
Tensor < Scalar > ( zerosLike: rawMax) . replacing (
@@ -2791,7 +2791,7 @@ extension Tensor where Scalar: TensorFlowFloatingPoint {
2791
2791
@inlinable
2792
2792
@differentiable ( wrt: self )
2793
2793
public func logSumExp( alongAxes axes: Tensor < Int32 > ) -> Tensor {
2794
- precondition ( areAxesInRange ( axes) , " All axes must be in the range `[-rank, rank)`. " )
2794
+ ensureValid ( axes: axes)
2795
2795
let rawMax = max ( alongAxes: axes)
2796
2796
let offset = withoutDerivative ( at: rawMax) { rawMax in
2797
2797
Tensor < Scalar > ( zerosLike: rawMax) . replacing (
@@ -2858,7 +2858,7 @@ extension Tensor where Scalar: TensorFlowFloatingPoint {
2858
2858
@inlinable
2859
2859
@differentiable ( wrt: self )
2860
2860
public func moments( squeezingAxes axes: Tensor < Int32 > ) -> Moments < Scalar > {
2861
- precondition ( areAxesInRange ( axes) , " All axes must be in the range `[-rank, rank)`. " )
2861
+ ensureValid ( axes: axes)
2862
2862
let mean = self . mean ( alongAxes: axes)
2863
2863
let variance = squaredDifference ( self , mean) . mean ( squeezingAxes: axes)
2864
2864
return Moments (
@@ -2876,7 +2876,7 @@ extension Tensor where Scalar: TensorFlowFloatingPoint {
2876
2876
@inlinable
2877
2877
@differentiable ( wrt: self )
2878
2878
public func moments( squeezingAxes axes: [ Int ] ) -> Moments < Scalar > {
2879
- precondition ( areAxesInRange ( axes) , " All axes must be in the range `[-rank, rank)`. " )
2879
+ ensureValid ( axes: axes)
2880
2880
let mean = self . mean ( squeezingAxes: axes)
2881
2881
let variance = squaredDifference ( self , mean) . mean ( squeezingAxes: axes)
2882
2882
return Moments ( mean: mean, variance: variance)
@@ -2909,7 +2909,7 @@ extension Tensor where Scalar: TensorFlowFloatingPoint {
2909
2909
@inlinable
2910
2910
@differentiable ( wrt: self )
2911
2911
public func moments( alongAxes axes: Tensor < Int32 > ) -> Moments < Scalar > {
2912
- precondition ( areAxesInRange ( axes) , " All axes must be in the range `[-rank, rank)`. " )
2912
+ ensureValid ( axes: axes)
2913
2913
let mean = self . mean ( alongAxes: axes)
2914
2914
let variance = squaredDifference ( self , mean) . mean ( alongAxes: axes)
2915
2915
return Moments < Scalar > ( mean: mean, variance: variance)
@@ -2923,7 +2923,7 @@ extension Tensor where Scalar: TensorFlowFloatingPoint {
2923
2923
@inlinable
2924
2924
@differentiable ( wrt: self )
2925
2925
public func moments( alongAxes axes: [ Int ] ) -> Moments < Scalar > {
2926
- precondition ( areAxesInRange ( axes) , " All axes must be in the range `[-rank, rank)`. " )
2926
+ ensureValid ( axes: axes)
2927
2927
let mean = self . mean ( alongAxes: axes)
2928
2928
let variance = squaredDifference ( self , mean) . mean ( alongAxes: axes)
2929
2929
return Moments < Scalar > ( mean: mean, variance: variance)
@@ -3018,3 +3018,37 @@ extension Tensor where Scalar: Numeric {
3018
3018
matmul ( lhs, rhs)
3019
3019
}
3020
3020
}
3021
+
3022
+ //===------------------------------------------------------------------------------------------===//
3023
+ // Precondition helpers.
3024
+ //===------------------------------------------------------------------------------------------===//
3025
+
3026
+ internal extension Tensor {
3027
+ @usableFromInline
3028
+ func ensureValid(
3029
+ axes: Tensor < Int32 > ,
3030
+ function: StaticString = #function,
3031
+ file: StaticString = #file,
3032
+ line: UInt = #line
3033
+ ) {
3034
+ precondition (
3035
+ areAxesInRange ( axes, file: file, line: line) ,
3036
+ " All axes must be in the range `[-rank, rank)` when calling \( function) (rank is: \( rank) , axes: \( axes) ) " ,
3037
+ file: file,
3038
+ line: line)
3039
+ }
3040
+
3041
+ @usableFromInline
3042
+ func ensureValid(
3043
+ axes: [ Int ] ,
3044
+ function: StaticString = #function,
3045
+ file: StaticString = #file,
3046
+ line: UInt = #line
3047
+ ) {
3048
+ precondition (
3049
+ areAxesInRange ( axes) ,
3050
+ " All axes must be in the range `[-rank, rank)` when calling \( function) (rank is: \( rank) , axes: \( axes) ) " ,
3051
+ file: file,
3052
+ line: line)
3053
+ }
3054
+ }
0 commit comments