@@ -66,6 +66,36 @@ extension LazyTensorOperation {
66
66
}
67
67
}
68
68
69
+ // Returns the `CTensor`, selectively materializing it if needed.
70
+ func cTensor( handle: LazyTensorHandle ) -> CTensor ? {
71
+ switch handle. handle {
72
+ case . concrete( let h, _) :
73
+ let cTensor = TFE_TensorHandleResolve ( h. _cTensorHandle, status)
74
+ checkOk ( status)
75
+ return cTensor
76
+ case . symbolic( let op, _, _) :
77
+ // TODO(https://bugs.swift.org/browse/TF-765): "Pack" is used
78
+ // for creating tensors from array literals. So, allow
79
+ // materialization for 'Pack' so that we can get the shape for
80
+ // array literals. We should revisit this heuristic.
81
+ if op. name != " Pack " { return nil }
82
+ let cTensor = TFE_TensorHandleResolve ( handle. _cTensorHandle, status)
83
+ checkOk ( status)
84
+ return cTensor
85
+ }
86
+ }
87
+
88
+ // Create `inputTensors` consisting of *only* materialized inputs.
89
+ var inputTensors : [ CTensor ? ] = [ ]
90
+ for input in inputs {
91
+ switch input {
92
+ case . single( let v) :
93
+ inputTensors. append ( cTensor ( handle: v) )
94
+ case . list( let values) :
95
+ inputTensors. append ( contentsOf: values. lazy. map { cTensor ( handle: $0) } )
96
+ }
97
+ }
98
+
69
99
// This will be filled in by `TFE_InferShapes` and should be freed later.
70
100
var outputShapeListPtr = UnsafeMutablePointer < TF_ShapeAndTypeList > ( nil )
71
101
defer { TF_DeleteShapeAndTypeList ( outputShapeListPtr) }
@@ -76,17 +106,18 @@ extension LazyTensorOperation {
76
106
TF_DeleteStatus ( tfeOp. status)
77
107
}
78
108
79
- TFE_InferShapes (
80
- tfeOp. op,
81
- /*input_shapes*/ inputShapeList,
82
- /*input_tensors*/ nil ,
83
- /*input_tensors_as_shapes*/ nil ,
84
- /*input_resource_shapes_and_types*/ nil ,
85
- /*output_shapes*/ & outputShapeListPtr,
86
- /*output_resource_shapes_and_types*/ nil ,
87
- status)
88
-
89
- checkOk ( status)
109
+ inputTensors. withUnsafeMutableBufferPointer { buffer in
110
+ TFE_InferShapes (
111
+ tfeOp. op,
112
+ /*input_shapes*/ inputShapeList,
113
+ /*input_tensors*/ buffer. baseAddress!,
114
+ /*input_tensors_as_shapes*/ nil ,
115
+ /*input_resource_shapes_and_types*/ nil ,
116
+ /*output_shapes*/ & outputShapeListPtr,
117
+ /*output_resource_shapes_and_types*/ nil ,
118
+ status)
119
+ checkOk ( status)
120
+ }
90
121
91
122
precondition ( outputShapeListPtr != nil , " TFE_InferShapes returned nil for output shapes " )
92
123
let outputShapeList = outputShapeListPtr!. pointee
0 commit comments