-
Notifications
You must be signed in to change notification settings - Fork 137
Conversation
@Shashi456 This does pass locally for me |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks. This looks great!
inputSize: Int, | ||
hiddenSize: Int, | ||
seed: TensorFlowSeed = Context.local.randomSeed | ||
) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
) { | |
) { |
public var updateBias, outputBias, resetBias: Tensor<Scalar> | ||
|
||
@noDerivative public var stateShape: TensorShape { | ||
TensorShape([1, updateWeight.shape[0]]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
TensorShape([1, updateWeight.shape[0]]) | |
[1, updateWeight.shape[0]] |
Use literal initialization when a contextual type exists.
@@ -200,6 +200,75 @@ public struct LSTMCell<Scalar: TensorFlowFloatingPoint>: RNNCell { | |||
} | |||
} | |||
|
|||
/// An GRU cell. | |||
public struct GRUCell<Scalar: TensorFlowFloatingPoint>: RNNCell { | |||
public var updateWeight, updateWeight2, resetWeight, resetWeight2, outputWeight, outputWeight2: Tensor<Scalar> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Break this into multiple var
declarations so that it fits within 100 columns.
/// - Returns: The hidden state. | ||
@differentiable | ||
public func callAsFunction(_ input: Input) -> Output { | ||
let resetGate = sigmoid(matmul(input.input, resetWeight) + matmul(input.state.hidden, resetWeight2) + resetBias) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Make sure all lines fit within 100 columns.
let x = Tensor<Float>(rangeFrom: 0.0, to: 0.4, stride: 0.1).rankLifted() | ||
let inputs: [Tensor<Float>] = Array(repeating: x, count: 4) | ||
let rnn = RNN(GRUCell<Float>(inputSize: 4, hiddenSize: 4, | ||
seed: (0xFeed, 0xBeef))) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Move this to the end of the previous line.
@differentiable | ||
public func callAsFunction(_ input: Input) -> Output { | ||
let resetGate = sigmoid(matmul(input.input, resetWeight) + | ||
matmul(input.state.hidden, resetWeight2) + resetBias) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Indent each line wrapping by 4 from the previous line. The only difference from the Google Swift Style Guide is that we use 4-space indentation instead of 2-space.
matmul(input.state.hidden, resetWeight2) + resetBias) | |
matmul(input.state.hidden, resetWeight2) + resetBias) |
Same for other occurrences below.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fixed 🙂
@@ -200,6 +200,80 @@ public struct LSTMCell<Scalar: TensorFlowFloatingPoint>: RNNCell { | |||
} | |||
} | |||
|
|||
/// An GRU cell. | |||
public struct GRUCell<Scalar: TensorFlowFloatingPoint>: RNNCell { | |||
public var updateWeight, updateWeight2: Tensor<Scalar> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
One last thing: I think it's better to rename variables that have a 2
variant to have a 1
suffix. So updateWeight1
, resetWeight1
, and outputWeight1
. What do you think?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't feel strongly at all, so I made the change.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry, I meant updating updateWeight
to be updateWeight1
so that you'll have updateWeight1
and updateWeight2
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Haha that makes way more sense. I agree this is an improvement. Made the changes accordingly
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks!
5936f1f
to
81046d8
Compare
CI has been broken recently and we are looking into it. The PR looks good to me so we'll merge it when CI gets fixed. |
@dhasl002 Could you pull master, The errors are from PR changes. I guess this should pass after that. |
inputSize: Int, | ||
hiddenSize: Int, | ||
seed: TensorFlowSeed = Context.local.randomSeed | ||
) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you switch to using the initialization convention we use for other layers? See, for example, how the Dense
layer initializers are defined.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you be more specific? I see some small differences between the initializers, but I'm not exactly sure which conventions you are referring to.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry for the confusion. I was referring to the initialization method being used (e.g., zeros
vs glorotUniform
vs others). The approach followed for other layers allows the user to provide a custom initialization method for the layer parameters if they want to. You should also modify this one to support custom initialization methods.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@eaplatanios Sorry, I am still confused. Could you tell me if I am understanding you correctly?
In other words, you would like me to add the weights and biases to the initializer so that someone could initialize them differently?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@eaplatanios @sgugger @marcrasi This is the last request on this PR, could you give me the requested info so that we can get this merged in. Thanks 😄
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry I missed that comment. The requested change is to pass a weightInitializer
and biasInitializer
in this init, on top of the sizes and use them (no need to take a seed
then, since you can seed your initiliazer). Here is an example for the initialization of Dense
:
init(
inputSize: Int,
outputSize: Int,
activation: @escaping Activation = identity,
weightInitializer: ParameterInitializer<Scalar> = glorotUniform(),
biasInitializer: ParameterInitializer<Scalar> = zeros()
) {
self.init(
weight: weightInitializer([inputSize, outputSize]),
bias: biasInitializer([outputSize]),
activation: activation)
}
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@sgugger Thanks for the help! Everything should be finished now.
Hi @dhasl002 ! Thank you so much for this PR. Do you think you might be able to make some of the changes @eaplatanios suggested? If not, no worries! -Brennan |
@saeta Thank you for reminding me, I will work this today. |
Hi again @dhasl002 . Did you have time to work on those changes? Please let us know if you have any questions or if you don't have any time for this. |
fixed spacing variable renaming correct variable renaming
Thanks a lot for your help! |
Based on the fully gated unit seen below

There are the formulas that I used.
