@@ -605,10 +605,36 @@ public struct UpSampling2D<Scalar: TensorFlowFloatingPoint>: Layer {
605
605
public func applied( to input: Tensor < Scalar > , in _: Context ) -> Tensor < Scalar > {
606
606
let shape = input. shape
607
607
let ( batchSize, height, width, channels) = ( shape [ 0 ] , shape [ 1 ] , shape [ 2 ] , shape [ 3 ] )
608
- let reshapeSize = Tensor < Int32 > ( [ batchSize, height, 1 , width, 1 , channels] )
609
608
let scaleOnes = Tensor < Scalar > ( ones: [ 1 , 1 , size, 1 , size, 1 ] )
610
- let upSampling = input. reshaped ( toShape: reshapeSize) * scaleOnes
611
- let upSampledShape = Tensor < Int32 > ( [ batchSize, height * size, width * size, channels] )
612
- return upSampling. reshaped ( toShape: upSampledShape)
609
+ let upSampling = input. reshaped ( to: [ batchSize, height, 1 , width, 1 , channels] ) * scaleOnes
610
+ return upSampling. reshaped ( to: [ batchSize, height * size, width * size, channels] )
611
+ }
612
+ }
613
+
614
+ @_fixed_layout
615
+ public struct Flatten < Scalar: TensorFlowFloatingPoint > : Layer {
616
+ @differentiable
617
+ public func applied( to input: Tensor < Scalar > , in _: Context ) -> Tensor < Scalar > {
618
+ let batchSize = input. shape [ 0 ]
619
+ let remaining = input. shape [ 1 ..< input. rank] . contiguousSize
620
+ return input. reshaped ( to: [ batchSize, remaining] )
621
+ }
622
+ }
623
+
624
+ @_fixed_layout
625
+ public struct Reshape < Scalar: TensorFlowFloatingPoint > : Layer {
626
+ @noDerivative public let shape : Tensor < Int32 >
627
+
628
+ public init ( shape: Tensor < Int32 > ) {
629
+ self . shape = shape
630
+ }
631
+
632
+ public init ( _ shape: TensorShape ) {
633
+ self . init ( shape: Tensor ( shape. dimensions) )
634
+ }
635
+
636
+ @differentiable
637
+ public func applied( to input: Tensor < Scalar > , in _: Context ) -> Tensor < Scalar > {
638
+ return input. reshaped ( toShape: shape)
613
639
}
614
640
}
0 commit comments