Skip to content
This repository was archived by the owner on Jul 1, 2023. It is now read-only.

Commit 6f5b962

Browse files
tanmayb123rxwei
authored andcommitted
Add flatten & reshape layers (#32)
1 parent fc55d91 commit 6f5b962

File tree

1 file changed

+30
-4
lines changed

1 file changed

+30
-4
lines changed

Sources/DeepLearning/Layer.swift

+30-4
Original file line numberDiff line numberDiff line change
@@ -605,10 +605,36 @@ public struct UpSampling2D<Scalar: TensorFlowFloatingPoint>: Layer {
605605
public func applied(to input: Tensor<Scalar>, in _: Context) -> Tensor<Scalar> {
606606
let shape = input.shape
607607
let (batchSize, height, width, channels) = (shape[0], shape[1], shape[2], shape[3])
608-
let reshapeSize = Tensor<Int32>([batchSize, height, 1, width, 1, channels])
609608
let scaleOnes = Tensor<Scalar>(ones: [1, 1, size, 1, size, 1])
610-
let upSampling = input.reshaped(toShape: reshapeSize) * scaleOnes
611-
let upSampledShape = Tensor<Int32>([batchSize, height * size, width * size, channels])
612-
return upSampling.reshaped(toShape: upSampledShape)
609+
let upSampling = input.reshaped(to: [batchSize, height, 1, width, 1, channels]) * scaleOnes
610+
return upSampling.reshaped(to: [batchSize, height * size, width * size, channels])
611+
}
612+
}
613+
614+
@_fixed_layout
615+
public struct Flatten<Scalar: TensorFlowFloatingPoint>: Layer {
616+
@differentiable
617+
public func applied(to input: Tensor<Scalar>, in _: Context) -> Tensor<Scalar> {
618+
let batchSize = input.shape[0]
619+
let remaining = input.shape[1..<input.rank].contiguousSize
620+
return input.reshaped(to: [batchSize, remaining])
621+
}
622+
}
623+
624+
@_fixed_layout
625+
public struct Reshape<Scalar: TensorFlowFloatingPoint>: Layer {
626+
@noDerivative public let shape: Tensor<Int32>
627+
628+
public init(shape: Tensor<Int32>) {
629+
self.shape = shape
630+
}
631+
632+
public init(_ shape: TensorShape) {
633+
self.init(shape: Tensor(shape.dimensions))
634+
}
635+
636+
@differentiable
637+
public func applied(to input: Tensor<Scalar>, in _: Context) -> Tensor<Scalar> {
638+
return input.reshaped(toShape: shape)
613639
}
614640
}

0 commit comments

Comments
 (0)