This repository was archived by the owner on Jul 1, 2023. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 137
/
Copy pathLazyTensorTestHelper.swift
92 lines (80 loc) · 3.1 KB
/
LazyTensorTestHelper.swift
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
// Copyright 2019 The TensorFlow Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
import XCTest
@testable import TensorFlow
class LazyTensorTestCase: XCTestCase {
static var shouldPromoteConstants = true
override class func setUp() {
super.setUp()
_ThreadLocalState.useLazyTensor = true
shouldPromoteConstants = LazyTensorContext.local.shouldPromoteConstants
LazyTensorContext.local.shouldPromoteConstants = false
}
override class func tearDown() {
super.tearDown()
_ThreadLocalState.useLazyTensor = false
LazyTensorContext.local.shouldPromoteConstants = shouldPromoteConstants
}
}
protocol _LazyTensorCompatible {
/// The underlying `LazyTensorHandle` (if any).
var _lazyTensor: LazyTensorHandle? { get }
/// Returns `Self` that wraps a concrete `LazyTensorHandle`.
/// (Triggers materialization if needed.)
var _concreteLazyTensor: Self { get }
/// Similar to the `concreteLazyTensor` with an additional constraint that
/// the underlying concrete `LazyTensorHandle` should be marked to be promoted as
/// an input when used in an extracted trace.
var _concreteInputLazyTensor: Self { get }
}
extension _AnyTensorHandle {
var _lazyTensor: LazyTensorHandle? {
if let handle = self as? LazyTensorHandle {
return handle
} else {
return nil
}
}
var _concreteLazyTensor: LazyTensorHandle { LazyTensorHandle(self._tfeTensorHandle) }
}
extension TensorHandle: _LazyTensorCompatible {
var _lazyTensor: LazyTensorHandle? { handle._lazyTensor }
public var _concreteLazyTensor: TensorHandle {
TensorHandle(handle: handle._concreteLazyTensor)
}
}
extension Tensor: _LazyTensorCompatible {
var _lazyTensor: LazyTensorHandle? { handle._lazyTensor }
public var _concreteLazyTensor: Tensor {
Tensor(handle: handle._concreteLazyTensor)
}
}
extension StringTensor: _LazyTensorCompatible {
var _lazyTensor: LazyTensorHandle? { handle._lazyTensor }
public var _concreteLazyTensor: StringTensor {
StringTensor(handle: handle._concreteLazyTensor)
}
}
extension VariantHandle: _LazyTensorCompatible {
var _lazyTensor: LazyTensorHandle? { handle._lazyTensor }
public var _concreteLazyTensor: VariantHandle {
VariantHandle(handle: handle._concreteLazyTensor)
}
}
extension ResourceHandle: _LazyTensorCompatible {
var _lazyTensor: LazyTensorHandle? { handle._lazyTensor }
public var _concreteLazyTensor: ResourceHandle {
ResourceHandle(handle: handle._concreteLazyTensor)
}
}