Skip to content
This repository was archived by the owner on Jul 1, 2023. It is now read-only.

Commit d1decbe

Browse files
authored
Adds a mechanism to manually promote constants to inputs in Lazy Tensor. (#299)
1 parent 12d7030 commit d1decbe

File tree

4 files changed

+217
-20
lines changed

4 files changed

+217
-20
lines changed

Sources/TensorFlow/Core/LazyTensorOperation.swift

+51
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,57 @@ class LazyTensor: _AnyTensorHandle {
130130
static var _materializationCallback: (String) -> () = { _ in }
131131
}
132132

133+
extension _AnyTensorHandle {
134+
/// Returns a concrete `LazyTensor` with an additional constraint that the
135+
/// underlying concrete `LazyTensor` should be marked to be promoted as an
136+
/// input when used in an extracted trace. This provides a **temporary**
137+
/// mechanism to promote a concrete lazy tensor to an input in extracted
138+
/// traces. (Note that this may trigger materialization.)
139+
var _concreteInputLazyTensor: LazyTensor {
140+
LazyTensor(_materialized: self._tfeTensorHandle)
141+
}
142+
}
143+
144+
extension TensorHandle {
145+
/// Returns `Self` that wraps `_concreteInputLazyTensor` of the underlying
146+
/// `_AnyTensorHandle`
147+
public var _concreteInputLazyTensor: TensorHandle {
148+
TensorHandle(handle: handle._concreteInputLazyTensor)
149+
}
150+
}
151+
152+
extension Tensor {
153+
/// Returns `Self` that wraps `_concreteInputLazyTensor` of the underlying
154+
/// `_AnyTensorHandle`
155+
public var _concreteInputLazyTensor: Tensor {
156+
Tensor(handle: handle._concreteInputLazyTensor)
157+
}
158+
}
159+
160+
extension StringTensor {
161+
/// Returns `Self` that wraps `_concreteInputLazyTensor` of the underlying
162+
/// `_AnyTensorHandle`
163+
public var _concreteInputLazyTensor: StringTensor {
164+
StringTensor(handle: handle._concreteInputLazyTensor)
165+
}
166+
}
167+
168+
extension VariantHandle {
169+
/// Returns `Self` that wraps `_concreteInputLazyTensor` of the underlying
170+
/// `_AnyTensorHandle`
171+
public var _concreteInputLazyTensor: VariantHandle {
172+
VariantHandle(handle: handle._concreteInputLazyTensor)
173+
}
174+
}
175+
176+
extension ResourceHandle {
177+
/// Returns `Self` that wraps `_concreteInputLazyTensor` of the underlying
178+
/// `_AnyTensorHandle`
179+
public var _concreteInputLazyTensor: ResourceHandle {
180+
ResourceHandle(handle: handle._concreteInputLazyTensor)
181+
}
182+
}
183+
133184
class LazyTensorOperation: TensorOperation {
134185
typealias TensorValueHandle = LazyTensor
135186

Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
// Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
@testable import TensorFlow
16+
17+
protocol _LazyTensorCompatible {
18+
/// The underlying `LazyTensor` (if any).
19+
var _lazyTensor: LazyTensor? { get }
20+
21+
/// Returns `Self` that wraps a concrete `LazyTensor`.
22+
/// (Triggers materialization if needed.)
23+
var _concreteLazyTensor: Self { get }
24+
25+
/// Similar to the `concreteLazyTensor` with an additional constraint that
26+
/// the underlying concrete `LazyTensor` should be marked to be promoted as
27+
/// an input when used in an extracted trace.
28+
var _concreteInputLazyTensor: Self { get }
29+
}
30+
31+
extension _AnyTensorHandle {
32+
var _lazyTensor: LazyTensor? {
33+
if let handle = self as? LazyTensor {
34+
return handle
35+
} else {
36+
return nil
37+
}
38+
}
39+
var _concreteLazyTensor: LazyTensor { LazyTensor(self._tfeTensorHandle) }
40+
}
41+
42+
extension TensorHandle: _LazyTensorCompatible {
43+
var _lazyTensor: LazyTensor? { handle._lazyTensor }
44+
public var _concreteLazyTensor: TensorHandle {
45+
TensorHandle(handle: handle._concreteLazyTensor)
46+
}
47+
}
48+
49+
extension Tensor: _LazyTensorCompatible {
50+
var _lazyTensor: LazyTensor? { handle._lazyTensor }
51+
public var _concreteLazyTensor: Tensor {
52+
Tensor(handle: handle._concreteLazyTensor)
53+
}
54+
}
55+
56+
extension StringTensor: _LazyTensorCompatible {
57+
var _lazyTensor: LazyTensor? { handle._lazyTensor }
58+
public var _concreteLazyTensor: StringTensor {
59+
StringTensor(handle: handle._concreteLazyTensor)
60+
}
61+
}
62+
63+
extension VariantHandle: _LazyTensorCompatible {
64+
var _lazyTensor: LazyTensor? { handle._lazyTensor }
65+
public var _concreteLazyTensor: VariantHandle {
66+
VariantHandle(handle: handle._concreteLazyTensor)
67+
}
68+
}
69+
70+
extension ResourceHandle: _LazyTensorCompatible {
71+
var _lazyTensor: LazyTensor? { handle._lazyTensor }
72+
public var _concreteLazyTensor: ResourceHandle {
73+
ResourceHandle(handle: handle._concreteLazyTensor)
74+
}
75+
}

Tests/TensorFlowTests/LazyTensorTests.swift

+46-8
Original file line numberDiff line numberDiff line change
@@ -75,13 +75,6 @@ final class LazyTensorTests: XCTestCase {
7575
XCTAssertEqual(expectedAllOps, actualAllOps)
7676
}
7777

78-
func isSymbolic(_ t: LazyTensor) -> Bool {
79-
switch t.handle {
80-
case .symbolic(_): return true
81-
case .concrete(_): return false
82-
}
83-
}
84-
8578
let op0 = LazyTensorOperation(
8679
_id: "0", name: "IdentityN", outputCount: 2)
8780
let op1 = LazyTensorOperation(
@@ -112,9 +105,54 @@ final class LazyTensorTests: XCTestCase {
112105
XCTAssertTrue(isSymbolic(t0))
113106
}
114107

108+
private func checkConversions<T: _LazyTensorCompatible>(_ x: T) {
109+
let concreteLazyX = x._concreteLazyTensor
110+
let concreteInputLazyX = x._concreteInputLazyTensor
111+
XCTAssertFalse(isSymbolic(concreteLazyX._lazyTensor))
112+
XCTAssertFalse(isSymbolic(concreteInputLazyX._lazyTensor))
113+
XCTAssertFalse(isMaterializedConcrete(concreteLazyX._lazyTensor))
114+
XCTAssertTrue(isMaterializedConcrete(concreteInputLazyX._lazyTensor))
115+
}
116+
117+
func testTensorToLazyTensorConversions() {
118+
checkConversions(Tensor<Float>(10.0))
119+
checkConversions(StringTensor("Hello!"))
120+
121+
// ResourceHandle and VariantHandle conversions.
122+
let elements1: Tensor<Int32> = [0, 1, 2]
123+
let elements2: Tensor<Int32> = [10, 11, 12]
124+
let outputTypes = [Int32.tensorFlowDataType, Int32.tensorFlowDataType]
125+
let outputShapes: [TensorShape?] = [nil, nil]
126+
let dataset: VariantHandle = Raw.tensorSliceDataset(
127+
components: [elements1, elements2],
128+
outputShapes: outputShapes
129+
)
130+
let iterator: ResourceHandle = Raw.iteratorV2(sharedName: "blah",
131+
container: "earth", outputTypes: outputTypes, outputShapes: outputShapes
132+
)
133+
checkConversions(dataset)
134+
checkConversions(iterator)
135+
}
136+
137+
private func isSymbolic(_ t: LazyTensor?) -> Bool {
138+
guard let t = t else { return false }
139+
switch t.handle {
140+
case .symbolic(_): return true
141+
case .concrete(_): return false
142+
}
143+
}
144+
145+
private func isMaterializedConcrete(_ t: LazyTensor?) -> Bool {
146+
guard let t = t else { return false }
147+
switch t.handle {
148+
case .symbolic(_): return true
149+
case .concrete(_, let isMaterialized): return isMaterialized
150+
}
151+
}
152+
115153
static var allTests = [
116154
("testConstructions", testConstructions),
117155
("testLivenessTracking", testLivenessTracking),
156+
("testTensorToLazyTensorConversions", testTensorToLazyTensorConversions)
118157
]
119-
120158
}

Tests/TensorFlowTests/LazyTensorTraceTests.swift

+45-12
Original file line numberDiff line numberDiff line change
@@ -107,15 +107,13 @@ final class LazyTensorTraceTests: XCTestCase {
107107
""")
108108
}
109109

110-
func testConstPromotion() {
110+
func testManualConstPromotion() {
111111
let a = Tensor<Float>(10.0)
112112
let b = Tensor<Float>(2.0)
113-
let concreteA = a.handle.handle._tfeTensorHandle
114113

115-
let lazyHandle = LazyTensor(concreteA)
116-
let lazyA = Tensor(handle: TensorHandle<Float>(handle: lazyHandle))
117-
// Since `lazyA` is not marked as a materialized concrete
118-
// tensor, this will be burnt into the trace as a constant.
114+
// Since `lazyA` is not marked as an input, this will
115+
// be burnt into the trace as a constant.
116+
let lazyA = a._concreteLazyTensor
119117
let w1 = lazyA * b
120118
let w1Trace = lazyTrace(w1)!
121119
XCTAssertEqual(w1Trace.description,
@@ -127,12 +125,11 @@ final class LazyTensorTraceTests: XCTestCase {
127125
}
128126
""")
129127
XCTAssertEqual(w1Trace.inputValues.count, 0)
130-
let materializedHandle = LazyTensor(_materialized: concreteA)
131-
let materializedLazyA = Tensor(
132-
handle: TensorHandle<Float>(handle: materializedHandle))
133-
// Since `materializedLazyA` is marked as a materialized concrete
134-
// tensor, this will be promoted to an input for the trace.
135-
let w2 = materializedLazyA * b
128+
129+
// Since `lazyInputA` is marked as an input, this will
130+
// be promoted to an input for the trace.
131+
let inputLazyA = a._concreteInputLazyTensor
132+
let w2 = inputLazyA * b
136133
let w2Trace = lazyTrace(w2)!
137134
XCTAssertEqual(w2Trace.description,
138135
"""
@@ -146,6 +143,41 @@ final class LazyTensorTraceTests: XCTestCase {
146143
XCTAssertEqual(w2Trace.inputValues[0].valueDescription, "10.0")
147144
}
148145

146+
func testConstPromotion() {
147+
let a = Tensor<Float>(1.0)
148+
let b = Tensor<Float>(2.0)
149+
let c = Tensor<Float>(3.0)
150+
let y = a + b
151+
let z = y * c
152+
153+
XCTAssertEqual(
154+
lazyTrace(y)!.description,
155+
"""
156+
lazyTrace_3() -> (%2) {
157+
%0 = Const[dtype: float, value: 1.0]()
158+
%1 = Const[dtype: float, value: 2.0]()
159+
%2 = Add[T: float](%0, %1)
160+
}
161+
""")
162+
XCTAssertEqual(y.scalarized(), 3.0)
163+
164+
/// Now that `y` is materialized and a constant,
165+
/// the trace for `z` will use that as a constant.
166+
let zTrace = lazyTrace(z)!
167+
XCTAssertEqual(
168+
zTrace.description,
169+
"""
170+
lazyTrace_3(%0: float) -> (%2) {
171+
%1 = Const[dtype: float, value: 3.0]()
172+
%2 = Mul[T: float](%0, %1)
173+
}
174+
""")
175+
// Make sure that the promoted constants are gathered as `inputValues`.
176+
XCTAssertEqual(zTrace.inputValues.count, 1)
177+
XCTAssertEqual(zTrace.inputValues[0].valueDescription, "3.0")
178+
XCTAssertEqual(z.scalarized(), 9.0)
179+
}
180+
149181
private func lazyTrace<T: TensorFlowScalar>(
150182
_ input: Tensor<T>
151183
) -> LazyTensorTrace? {
@@ -165,6 +197,7 @@ final class LazyTensorTraceTests: XCTestCase {
165197
("testSingleLiveTensor", testSingleLiveTensor),
166198
("testMultipleLiveTensors", testMultipleLiveTensors),
167199
("testSimpleControlFlow", testSimpleControlFlow),
200+
("testManualConstPromotion", testManualConstPromotion),
168201
("testConstPromotion", testConstPromotion)
169202
]
170203
}

0 commit comments

Comments
 (0)