Skip to content
This repository was archived by the owner on Jul 1, 2023. It is now read-only.

Commit 3b2108c

Browse files
authored
Track the device name in LazyTensorOperation (#262)
1 parent ec7f583 commit 3b2108c

File tree

2 files changed

+15
-1
lines changed

2 files changed

+15
-1
lines changed

Diff for: Sources/TensorFlow/Core/LazyTensorOperation.swift

+2
Original file line numberDiff line numberDiff line change
@@ -162,6 +162,7 @@ class LazyTensorOperation: TensorOperation {
162162
let outputCount: Int
163163
var inputs: [Input]
164164
var attributes: [String: Attribute]
165+
var deviceName: String?
165166
var outputs: [TFETensorHandle]?
166167
var id: String?
167168

@@ -203,6 +204,7 @@ class LazyTensorOperation: TensorOperation {
203204
self.name = name
204205
self.inputs = []
205206
self.attributes = [:]
207+
self.deviceName = _ExecutionContext.global.currentDeviceName
206208
self.outputCount = outputCount
207209
self.outputs = nil
208210
self.id = id

Diff for: Tests/TensorFlowTests/LazyTensorOperationTests.swift

+13-1
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ import XCTest
1717
import CTensorFlow
1818

1919
final class LazyTensorOperationTests: XCTestCase {
20+
2021
func testNoInput() {
2122
let placeholder = LazyTensorOperation(
2223
_id: "V", name: "Placeholder", outputCount: 1)
@@ -204,6 +205,15 @@ final class LazyTensorOperationTests: XCTestCase {
204205
XCTAssertEqual(op0.description, "%0 = Nop[fn: TFFunction(ExampleFunction)]()")
205206
}
206207

208+
func testDeviceTracking() {
209+
let op0 = LazyTensorOperation(_id: "0", name: "Nop", outputCount: 1)
210+
XCTAssertEqual(op0.deviceName, nil)
211+
withDevice(named: "/job:localhost/replica:0/task:0/device:CPU:0") {
212+
let op1 = LazyTensorOperation(_id: "0", name: "Nop", outputCount: 1)
213+
XCTAssertEqual(op1.deviceName ?? "", "/job:localhost/replica:0/task:0/device:CPU:0")
214+
}
215+
}
216+
207217
static var allTests = [
208218
("testNoInput", testNoInput),
209219
("testSingleInput", testSingleInput),
@@ -223,6 +233,8 @@ final class LazyTensorOperationTests: XCTestCase {
223233
testOptionalTensorShapeArrayAttribute),
224234
("testArrayAttributes", testArrayAttributes),
225235
("testMultipleAttributes", testMultipleAttributes),
226-
("testFunctionAttribute", testFunctionAttribute)
236+
("testFunctionAttribute", testFunctionAttribute),
237+
("testDeviceTracking", testDeviceTracking)
238+
227239
]
228240
}

0 commit comments

Comments
 (0)