Skip to content
This repository was archived by the owner on Jul 1, 2023. It is now read-only.

Commit 7465fdb

Browse files
authored
Add average, multiply and stack modes for BiRNNs (#1101)
1 parent d0ea5cb commit 7465fdb

File tree

5 files changed

+3838
-861
lines changed

5 files changed

+3838
-861
lines changed

Sources/TensorFlow/Layers/Recurrent.swift

+90-2
Original file line numberDiff line numberDiff line change
@@ -211,7 +211,7 @@ public struct LSTMCell<Scalar: TensorFlowFloatingPoint>: RecurrentLayerCell {
211211
/// Concatenates two values.
212212
@differentiable
213213
public static func concatenate(_ lhs: Self, _ rhs: Self) -> Self {
214-
// TODO: Remove workaround after https://github.com/tensorflow/swift-apis/issues/1087 is fixed.
214+
// TODO(TF-1005): Remove workaround for differenting concatenated.
215215
let concatCell = lhs.cell.concatenated(with: rhs.cell, alongAxis: -1)
216216
let concatHidden = lhs.hidden.concatenated(with: rhs.hidden, alongAxis: -1)
217217
let cell = concatCell.withDerivative { [shape = concatCell.shape] in
@@ -228,6 +228,33 @@ public struct LSTMCell<Scalar: TensorFlowFloatingPoint>: RecurrentLayerCell {
228228
public static func sum(_ lhs: Self, _ rhs: Self) -> Self {
229229
Self(cell: lhs.cell + rhs.cell, hidden: lhs.hidden + rhs.hidden)
230230
}
231+
232+
/// Averages two values.
233+
@differentiable
234+
public static func average(_ lhs: Self, _ rhs: Self) -> Self {
235+
Self(cell: (lhs.cell + rhs.cell) / 2, hidden: (lhs.hidden + rhs.hidden) / 2)
236+
}
237+
238+
/// Multiplies two values.
239+
@differentiable
240+
public static func multiply(_ lhs: Self, _ rhs: Self) -> Self {
241+
Self(cell: lhs.cell * rhs.cell, hidden: lhs.hidden * rhs.hidden)
242+
}
243+
244+
/// Stack two values.
245+
@differentiable
246+
public static func stack(_ lhs: Self, _ rhs: Self) -> Self {
247+
// TODO(TF-1005): Remove workaround for differenting stacking.
248+
let stackCell = Tensor(stacking: [lhs.cell, rhs.cell])
249+
let stackHidden = Tensor(stacking: [lhs.hidden, rhs.hidden])
250+
let cell = stackCell.withDerivative { [shape = stackCell.shape] in
251+
if $0 == Tensor(0) { $0 = Tensor(zeros: shape) }
252+
}
253+
let hidden = stackHidden.withDerivative { [shape = stackHidden.shape] in
254+
if $0 == Tensor(0) { $0 = Tensor(zeros: shape) }
255+
}
256+
return Self(cell: cell, hidden: hidden)
257+
}
231258
}
232259

233260
/// Returns a zero-valued state with shape compatible with the provided input.
@@ -455,13 +482,25 @@ public protocol Mergeable: Differentiable, AdditiveArithmetic {
455482
/// `Mergeable` (SR-13229).
456483
@differentiable
457484
static func sum(_ lhs: Self, _ rhs: Self) -> Self
485+
486+
/// Averages two values.
487+
@differentiable
488+
static func average(_ lhs: Self, _ rhs: Self) -> Self
489+
490+
/// Multiplies two values.
491+
@differentiable
492+
static func multiply(_ lhs: Self, _ rhs: Self) -> Self
493+
494+
/// Stack two values.
495+
@differentiable
496+
static func stack(_ lhs: Self, _ rhs: Self) -> Self
458497
}
459498

460499
extension Tensor: Mergeable where Scalar: TensorFlowFloatingPoint {
461500
/// Concatenates two tensors along last axis.
462501
@differentiable
463502
public static func concatenate(_ lhs: Tensor, _ rhs: Tensor) -> Tensor {
464-
// TODO: Remove workaround after https://github.com/tensorflow/swift-apis/issues/1087 is fixed.
503+
// TODO(TF-1005): Remove workaround for differenting concatenated.
465504
let concat = lhs.concatenated(with: rhs, alongAxis: -1)
466505
return concat.withDerivative { [shape = concat.shape] in
467506
if $0 == Tensor(0) { $0 = Tensor(zeros: shape) }
@@ -473,6 +512,28 @@ extension Tensor: Mergeable where Scalar: TensorFlowFloatingPoint {
473512
public static func sum(_ lhs: Tensor, _ rhs: Tensor) -> Tensor {
474513
lhs + rhs
475514
}
515+
516+
/// Averages two values.
517+
@differentiable
518+
public static func average(_ lhs: Tensor, _ rhs: Tensor) -> Tensor {
519+
(lhs + rhs) / 2
520+
}
521+
522+
/// Multiplies two values.
523+
@differentiable
524+
public static func multiply(_ lhs: Tensor, _ rhs: Tensor) -> Tensor {
525+
lhs * rhs
526+
}
527+
528+
/// Stack two values.
529+
@differentiable
530+
public static func stack(_ lhs: Tensor, _ rhs: Tensor) -> Tensor {
531+
// TODO(TF-1005): Remove workaround for differenting stacking.
532+
let stack = Tensor(stacking: [lhs, rhs])
533+
return stack.withDerivative { [shape = stack.shape] in
534+
if $0 == Tensor(0) { $0 = Tensor(zeros: shape) }
535+
}
536+
}
476537
}
477538

478539
/// Concatenates two values.
@@ -493,6 +554,33 @@ public func sum<T: Mergeable>(
493554
T.sum(first, second)
494555
}
495556

557+
/// Averages two values.
558+
@differentiable
559+
public func average<T: Mergeable>(
560+
_ first: T,
561+
_ second: T
562+
) -> T {
563+
T.average(first, second)
564+
}
565+
566+
/// Multiplies two values.
567+
@differentiable
568+
public func multiply<T: Mergeable>(
569+
_ first: T,
570+
_ second: T
571+
) -> T {
572+
T.multiply(first, second)
573+
}
574+
575+
/// Stack two values.
576+
@differentiable
577+
public func stack<T: Mergeable>(
578+
_ first: T,
579+
_ second: T
580+
) -> T {
581+
T.stack(first, second)
582+
}
583+
496584
public struct BidirectionalRecurrentLayer<Cell: RecurrentLayerCell>: Layer
497585
where Cell.TimeStepOutput: Mergeable {
498586
public typealias Input = [Cell.TimeStepInput]

0 commit comments

Comments
 (0)