Skip to content
This repository was archived by the owner on Jul 1, 2023. It is now read-only.

Commit 22c923a

Browse files
authoredFeb 22, 2019
Fix Glorot uniform initialization for convolutional layers (#22)
1 parent b9a05df commit 22c923a

File tree

1 file changed

+9
-4
lines changed

1 file changed

+9
-4
lines changed
 

‎Sources/DeepLearning/Initializers.swift

+9-4
Original file line numberDiff line numberDiff line change
@@ -129,22 +129,27 @@ public extension Tensor where Scalar: TensorFlowFloatingPoint,
129129
Scalar.RawSignificand: FixedWidthInteger {
130130
/// Performs Glorot uniform initialization for the specified shape, creating a tensor by
131131
/// randomly sampling scalar values from a uniform distribution between `-limit` and `limit`,
132-
/// where limit is `sqrt(6 / (fanIn + fanOut))`.
132+
/// where limit is `sqrt(6 / (fanIn + fanOut))` and `fanIn`/`fanOut` represent the number of
133+
/// input and output features multiplied by the receptive field if present.
133134
///
134135
/// - Parameters:
135136
/// - shape: The dimensions of the tensor.
136137
/// - generator: Random number generator to use.
137138
///
138139
init<G: RandomNumberGenerator>(glorotUniform shape: TensorShape, generator: inout G) {
139-
let fanIn = shape[shape.count - 2]
140-
let fanOut = shape[shape.count - 1]
140+
let spatialDimCount = shape.count - 2
141+
let receptiveField = shape[0..<spatialDimCount].contiguousSize
142+
let fanIn = shape[shape.count - 2] * receptiveField
143+
let fanOut = shape[shape.count - 1] * receptiveField
141144
let minusOneToOne = 2 * Tensor(randomUniform: shape, generator: &generator) - 1
142145
self = sqrt(Scalar(6) / Scalar(fanIn + fanOut)) * minusOneToOne
143146
}
144147

145148
/// Creates a tensor by performing Glorot uniform initialization for the specified shape,
146149
/// randomly sampling scalar values from a uniform distribution between `-limit` and `limit`,
147-
/// where limit is `sqrt(6 / (fanIn + fanOut))`, using the default random number generator.
150+
/// generated by the default random number generator, where limit is
151+
/// `sqrt(6 / (fanIn + fanOut))` and `fanIn`/`fanOut` represent the number of input and output
152+
/// features multiplied by the receptive field if present.
148153
///
149154
/// - Parameters:
150155
/// - shape: The dimensions of the tensor.

0 commit comments

Comments
 (0)
This repository has been archived.