Skip to content

Commit 4c7059f

Browse files
authored
[stdlib][SR-13883] Avoid advancing past representable bounds when striding (#34860)
* [stdlib][SR-13883] Avoid advancing past representable bounds when striding. * [stdlib] Expand a test and add a comment to ensure correct floating-point stride bounds checking. * [stdlib][NFC] Clarify a comment in a test. * [stdlib][NFC] Adjust copyright notices, clarify comments, delete '-swift-version=3' for tests. * [stdlib] Add implementations for fixed-width integer strides for performance. * [stdlib] Document `Strideable._step` and modify overflow checking behavior of `Stride*Iterator`. * [stdlib] Address reviewer comments, postpone documentation changes * [stdlib][NFC] Update documentation for '_step(after:from:by:)' * [stdlib][NFC] Use 'nil' instead of an arbitrary value for integer striding '_step' index
1 parent 031ffc0 commit 4c7059f

File tree

3 files changed

+123
-9
lines changed

3 files changed

+123
-9
lines changed

stdlib/public/core/Stride.swift

+96-7
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
//
33
// This source file is part of the Swift.org open source project
44
//
5-
// Copyright (c) 2014 - 2017 Apple Inc. and the Swift project authors
5+
// Copyright (c) 2014 - 2021 Apple Inc. and the Swift project authors
66
// Licensed under Apache License v2.0 with Runtime Library Exception
77
//
88
// See https://swift.org/LICENSE.txt for license information
@@ -134,15 +134,68 @@ public protocol Strideable: Comparable {
134134
/// If this type's `Stride` type conforms to `BinaryInteger`, then for a
135135
/// value `x`, a distance `n`, and a value `y = x.advanced(by: n)`,
136136
/// `x.distance(to: y) == n`. Using this method with types that have a
137-
/// noninteger `Stride` may result in an approximation.
137+
/// noninteger `Stride` may result in an approximation. If the result of
138+
/// advancing by `n` is not representable as a value of this type, then a
139+
/// runtime error may occur.
138140
///
139141
/// - Parameter n: The distance to advance this value.
140142
/// - Returns: A value that is offset from this value by `n`.
141143
///
142144
/// - Complexity: O(1)
143145
func advanced(by n: Stride) -> Self
144146

145-
/// `_step` is an implementation detail of Strideable; do not use it directly.
147+
/// Returns the next result of striding by a specified distance.
148+
///
149+
/// This method is an implementation detail of `Strideable`; do not call it
150+
/// directly.
151+
///
152+
/// While striding, `_step(after:from:by:)` is called at each step to
153+
/// determine the next result. At the first step, the value of `current` is
154+
/// `(index: 0, value: start)`. At each subsequent step, the value of
155+
/// `current` is the result returned by this method in the immediately
156+
/// preceding step.
157+
///
158+
/// If the result of advancing by a given `distance` is not representable as a
159+
/// value of this type, then a runtime error may occur.
160+
///
161+
/// Implementing `_step(after:from:by:)` to Customize Striding Behavior
162+
/// ===================================================================
163+
///
164+
/// The default implementation of this method calls `advanced(by:)` to offset
165+
/// `current.value` by a specified `distance`. No attempt is made to count the
166+
/// number of prior steps, and the result's `index` is always `nil`.
167+
///
168+
/// To avoid incurring runtime errors that arise from advancing past
169+
/// representable bounds, a conforming type can signal that the result of
170+
/// advancing by a given `distance` is not representable by using `Int.min` as
171+
/// a sentinel value for the result's `index`. In that case, the result's
172+
/// `value` must be either the minimum representable value of this type if
173+
/// `distance` is less than zero or the maximum representable value of this
174+
/// type otherwise. Fixed-width integer types make use of arithmetic
175+
/// operations reporting overflow to implement this customization.
176+
///
177+
/// A conforming type may use any positive value for the result's `index` as
178+
/// an opaque state that is private to that type. For example, floating-point
179+
/// types increment `index` with each step so that the corresponding `value`
180+
/// can be computed by multiplying the number of steps by the specified
181+
/// `distance`. Serially calling `advanced(by:)` would accumulate
182+
/// floating-point rounding error at each step, which is avoided by this
183+
/// customization.
184+
///
185+
/// - Parameters:
186+
/// - current: The result returned by this method in the immediately
187+
/// preceding step while striding, or `(index: 0, value: start)` if there
188+
/// have been no preceding steps.
189+
/// - start: The starting value used for the striding sequence.
190+
/// - distance: The amount to step by with each iteration of the striding
191+
/// sequence.
192+
/// - Returns: A tuple of `index` and `value`; `index` may be `nil`, any
193+
/// positive value as an opaque state private to the conforming type, or
194+
/// `Int.min` to signal that the notional result of advancing by `distance`
195+
/// is unrepresentable, and `value` is the next result after `current.value`
196+
/// while striding from `start` by `distance`.
197+
///
198+
/// - Complexity: O(1)
146199
static func _step(
147200
after current: (index: Int?, value: Self),
148201
from start: Self, by distance: Self.Stride
@@ -171,6 +224,39 @@ extension Strideable {
171224
}
172225
}
173226

227+
extension Strideable where Self: FixedWidthInteger & SignedInteger {
228+
@_alwaysEmitIntoClient
229+
public static func _step(
230+
after current: (index: Int?, value: Self),
231+
from start: Self, by distance: Self.Stride
232+
) -> (index: Int?, value: Self) {
233+
let value = current.value
234+
let (partialValue, overflow) =
235+
Self.bitWidth >= Self.Stride.bitWidth ||
236+
(value < (0 as Self)) == (distance < (0 as Self.Stride))
237+
? value.addingReportingOverflow(Self(distance))
238+
: (Self(Self.Stride(value) + distance), false)
239+
return overflow
240+
? (.min, distance < (0 as Self.Stride) ? .min : .max)
241+
: (nil, partialValue)
242+
}
243+
}
244+
245+
extension Strideable where Self: FixedWidthInteger & UnsignedInteger {
246+
@_alwaysEmitIntoClient
247+
public static func _step(
248+
after current: (index: Int?, value: Self),
249+
from start: Self, by distance: Self.Stride
250+
) -> (index: Int?, value: Self) {
251+
let (partialValue, overflow) = distance < (0 as Self.Stride)
252+
? current.value.subtractingReportingOverflow(Self(-distance))
253+
: current.value.addingReportingOverflow(Self(distance))
254+
return overflow
255+
? (.min, distance < (0 as Self.Stride) ? .min : .max)
256+
: (nil, partialValue)
257+
}
258+
}
259+
174260
extension Strideable where Stride: FloatingPoint {
175261
@inlinable // protocol-only
176262
public static func _step(
@@ -439,10 +525,13 @@ extension StrideThroughIterator: IteratorProtocol {
439525
public mutating func next() -> Element? {
440526
let result = _current.value
441527
if _stride > 0 ? result >= _end : result <= _end {
442-
// This check is needed because if we just changed the above operators
443-
// to > and <, respectively, we might advance current past the end
444-
// and throw it out of bounds (e.g. above Int.max) unnecessarily.
445-
if result == _end && !_didReturnEnd {
528+
// Note the `>=` and `<=` operators above. When `result == _end`, the
529+
// following check is needed to prevent advancing `_current` past the
530+
// representable bounds of the `Strideable` type unnecessarily.
531+
//
532+
// If the `Strideable` type is a fixed-width integer, overflowed results
533+
// are represented using a sentinel value for `_current.index`, `Int.min`.
534+
if result == _end && !_didReturnEnd && _current.index != .min {
446535
_didReturnEnd = true
447536
return result
448537
}

test/stdlib/Strideable.swift

+23-2
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,14 @@
22
//
33
// This source file is part of the Swift.org open source project
44
//
5-
// Copyright (c) 2014 - 2017 Apple Inc. and the Swift project authors
5+
// Copyright (c) 2014 - 2020 Apple Inc. and the Swift project authors
66
// Licensed under Apache License v2.0 with Runtime Library Exception
77
//
88
// See https://swift.org/LICENSE.txt for license information
99
// See https://swift.org/CONTRIBUTORS.txt for the list of Swift project authors
1010
//
1111
//===----------------------------------------------------------------------===//
12-
// RUN: %target-run-simple-swift -swift-version=3
12+
// RUN: %target-run-simple-swift
1313
// REQUIRES: executable_test
1414
//
1515

@@ -202,6 +202,27 @@ StrideTestSuite.test("FloatingPointStride/rounding error") {
202202
expectEqual(7, c.count)
203203
expectEqual(1 as Double, c.last)
204204
}
205+
206+
if (1 as Float).addingProduct(0.9, 6) == 6.3999996 {
207+
let d = Array(stride(from: 1 as Float, through: 6.3999996, by: 0.9))
208+
expectEqual(7, d.count)
209+
// The reason that `d` has seven elements and not six is that the fused
210+
// multiply-add operation `(1 as Float).addingProduct(0.9, 6)` gives the
211+
// result `6.3999996`. This is nonetheless the desired behavior because
212+
// avoiding error accumulation and intermediate rounding error wherever
213+
// possible will produce better results more often than not (see SR-6377).
214+
//
215+
// If checking of end bounds has been inadvertently modified such that we're
216+
// computing the distance from the penultimate element to the end (in this
217+
// case, `6.3999996 - (1 as Float).addingProduct(0.9, 5)`), then the last
218+
// element will be omitted here.
219+
//
220+
// Therefore, if the test has failed, there may have been a regression in
221+
// the bounds-checking logic of `Stride*Iterator`. Restore the expected
222+
// behavior here by ensuring that floating-point strides are opted out of
223+
// any bounds checking that performs arithmetic with values other than the
224+
// bounds themselves and the stride.
225+
}
205226
}
206227

207228
func strideIteratorTest<

validation-test/stdlib/Stride.swift

+4
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,15 @@ var StrideTestSuite = TestSuite("Stride")
99
StrideTestSuite.test("to") {
1010
checkSequence(Array(0...4), stride(from: 0, to: 5, by: 1))
1111
checkSequence(Array(1...5).reversed(), stride(from: 5, to: 0, by: -1))
12+
checkSequence(stride(from: 0, to: 127, by: 3).map { Int8($0) },
13+
stride(from: 0, to: 127 as Int8, by: 3))
1214
}
1315

1416
StrideTestSuite.test("through") {
1517
checkSequence(Array(0...5), stride(from: 0, through: 5, by: 1))
1618
checkSequence(Array(0...5).reversed(), stride(from: 5, through: 0, by: -1))
19+
checkSequence(stride(from: 0, through: 127, by: 3).map { Int8($0) },
20+
stride(from: 0, through: 127 as Int8, by: 3))
1721
}
1822

1923
runAllTests()

0 commit comments

Comments
 (0)