Skip to content
This repository was archived by the owner on Jul 1, 2023. It is now read-only.

Commit 0dbccf3

Browse files
authored
Handle input_tensors during shape inference. (#453)
1 parent 3629ddc commit 0dbccf3

File tree

2 files changed

+73
-12
lines changed

2 files changed

+73
-12
lines changed

Diff for: Sources/TensorFlow/Core/LazyTensorShapeInference.swift

+42-11
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,36 @@ extension LazyTensorOperation {
6666
}
6767
}
6868

69+
// Returns the `CTensor`, selectively materializing it if needed.
70+
func cTensor(handle: LazyTensorHandle) -> CTensor? {
71+
switch handle.handle {
72+
case .concrete(let h, _):
73+
let cTensor = TFE_TensorHandleResolve(h._cTensorHandle, status)
74+
checkOk(status)
75+
return cTensor
76+
case .symbolic(let op, _, _):
77+
// TODO(https://bugs.swift.org/browse/TF-765): "Pack" is used
78+
// for creating tensors from array literals. So, allow
79+
// materialization for 'Pack' so that we can get the shape for
80+
// array literals. We should revisit this heuristic.
81+
if op.name != "Pack" { return nil }
82+
let cTensor = TFE_TensorHandleResolve(handle._cTensorHandle, status)
83+
checkOk(status)
84+
return cTensor
85+
}
86+
}
87+
88+
// Create `inputTensors` consisting of *only* materialized inputs.
89+
var inputTensors: [CTensor?] = []
90+
for input in inputs {
91+
switch input {
92+
case .single(let v):
93+
inputTensors.append(cTensor(handle: v))
94+
case .list(let values):
95+
inputTensors.append(contentsOf: values.lazy.map { cTensor(handle: $0) } )
96+
}
97+
}
98+
6999
// This will be filled in by `TFE_InferShapes` and should be freed later.
70100
var outputShapeListPtr = UnsafeMutablePointer<TF_ShapeAndTypeList>(nil)
71101
defer { TF_DeleteShapeAndTypeList(outputShapeListPtr) }
@@ -76,17 +106,18 @@ extension LazyTensorOperation {
76106
TF_DeleteStatus(tfeOp.status)
77107
}
78108

79-
TFE_InferShapes(
80-
tfeOp.op,
81-
/*input_shapes*/ inputShapeList,
82-
/*input_tensors*/ nil,
83-
/*input_tensors_as_shapes*/ nil,
84-
/*input_resource_shapes_and_types*/ nil,
85-
/*output_shapes*/ &outputShapeListPtr,
86-
/*output_resource_shapes_and_types*/ nil,
87-
status)
88-
89-
checkOk(status)
109+
inputTensors.withUnsafeMutableBufferPointer { buffer in
110+
TFE_InferShapes(
111+
tfeOp.op,
112+
/*input_shapes*/ inputShapeList,
113+
/*input_tensors*/ buffer.baseAddress!,
114+
/*input_tensors_as_shapes*/ nil,
115+
/*input_resource_shapes_and_types*/ nil,
116+
/*output_shapes*/ &outputShapeListPtr,
117+
/*output_resource_shapes_and_types*/ nil,
118+
status)
119+
checkOk(status)
120+
}
90121

91122
precondition(outputShapeListPtr != nil, "TFE_InferShapes returned nil for output shapes")
92123
let outputShapeList = outputShapeListPtr!.pointee

Diff for: Tests/TensorFlowTests/LazyTensorShapeInferenceTests.swift

+31-1
Original file line numberDiff line numberDiff line change
@@ -64,8 +64,38 @@ final class LazyTensorShapeInferenceTests: XCTestCase {
6464
XCTAssertTrue(xLazyTensorOperation.isMaterialized)
6565
}
6666

67+
/// Checks scenarios where shapes are computed from input tensors.
68+
func testShapeComputationsWithInputTensors() {
69+
let a = Tensor<Float>(shape: [3, 1], scalars: [1.0, 2.0, 3.0])
70+
let b = a.reshaped(toShape: [1, 3])
71+
72+
let bLazyTensorOperation = b._lazyTensor!.lazyTensorOperation!
73+
XCTAssertFalse(bLazyTensorOperation.isMaterialized)
74+
75+
let bShape = b.shape
76+
XCTAssertEqual(bShape.rank, 2)
77+
XCTAssertEqual(bShape.dimensions, [1, 3])
78+
XCTAssertFalse(bLazyTensorOperation.isMaterialized)
79+
80+
let c = Tensor<Float>(repeating: 5, shape: [4, 5, 6])
81+
let cLazyTensorOperation = c._lazyTensor!.lazyTensorOperation!
82+
XCTAssertFalse(cLazyTensorOperation.isMaterialized)
83+
84+
let cShape = c.shape
85+
XCTAssertEqual(cShape.rank, 3)
86+
XCTAssertEqual(cShape.dimensions, [4, 5, 6])
87+
XCTAssertFalse(cLazyTensorOperation.isMaterialized)
88+
89+
// Trigger materialization.
90+
let _ = b._rawTensorHandle
91+
let _ = c._rawTensorHandle
92+
XCTAssertTrue(bLazyTensorOperation.isMaterialized)
93+
XCTAssertTrue(cLazyTensorOperation.isMaterialized)
94+
}
95+
6796
static var allTests = [
68-
("testSimpleShapeComputations", testSimpleShapeComputations)
97+
("testSimpleShapeComputations", testSimpleShapeComputations),
98+
("testShapeComputationsWithInputTensors", testShapeComputationsWithInputTensors)
6999
]
70100
}
71101

0 commit comments

Comments
 (0)