@@ -153,18 +153,17 @@ final class TensorAutoDiffTests: XCTestCase {
153
153
XCTAssertEqual ( varianceGradAlongAxes ( input) , expected)
154
154
}
155
155
156
- // TODO: Uncomment once TF-653 is resolved.
157
- // func testTensorInitStacking() {
158
- // let a1 = Tensor<Float>([1, 2, 3, 4, 5])
159
- // let b1 = Tensor<Float>([6, 7, 8, 9, 10])
160
- // let a2 = Tensor<Float>([1, 1, 1, 1, 1])
161
- // let b2 = Tensor<Float>([1, 1, 1, 1, 1])
162
- // let grads = gradient(at: a2, b2) { a, b in
163
- // Tensor<Float>(stacking: [a1 * a, b1 * b], alongAxis: -1).sum()
164
- // }
165
- // XCTAssertEqual(a1, grads.0)
166
- // XCTAssertEqual(b1, grads.1)
167
- // }
156
+ func testTensorInitStacking( ) {
157
+ let a1 = Tensor < Float > ( [ 1 , 2 , 3 , 4 , 5 ] )
158
+ let b1 = Tensor < Float > ( [ 6 , 7 , 8 , 9 , 10 ] )
159
+ let a2 = Tensor < Float > ( [ 1 , 1 , 1 , 1 , 1 ] )
160
+ let b2 = Tensor < Float > ( [ 1 , 1 , 1 , 1 , 1 ] )
161
+ let grads = gradient ( at: a2, b2) { a, b in
162
+ Tensor < Float > ( stacking: [ a1 * a, b1 * b] , alongAxis: - 1 ) . sum ( )
163
+ }
164
+ XCTAssertEqual ( a1, grads. 0 )
165
+ XCTAssertEqual ( b1, grads. 1 )
166
+ }
168
167
169
168
func testExpandingShape( ) {
170
169
func f1( a: Tensor < Float > ) -> Tensor < Float > { a. expandingShape ( at: 0 ) . squared ( ) }
@@ -448,8 +447,7 @@ final class TensorAutoDiffTests: XCTestCase {
448
447
( " testSum " , testSum) ,
449
448
( " testMean " , testMean) ,
450
449
( " testVariance " , testVariance) ,
451
- // TODO: Uncomment once TF-653 is resolved.
452
- // ("testTensorInitStacking", testTensorInitStacking),
450
+ ( " testTensorInitStacking " , testTensorInitStacking) ,
453
451
( " testExpandingShape " , testExpandingShape) ,
454
452
( " testSqueezingShape " , testSqueezingShape) ,
455
453
( " testReshapedBackprop " , testReshapedBackprop) ,
0 commit comments