@@ -129,22 +129,27 @@ public extension Tensor where Scalar: TensorFlowFloatingPoint,
129
129
Scalar. RawSignificand: FixedWidthInteger {
130
130
/// Performs Glorot uniform initialization for the specified shape, creating a tensor by
131
131
/// randomly sampling scalar values from a uniform distribution between `-limit` and `limit`,
132
- /// where limit is `sqrt(6 / (fanIn + fanOut))`.
132
+ /// where limit is `sqrt(6 / (fanIn + fanOut))` and `fanIn`/`fanOut` represent the number of
133
+ /// input and output features multiplied by the receptive field if present.
133
134
///
134
135
/// - Parameters:
135
136
/// - shape: The dimensions of the tensor.
136
137
/// - generator: Random number generator to use.
137
138
///
138
139
init < G: RandomNumberGenerator > ( glorotUniform shape: TensorShape , generator: inout G ) {
139
- let fanIn = shape [ shape. count - 2 ]
140
- let fanOut = shape [ shape. count - 1 ]
140
+ let spatialDimCount = shape. count - 2
141
+ let receptiveField = shape [ 0 ..< spatialDimCount] . contiguousSize
142
+ let fanIn = shape [ shape. count - 2 ] * receptiveField
143
+ let fanOut = shape [ shape. count - 1 ] * receptiveField
141
144
let minusOneToOne = 2 * Tensor( randomUniform: shape, generator: & generator) - 1
142
145
self = sqrt ( Scalar ( 6 ) / Scalar( fanIn + fanOut) ) * minusOneToOne
143
146
}
144
147
145
148
/// Creates a tensor by performing Glorot uniform initialization for the specified shape,
146
149
/// randomly sampling scalar values from a uniform distribution between `-limit` and `limit`,
147
- /// where limit is `sqrt(6 / (fanIn + fanOut))`, using the default random number generator.
150
+ /// generated by the default random number generator, where limit is
151
+ /// `sqrt(6 / (fanIn + fanOut))` and `fanIn`/`fanOut` represent the number of input and output
152
+ /// features multiplied by the receptive field if present.
148
153
///
149
154
/// - Parameters:
150
155
/// - shape: The dimensions of the tensor.
0 commit comments