@@ -211,7 +211,7 @@ public struct LSTMCell<Scalar: TensorFlowFloatingPoint>: RecurrentLayerCell {
211
211
/// Concatenates two values.
212
212
@differentiable
213
213
public static func concatenate( _ lhs: Self , _ rhs: Self ) -> Self {
214
- // TODO: Remove workaround after https://github.com/tensorflow/swift-apis/issues/1087 is fixed .
214
+ // TODO(TF-1005) : Remove workaround for differenting concatenated .
215
215
let concatCell = lhs. cell. concatenated ( with: rhs. cell, alongAxis: - 1 )
216
216
let concatHidden = lhs. hidden. concatenated ( with: rhs. hidden, alongAxis: - 1 )
217
217
let cell = concatCell. withDerivative { [ shape = concatCell. shape] in
@@ -228,6 +228,33 @@ public struct LSTMCell<Scalar: TensorFlowFloatingPoint>: RecurrentLayerCell {
228
228
public static func sum( _ lhs: Self , _ rhs: Self ) -> Self {
229
229
Self ( cell: lhs. cell + rhs. cell, hidden: lhs. hidden + rhs. hidden)
230
230
}
231
+
232
+ /// Averages two values.
233
+ @differentiable
234
+ public static func average( _ lhs: Self , _ rhs: Self ) -> Self {
235
+ Self ( cell: ( lhs. cell + rhs. cell) / 2 , hidden: ( lhs. hidden + rhs. hidden) / 2 )
236
+ }
237
+
238
+ /// Multiplies two values.
239
+ @differentiable
240
+ public static func multiply( _ lhs: Self , _ rhs: Self ) -> Self {
241
+ Self ( cell: lhs. cell * rhs. cell, hidden: lhs. hidden * rhs. hidden)
242
+ }
243
+
244
+ /// Stack two values.
245
+ @differentiable
246
+ public static func stack( _ lhs: Self , _ rhs: Self ) -> Self {
247
+ // TODO(TF-1005): Remove workaround for differenting stacking.
248
+ let stackCell = Tensor ( stacking: [ lhs. cell, rhs. cell] )
249
+ let stackHidden = Tensor ( stacking: [ lhs. hidden, rhs. hidden] )
250
+ let cell = stackCell. withDerivative { [ shape = stackCell. shape] in
251
+ if $0 == Tensor ( 0 ) { $0 = Tensor ( zeros: shape) }
252
+ }
253
+ let hidden = stackHidden. withDerivative { [ shape = stackHidden. shape] in
254
+ if $0 == Tensor ( 0 ) { $0 = Tensor ( zeros: shape) }
255
+ }
256
+ return Self ( cell: cell, hidden: hidden)
257
+ }
231
258
}
232
259
233
260
/// Returns a zero-valued state with shape compatible with the provided input.
@@ -455,13 +482,25 @@ public protocol Mergeable: Differentiable, AdditiveArithmetic {
455
482
/// `Mergeable` (SR-13229).
456
483
@differentiable
457
484
static func sum( _ lhs: Self , _ rhs: Self ) -> Self
485
+
486
+ /// Averages two values.
487
+ @differentiable
488
+ static func average( _ lhs: Self , _ rhs: Self ) -> Self
489
+
490
+ /// Multiplies two values.
491
+ @differentiable
492
+ static func multiply( _ lhs: Self , _ rhs: Self ) -> Self
493
+
494
+ /// Stack two values.
495
+ @differentiable
496
+ static func stack( _ lhs: Self , _ rhs: Self ) -> Self
458
497
}
459
498
460
499
extension Tensor: Mergeable where Scalar: TensorFlowFloatingPoint {
461
500
/// Concatenates two tensors along last axis.
462
501
@differentiable
463
502
public static func concatenate( _ lhs: Tensor , _ rhs: Tensor ) -> Tensor {
464
- // TODO: Remove workaround after https://github.com/tensorflow/swift-apis/issues/1087 is fixed .
503
+ // TODO(TF-1005) : Remove workaround for differenting concatenated .
465
504
let concat = lhs. concatenated ( with: rhs, alongAxis: - 1 )
466
505
return concat. withDerivative { [ shape = concat. shape] in
467
506
if $0 == Tensor ( 0 ) { $0 = Tensor ( zeros: shape) }
@@ -473,6 +512,28 @@ extension Tensor: Mergeable where Scalar: TensorFlowFloatingPoint {
473
512
public static func sum( _ lhs: Tensor , _ rhs: Tensor ) -> Tensor {
474
513
lhs + rhs
475
514
}
515
+
516
+ /// Averages two values.
517
+ @differentiable
518
+ public static func average( _ lhs: Tensor , _ rhs: Tensor ) -> Tensor {
519
+ ( lhs + rhs) / 2
520
+ }
521
+
522
+ /// Multiplies two values.
523
+ @differentiable
524
+ public static func multiply( _ lhs: Tensor , _ rhs: Tensor ) -> Tensor {
525
+ lhs * rhs
526
+ }
527
+
528
+ /// Stack two values.
529
+ @differentiable
530
+ public static func stack( _ lhs: Tensor , _ rhs: Tensor ) -> Tensor {
531
+ // TODO(TF-1005): Remove workaround for differenting stacking.
532
+ let stack = Tensor ( stacking: [ lhs, rhs] )
533
+ return stack. withDerivative { [ shape = stack. shape] in
534
+ if $0 == Tensor ( 0 ) { $0 = Tensor ( zeros: shape) }
535
+ }
536
+ }
476
537
}
477
538
478
539
/// Concatenates two values.
@@ -493,6 +554,33 @@ public func sum<T: Mergeable>(
493
554
T . sum ( first, second)
494
555
}
495
556
557
+ /// Averages two values.
558
+ @differentiable
559
+ public func average< T: Mergeable > (
560
+ _ first: T ,
561
+ _ second: T
562
+ ) -> T {
563
+ T . average ( first, second)
564
+ }
565
+
566
+ /// Multiplies two values.
567
+ @differentiable
568
+ public func multiply< T: Mergeable > (
569
+ _ first: T ,
570
+ _ second: T
571
+ ) -> T {
572
+ T . multiply ( first, second)
573
+ }
574
+
575
+ /// Stack two values.
576
+ @differentiable
577
+ public func stack< T: Mergeable > (
578
+ _ first: T ,
579
+ _ second: T
580
+ ) -> T {
581
+ T . stack ( first, second)
582
+ }
583
+
496
584
public struct BidirectionalRecurrentLayer < Cell: RecurrentLayerCell > : Layer
497
585
where Cell. TimeStepOutput: Mergeable {
498
586
public typealias Input = [ Cell . TimeStepInput ]
0 commit comments