This repository was archived by the owner on Jul 1, 2023. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 137
/
Copy pathRandom.swift
648 lines (576 loc) · 19.3 KB
/
Random.swift
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
// Copyright 2019 The TensorFlow Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#if os(macOS) || os(iOS) || os(watchOS) || os(tvOS)
import Darwin
#elseif os(Windows)
import ucrt
#else
import Glibc
#endif
public typealias TensorFlowSeed = (graph: Int32, op: Int32)
/// Generates a new random seed for TensorFlow.
public func randomSeedForTensorFlow(using seed: TensorFlowSeed? = nil) -> TensorFlowSeed {
var strongSeed = UInt64(0)
if let s = seed {
let bytes = (s.graph.bytes() + s.op.bytes())[...]
let singleSeed = UInt64(bytes: bytes, startingAt: bytes.startIndex)
strongSeed = UInt64(pow(Double(singleSeed % 2), Double(8 * 8)))
} else {
strongSeed = UInt64.random(in: UInt64.min..<UInt64.max)
}
// Many machine learning systems are likely to have many random number generators active at
// once (e.g., in reinforcement learning we may have an environment running in multiple
// processes). There is literature indicating that having linear correlations between seeds of
// multiple PRNG's can correlate the outputs:
// - http://blogs.unity3d.com/2015/01/07/a-primer-on-repeatable-random-numbers
// - http://stackoverflow.com/questions/1554958/how-different-do-random-seeds-need-to-be
// - http://dl.acm.org/citation.cfm?id=1276928
// Thus, for sanity we hash the generated seed before using it, This scheme is likely not
// crypto-strength, but it should be good enough to get rid of simple correlations.
// Reference: https://github.com/openai/gym/blob/master/gym/utils/seeding.py
let hash = strongSeed.bytes().sha512()
let graph = Int32(bytes: [hash[0], hash[1], hash[2], hash[3]], startingAt: 0)
let op = Int32(bytes: [hash[4], hash[5], hash[6], hash[7]], startingAt: 0)
return (graph: graph, op: op)
}
//===------------------------------------------------------------------------------------------===//
// Random Number Generators
//===------------------------------------------------------------------------------------------===//
/// A type-erased random number generator.
///
/// The `AnyRandomNumberGenerator` type forwards random number generating operations to an
/// underlying random number generator, hiding its specific underlying type.
public struct AnyRandomNumberGenerator: RandomNumberGenerator {
@usableFromInline
var _rng: RandomNumberGenerator
/// - Parameter rng: A random number generator.
@inlinable
public init(_ rng: RandomNumberGenerator) {
self._rng = rng
}
@inlinable
public mutating func next() -> UInt64 {
return self._rng.next()
}
}
/// A type that provides seedable deterministic pseudo-random data.
///
/// A SeedableRandomNumberGenerator can be used anywhere where a
/// RandomNumberGenerator would be used. It is useful when the pseudo-random
/// data needs to be reproducible across runs.
///
/// Conforming to the SeedableRandomNumberGenerator Protocol
/// ========================================================
///
/// To make a custom type conform to the `SeedableRandomNumberGenerator`
/// protocol, implement the `init(seed: [UInt8])` initializer, as well as the
/// requirements for `RandomNumberGenerator`. The values returned by `next()`
/// must form a deterministic sequence that depends only on the seed provided
/// upon initialization.
public protocol SeedableRandomNumberGenerator: RandomNumberGenerator {
init(seed: [UInt8])
init<T: BinaryInteger>(seed: T)
}
extension SeedableRandomNumberGenerator {
public init<T: BinaryInteger>(seed: T) {
var newSeed: [UInt8] = []
for i in 0..<seed.bitWidth / UInt8.bitWidth {
newSeed.append(UInt8(truncatingIfNeeded: seed >> (UInt8.bitWidth * i)))
}
self.init(seed: newSeed)
}
}
/// An implementation of `SeedableRandomNumberGenerator` using ARC4.
///
/// ARC4 is a stream cipher that generates a pseudo-random stream of bytes. This
/// PRNG uses the seed as its key.
///
/// ARC4 is described in Schneier, B., "Applied Cryptography: Protocols,
/// Algorithms, and Source Code in C", 2nd Edition, 1996.
///
/// An individual generator is not thread-safe, but distinct generators do not
/// share state. The random data generated is of high-quality, but is not
/// suitable for cryptographic applications.
@frozen
public struct ARC4RandomNumberGenerator: SeedableRandomNumberGenerator {
public static var global = ARC4RandomNumberGenerator(seed: UInt32(time(nil)))
var state: [UInt8] = Array(0...255)
var iPos: UInt8 = 0
var jPos: UInt8 = 0
/// Initialize ARC4RandomNumberGenerator using an array of UInt8. The array
/// must have length between 1 and 256 inclusive.
public init(seed: [UInt8]) {
precondition(seed.count > 0, "Length of seed must be positive")
precondition(seed.count <= 256, "Length of seed must be at most 256")
var j: UInt8 = 0
for i: UInt8 in 0...255 {
j &+= S(i) &+ seed[Int(i) % seed.count]
swapAt(i, j)
}
}
// Produce the next random UInt64 from the stream, and advance the internal
// state.
public mutating func next() -> UInt64 {
var result: UInt64 = 0
for _ in 0..<UInt64.bitWidth / UInt8.bitWidth {
result <<= UInt8.bitWidth
result += UInt64(nextByte())
}
return result
}
// Helper to access the state.
private func S(_ index: UInt8) -> UInt8 {
return state[Int(index)]
}
// Helper to swap elements of the state.
private mutating func swapAt(_ i: UInt8, _ j: UInt8) {
state.swapAt(Int(i), Int(j))
}
// Generates the next byte in the keystream.
private mutating func nextByte() -> UInt8 {
iPos &+= 1
jPos &+= S(iPos)
swapAt(iPos, jPos)
return S(S(iPos) &+ S(jPos))
}
}
private typealias UInt32x2 = (UInt32, UInt32)
private typealias UInt32x4 = (UInt32, UInt32, UInt32, UInt32)
/// An implementation of `SeedableRandomNumberGenerator` using Threefry.
/// Salmon et al. SC 2011. Parallel random numbers: as easy as 1, 2, 3.
/// http://www.thesalmons.org/john/random123/papers/random123sc11.pdf
///
/// This struct implements a 20-round Threefry2x32 PRNG. It must be seeded with
/// a 64-bit value.
///
/// An individual generator is not thread-safe, but distinct generators do not
/// share state. The random data generated is of high-quality, but is not
/// suitable for cryptographic applications.
public struct ThreefryRandomNumberGenerator: SeedableRandomNumberGenerator {
public static var global = ThreefryRandomNumberGenerator(
uint64Seed: UInt64(time(nil))
)
private let rot: (UInt32, UInt32, UInt32, UInt32, UInt32, UInt32, UInt32, UInt32) = (
13, 15, 26, 6, 17, 29, 16, 24
)
private func rotl32(value: UInt32, n: UInt32) -> UInt32 {
return (value << (n & 31)) | (value >> ((32 - n) & 31))
}
private var ctr: UInt64 = 0
private let key: UInt32x2
private func random(forCtr ctr: UInt32x2, key: UInt32x2) -> UInt32x2 {
let skeinKsParity32: UInt32 = 0x1BD1_1BDA
let ks0 = key.0
let ks1 = key.1
let ks2 = skeinKsParity32 ^ key.0 ^ key.1
var X0 = ctr.0
var X1 = ctr.1
// 20 rounds
// Key injection (r = 0)
X0 &+= ks0
X1 &+= ks1
// R1
X0 &+= X1
X1 = rotl32(value: X1, n: rot.0)
X1 ^= X0
// R2
X0 &+= X1
X1 = rotl32(value: X1, n: rot.1)
X1 ^= X0
// R3
X0 &+= X1
X1 = rotl32(value: X1, n: rot.2)
X1 ^= X0
// R4
X0 &+= X1
X1 = rotl32(value: X1, n: rot.3)
X1 ^= X0
// Key injection (r = 1)
X0 &+= ks1
X1 &+= (ks2 + 1)
// R5
X0 &+= X1
X1 = rotl32(value: X1, n: rot.4)
X1 ^= X0
// R6
X0 &+= X1
X1 = rotl32(value: X1, n: rot.5)
X1 ^= X0
// R7
X0 &+= X1
X1 = rotl32(value: X1, n: rot.6)
X1 ^= X0
// R8
X0 &+= X1
X1 = rotl32(value: X1, n: rot.7)
X1 ^= X0
// Key injection (r = 2)
X0 &+= ks2
X1 &+= (ks0 + 2)
// R9
X0 &+= X1
X1 = rotl32(value: X1, n: rot.0)
X1 ^= X0
// R10
X0 &+= X1
X1 = rotl32(value: X1, n: rot.1)
X1 ^= X0
// R11
X0 &+= X1
X1 = rotl32(value: X1, n: rot.2)
X1 ^= X0
// R12
X0 &+= X1
X1 = rotl32(value: X1, n: rot.3)
X1 ^= X0
// Key injection (r = 3)
X0 &+= ks0
X1 &+= (ks1 + 3)
// R13
X0 &+= X1
X1 = rotl32(value: X1, n: rot.4)
X1 ^= X0
// R14
X0 &+= X1
X1 = rotl32(value: X1, n: rot.5)
X1 ^= X0
// R15
X0 &+= X1
X1 = rotl32(value: X1, n: rot.6)
X1 ^= X0
// R16
X0 &+= X1
X1 = rotl32(value: X1, n: rot.7)
X1 ^= X0
// Key injection (r = 4)
X0 &+= ks1
X1 &+= (ks2 + 4)
// R17
X0 &+= X1
X1 = rotl32(value: X1, n: rot.0)
X1 ^= X0
// R18
X0 &+= X1
X1 = rotl32(value: X1, n: rot.1)
X1 ^= X0
// R19
X0 &+= X1
X1 = rotl32(value: X1, n: rot.2)
X1 ^= X0
// R20
X0 &+= X1
X1 = rotl32(value: X1, n: rot.3)
X1 ^= X0
// Key injection (r = 5)
X0 &+= ks2
X1 &+= (ks0 + 5)
return (X0, X1)
}
internal init(uint64Seed seed: UInt64) {
key = seed.vector2
}
public init(seed: [UInt8]) {
precondition(seed.count > 0, "Length of seed must be positive")
precondition(seed.count <= 8, "Length of seed must be at most 8")
var combinedSeed: UInt64 = 0
for (i, byte) in seed.enumerated() {
combinedSeed += UInt64(byte) << UInt64(8 * i)
}
self.init(uint64Seed: combinedSeed)
}
public mutating func next() -> UInt64 {
defer { ctr += 1 }
return UInt64(vector: random(forCtr: ctr.vector2, key: key))
}
}
/// An implementation of `SeedableRandomNumberGenerator` using Philox.
/// Salmon et al. SC 2011. Parallel random numbers: as easy as 1, 2, 3.
/// http://www.thesalmons.org/john/random123/papers/random123sc11.pdf
///
/// This struct implements a 10-round Philox4x32 PRNG. It must be seeded with
/// a 64-bit value.
///
/// An individual generator is not thread-safe, but distinct generators do not
/// share state. The random data generated is of high-quality, but is not
/// suitable for cryptographic applications.
public struct PhiloxRandomNumberGenerator: SeedableRandomNumberGenerator {
public static var global = PhiloxRandomNumberGenerator(uint64Seed: UInt64(time(nil)))
private var ctr: UInt64 = 0
private let key: UInt32x2
// Since we generate two 64-bit values at a time, we only need to run the
// generator every other invocation.
private var useNextValue = false
private var nextValue: UInt64 = 0
private func bump(key: UInt32x2) -> UInt32x2 {
let bumpConstantHi: UInt32 = 0x9E37_79B9
let bumpConstantLo: UInt32 = 0xBB67_AE85
return (key.0 &+ bumpConstantHi, key.1 &+ bumpConstantLo)
}
private func round(ctr: UInt32x4, key: UInt32x2) -> UInt32x4 {
let roundConstant0: UInt64 = 0xD251_1F53
let roundConstant1: UInt64 = 0xCD9E_8D57
let product0: UInt64 = roundConstant0 &* UInt64(ctr.0)
let hi0 = UInt32(truncatingIfNeeded: product0 >> 32)
let lo0 = UInt32(truncatingIfNeeded: (product0 & 0x0000_0000_FFFF_FFFF))
let product1: UInt64 = roundConstant1 &* UInt64(ctr.2)
let hi1 = UInt32(truncatingIfNeeded: product1 >> 32)
let lo1 = UInt32(truncatingIfNeeded: (product1 & 0x0000_0000_FFFF_FFFF))
return (hi1 ^ ctr.1 ^ key.0, lo1, hi0 ^ ctr.3 ^ key.1, lo0)
}
private func random(forCtr initialCtr: UInt32x4, key initialKey: UInt32x2) -> UInt32x4 {
var ctr = initialCtr
var key = initialKey
// 10 rounds
// R1
ctr = round(ctr: ctr, key: key)
// R2
key = bump(key: key)
ctr = round(ctr: ctr, key: key)
// R3
key = bump(key: key)
ctr = round(ctr: ctr, key: key)
// R4
key = bump(key: key)
ctr = round(ctr: ctr, key: key)
// R5
key = bump(key: key)
ctr = round(ctr: ctr, key: key)
// R6
key = bump(key: key)
ctr = round(ctr: ctr, key: key)
// R7
key = bump(key: key)
ctr = round(ctr: ctr, key: key)
// R8
key = bump(key: key)
ctr = round(ctr: ctr, key: key)
// R9
key = bump(key: key)
ctr = round(ctr: ctr, key: key)
// R10
key = bump(key: key)
ctr = round(ctr: ctr, key: key)
return ctr
}
public init(uint64Seed seed: UInt64) {
key = seed.vector2
}
public init(seed: [UInt8]) {
precondition(seed.count > 0, "Length of seed must be positive")
precondition(seed.count <= 8, "Length of seed must be at most 8")
var combinedSeed: UInt64 = 0
for (i, byte) in seed.enumerated() {
combinedSeed += UInt64(byte) << UInt64(8 * i)
}
self.init(uint64Seed: combinedSeed)
}
public mutating func next() -> UInt64 {
if useNextValue {
useNextValue = false
return nextValue
}
let (this, next) = makeUInt64Pair(random(forCtr: ctr.vector4, key: key))
useNextValue = true
nextValue = next
ctr += 1
return this
}
}
/// Private helpers.
extension UInt64 {
fileprivate var vector2: UInt32x2 {
let msb = UInt32(truncatingIfNeeded: self >> 32)
let lsb = UInt32(truncatingIfNeeded: self & 0x0000_0000_FFFF_FFFF)
return (msb, lsb)
}
fileprivate var vector4: UInt32x4 {
let msb = UInt32(truncatingIfNeeded: self >> 32)
let lsb = UInt32(truncatingIfNeeded: self & 0x0000_0000_FFFF_FFFF)
return (0, 0, msb, lsb)
}
fileprivate init(vector: UInt32x2) {
self = (UInt64(vector.0) << 32) + UInt64(vector.1)
}
}
private func makeUInt64Pair(_ vector: UInt32x4) -> (UInt64, UInt64) {
let a = (UInt64(vector.0) << 32) + UInt64(vector.1)
let b = (UInt64(vector.2) << 32) + UInt64(vector.3)
return (a, b)
}
//===------------------------------------------------------------------------------------------===//
// Distributions
//===------------------------------------------------------------------------------------------===//
public protocol RandomDistribution {
associatedtype Sample
func next<G: RandomNumberGenerator>(using generator: inout G) -> Sample
}
@frozen
public struct UniformIntegerDistribution<T: FixedWidthInteger>: RandomDistribution {
public let lowerBound: T
public let upperBound: T
public init(lowerBound: T = T.self.min, upperBound: T = T.self.max) {
self.lowerBound = lowerBound
self.upperBound = upperBound
}
public func next<G: RandomNumberGenerator>(using rng: inout G) -> T {
return T.random(in: lowerBound...upperBound, using: &rng)
}
}
@frozen
public struct UniformFloatingPointDistribution<T: BinaryFloatingPoint>: RandomDistribution
where T.RawSignificand: FixedWidthInteger {
public let lowerBound: T
public let upperBound: T
public init(lowerBound: T = 0, upperBound: T = 1) {
self.lowerBound = lowerBound
self.upperBound = upperBound
}
public func next<G: RandomNumberGenerator>(using rng: inout G) -> T {
return T.random(in: lowerBound..<upperBound, using: &rng)
}
}
@frozen
public struct NormalDistribution<T: BinaryFloatingPoint>: RandomDistribution
where T.RawSignificand: FixedWidthInteger {
public let mean: T
public let standardDeviation: T
private let uniformDist = UniformFloatingPointDistribution<T>()
public init(mean: T = 0, standardDeviation: T = 1) {
self.mean = mean
self.standardDeviation = standardDeviation
}
public func next<G: RandomNumberGenerator>(using rng: inout G) -> T {
// FIXME: Box-Muller can generate two values for only a little more than the
// cost of one.
let u1 = uniformDist.next(using: &rng)
let u2 = uniformDist.next(using: &rng)
let r = (-2 * T(log(Double(u1)))).squareRoot()
let theta: Double = 2 * Double.pi * Double(u2)
let normal01 = r * T(cos(theta))
return mean + standardDeviation * normal01
}
}
@frozen
public struct BetaDistribution: RandomDistribution {
public let alpha: Float
public let beta: Float
private let uniformDistribution = UniformFloatingPointDistribution<Float>()
public init(alpha: Float = 0, beta: Float = 1) {
self.alpha = alpha
self.beta = beta
}
public func next<G: RandomNumberGenerator>(using rng: inout G) -> Float {
// Generate a sample using Cheng's sampling algorithm from:
// R. C. H. Cheng, "Generating beta variates with nonintegral shape
// parameters.". Communications of the ACM, 21, 317-322, 1978.
let a = min(alpha, beta)
let b = max(alpha, beta)
if a > 1 {
return BetaDistribution.chengsAlgorithmBB(alpha, a, b, using: &rng)
} else {
return BetaDistribution.chengsAlgorithmBC(alpha, b, a, using: &rng)
}
}
/// Returns one sample from a Beta(alpha, beta) distribution using Cheng's BB
/// algorithm, when both alpha and beta are greater than 1.
///
/// - Parameters:
/// - alpha: First Beta distribution shape parameter.
/// - a: `min(alpha, beta)`.
/// - b: `max(alpha, beta)`.
/// - rng: Random number generator.
///
/// - Returns: Sample obtained using Cheng's BB algorithm.
private static func chengsAlgorithmBB<G: RandomNumberGenerator>(
_ alpha0: Float,
_ a: Float,
_ b: Float,
using rng: inout G
) -> Float {
let alpha = a + b
let beta = sqrtf((alpha - 2.0) / (2 * a * b - alpha))
let gamma = a + 1 / beta
var r: Float = 0.0
var w: Float = 0.0
var t: Float = 0.0
repeat {
let u1 = Float.random(in: 0.0...1.0, using: &rng)
let u2 = Float.random(in: 0.0...1.0, using: &rng)
let v = beta * (logf(u1) - log1pf(-u1))
r = gamma * v - 1.3862944
let z = u1 * u1 * u2
w = a * expf(v)
let s = a + r - w
if s + 2.609438 >= 5 * z {
break
}
t = logf(z)
if s >= t {
break
}
} while r + alpha * (logf(alpha) - logf(b + w)) < t
w = min(w, Float.greatestFiniteMagnitude)
return a == alpha0 ? w / (b + w) : b / (b + w)
}
/// Returns one sample from a Beta(alpha, beta) distribution using Cheng's BC
/// algorithm, when at least one of alpha and beta is less than 1.
///
/// - Parameters:
/// - alpha: First Beta distribution shape parameter.
/// - a: `max(alpha, beta)`.
/// - b: `min(alpha, beta)`.
/// - rng: Random number generator.
///
/// - Returns: Sample obtained using Cheng's BB algorithm.
private static func chengsAlgorithmBC<G: RandomNumberGenerator>(
_ alpha0: Float,
_ a: Float,
_ b: Float,
using rng: inout G
) -> Float {
let alpha = a + b
let beta = 1 / b
let delta = 1 + a - b
let k1 = delta * (0.0138889 + 0.0416667 * b) / (a * beta - 0.777778)
let k2 = 0.25 + (0.5 + 0.25 / delta) * b
var w: Float = 0.0
while true {
let u1 = Float.random(in: 0.0...1.0, using: &rng)
let u2 = Float.random(in: 0.0...1.0, using: &rng)
let y = u1 * u2
let z = u1 * y
if u1 < 0.5 {
if 0.25 * u2 + z - y >= k1 {
continue
}
} else {
if z <= 0.25 {
let v = beta * (logf(u1) - log1pf(-u1))
w = a * expf(v)
break
}
if z >= k2 {
continue
}
}
let v = beta * (logf(u1) - log1pf(-u1))
w = a * expf(v)
if alpha * (logf(alpha) - logf(b + 1) + v) - 1.3862944 >= logf(z) {
break
}
}
w = min(w, Float.greatestFiniteMagnitude)
return a == alpha0 ? w / (b + w) : b / (b + w)
}
}