Skip to content
This repository was archived by the owner on Jul 1, 2023. It is now read-only.

Commit 0fce911

Browse files
authoredAug 22, 2019
Added back a test for 'Tensor.init(stacking:alongAxis:)'. (#466)
1 parent 56df29b commit 0fce911

File tree

1 file changed

+12
-14
lines changed

1 file changed

+12
-14
lines changed
 

‎Tests/TensorFlowTests/TensorAutoDiffTests.swift

+12-14
Original file line numberDiff line numberDiff line change
@@ -153,18 +153,17 @@ final class TensorAutoDiffTests: XCTestCase {
153153
XCTAssertEqual(varianceGradAlongAxes(input), expected)
154154
}
155155

156-
// TODO: Uncomment once TF-653 is resolved.
157-
// func testTensorInitStacking() {
158-
// let a1 = Tensor<Float>([1, 2, 3, 4, 5])
159-
// let b1 = Tensor<Float>([6, 7, 8, 9, 10])
160-
// let a2 = Tensor<Float>([1, 1, 1, 1, 1])
161-
// let b2 = Tensor<Float>([1, 1, 1, 1, 1])
162-
// let grads = gradient(at: a2, b2) { a, b in
163-
// Tensor<Float>(stacking: [a1 * a, b1 * b], alongAxis: -1).sum()
164-
// }
165-
// XCTAssertEqual(a1, grads.0)
166-
// XCTAssertEqual(b1, grads.1)
167-
// }
156+
func testTensorInitStacking() {
157+
let a1 = Tensor<Float>([1, 2, 3, 4, 5])
158+
let b1 = Tensor<Float>([6, 7, 8, 9, 10])
159+
let a2 = Tensor<Float>([1, 1, 1, 1, 1])
160+
let b2 = Tensor<Float>([1, 1, 1, 1, 1])
161+
let grads = gradient(at: a2, b2) { a, b in
162+
Tensor<Float>(stacking: [a1 * a, b1 * b], alongAxis: -1).sum()
163+
}
164+
XCTAssertEqual(a1, grads.0)
165+
XCTAssertEqual(b1, grads.1)
166+
}
168167

169168
func testExpandingShape() {
170169
func f1(a: Tensor<Float>) -> Tensor<Float> { a.expandingShape(at: 0).squared() }
@@ -448,8 +447,7 @@ final class TensorAutoDiffTests: XCTestCase {
448447
("testSum", testSum),
449448
("testMean", testMean),
450449
("testVariance", testVariance),
451-
// TODO: Uncomment once TF-653 is resolved.
452-
// ("testTensorInitStacking", testTensorInitStacking),
450+
("testTensorInitStacking", testTensorInitStacking),
453451
("testExpandingShape", testExpandingShape),
454452
("testSqueezingShape", testSqueezingShape),
455453
("testReshapedBackprop", testReshapedBackprop),

0 commit comments

Comments
 (0)