Skip to content
This repository was archived by the owner on Jul 1, 2023. It is now read-only.

Commit c3f38de

Browse files
authored
Add transposedConv2D (#596) (#672)
Rename `conv2DBackpropInput `to `transposedConv2D` and make it public. `transposedConv2D` is consistent with TensorFlow and PyTorch.
1 parent 7d0691b commit c3f38de

File tree

3 files changed

+75
-5
lines changed

3 files changed

+75
-5
lines changed

Sources/TensorFlow/Layers/Convolutional.swift

+2-2
Original file line numberDiff line numberDiff line change
@@ -437,7 +437,7 @@ public struct TransposedConv1D<Scalar: TensorFlowFloatingPoint>: Layer {
437437
stride + (filter.shape[0] * paddingIndex)
438438
let c = filter.shape[2]
439439
let newShape = Tensor<Int32>([Int32(batchSize), 1, Int32(w), Int32(c)])
440-
let conv = conv2DBackpropInput(
440+
let conv = transposedConv2D(
441441
input.expandingShape(at: 1),
442442
shape: newShape,
443443
filter: filter.expandingShape(at: 0),
@@ -541,7 +541,7 @@ public struct TransposedConv2D<Scalar: TensorFlowFloatingPoint>: Layer {
541541
strides.1 + (filter.shape[1] * paddingIndex)
542542
let c = filter.shape[2]
543543
let newShape = Tensor<Int32>([Int32(batchSize), Int32(h), Int32(w), Int32(c)])
544-
let conv = conv2DBackpropInput(
544+
let conv = transposedConv2D(
545545
input,
546546
shape: newShape,
547547
filter: filter,

Sources/TensorFlow/Operators/NN.swift

+27-1
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,32 @@ func _vjpConv2D<Scalar: TensorFlowFloatingPoint>(
149149
})
150150
}
151151

152+
/// Returns a 2-D transposed convolution with the specified input, filter, strides, and padding.
153+
///
154+
/// - Parameters:
155+
/// - input: The input.
156+
/// - shape: The output shape of the deconvolution operation.
157+
/// - filter: The convolution filter.
158+
/// - strides: The strides of the sliding filter for each dimension of the input.
159+
/// - padding: The padding for the operation
160+
/// - dilations: The dilation factor for each dimension of the input.
161+
/// - Precondition: `input` must have rank `4`.
162+
/// - Precondition: `filter` must have rank 4.
163+
@differentiable(wrt: (input, filter))
164+
public func transposedConv2D<Scalar: TensorFlowFloatingPoint>(
165+
_ input: Tensor<Scalar>,
166+
shape: Tensor<Int32>,
167+
filter: Tensor<Scalar>,
168+
strides: (Int, Int, Int, Int) = (1, 1, 1, 1),
169+
padding: Padding = .valid,
170+
dilations: (Int, Int, Int, Int) = (1, 1, 1, 1)
171+
) -> Tensor<Scalar> {
172+
precondition(input.shape.rank == 4, "The input must have rank 4.")
173+
precondition(filter.shape.rank == 4, "The filter must have rank 4.")
174+
return conv2DBackpropInput(input, shape: shape, filter: filter,
175+
strides: strides, padding: padding, dilations: dilations)
176+
}
177+
152178
/// TensorFlow builtin conv2d gradient helper for the input.
153179
@differentiable(wrt: (x, filter))
154180
@usableFromInline
@@ -170,8 +196,8 @@ func conv2DBackpropInput<Scalar: TensorFlowFloatingPoint>(
170196
dilations: [Int32(dilations.0), Int32(dilations.1), Int32(dilations.2), Int32(dilations.3)])
171197
}
172198

173-
@usableFromInline
174199
@derivative(of: conv2DBackpropInput)
200+
@usableFromInline
175201
func _vjpConv2DBackpropInput<Scalar: TensorFlowFloatingPoint>(
176202
_ x: Tensor<Scalar>,
177203
_ shape: Tensor<Int32>,

Tests/TensorFlowTests/LayerTests.swift

+46-2
Original file line numberDiff line numberDiff line change
@@ -360,8 +360,51 @@ final class LayerTests: XCTestCase {
360360
let expectedNoBias = Tensor<Float>(shape: [1, 4, 2, 1],
361361
scalars: [0, 4, 4, 20, 16, 56, 40, 104])
362362
XCTAssertEqual(outputNoBias, expectedNoBias)
363-
}
364-
363+
}
364+
365+
func testTransposedConv2DGradient() {
366+
let filter = Tensor(shape: [3, 3, 2, 4], scalars: (0..<72).map(Float.init))
367+
let bias = Tensor<Float>(zeros: [2])
368+
let layer = TransposedConv2D<Float>(filter: filter,
369+
bias: bias,
370+
activation: identity,
371+
strides: (2, 2),
372+
padding: .same)
373+
let input = Tensor(shape: [2, 2, 2, 4], scalars: (0..<32).map(Float.init))
374+
let grads = gradient( at: input, layer) { $1($0).sum() }
375+
// The expected value of the gradient was computed using the following Python code:
376+
// ```
377+
// import tensorflow as tf
378+
// x = tf.reshape(tf.range(32, dtype=tf.float32), [2, 2, 2, 4])
379+
// filter = tf.reshape(tf.range(72, dtype=tf.float32), [3, 3, 2, 4])
380+
// bias = tf.zeros([2])
381+
// with tf.GradientTape() as tape:
382+
// tape.watch([x, filter, bias])
383+
// y = tf.math.reduce_sum(tf.nn.conv2d_transpose(input=x,
384+
// filters=filter,
385+
// output_shape=[2, 4, 4, 2],
386+
// strides=[1, 2, 2, 1],
387+
// data_format="NHWC",
388+
// padding="SAME") + bias)
389+
// print(tape.gradient(y, [x, filter, bias]))
390+
// ```
391+
XCTAssertEqual(grads.0,
392+
[[[[612, 630, 648, 666], [360, 372, 384, 396]],
393+
[[264, 276, 288, 300], [144, 152, 160, 168]]],
394+
[[[612, 630, 648, 666], [360, 372, 384, 396]],
395+
[[264, 276, 288, 300], [144, 152, 160, 168]]]])
396+
XCTAssertEqual(grads.1.filter,
397+
[[[[112, 120, 128, 136], [112, 120, 128, 136]],
398+
[[112, 120, 128, 136], [112, 120, 128, 136]],
399+
[[ 48, 52, 56, 60], [ 48, 52, 56, 60]]],
400+
[[[112, 120, 128, 136], [112, 120, 128, 136]],
401+
[[112, 120, 128, 136], [112, 120, 128, 136]],
402+
[[ 48, 52, 56, 60], [ 48, 52, 56, 60]]],
403+
[[[ 40, 44, 48, 52], [ 40, 44, 48, 52]],
404+
[[ 40, 44, 48, 52], [ 40, 44, 48, 52]],
405+
[[ 16, 18, 20, 22], [ 16, 18, 20, 22]]]])
406+
XCTAssertEqual(grads.1.bias, [32, 32])
407+
}
365408

366409
func testTransposedConv3D() {
367410
let filter = Tensor(shape: [2, 2, 2, 1, 1], scalars: (0..<8).map(Float.init))
@@ -1615,6 +1658,7 @@ final class LayerTests: XCTestCase {
16151658
("testConv3DGradient", testConv3DGradient),
16161659
("testTransposedConv1D", testTransposedConv1D),
16171660
("testTransposedConv2D", testTransposedConv2D),
1661+
("testTransposedConv2DGradient", testTransposedConv2DGradient),
16181662
("testTransposedConv3D", testTransposedConv3D),
16191663
("testDepthwiseConv2D", testDepthwiseConv2D),
16201664
("testDepthwiseConv2DGradient", testDepthwiseConv2DGradient),

0 commit comments

Comments
 (0)