Skip to content
This repository was archived by the owner on Jul 1, 2023. It is now read-only.

Commit b9f6c3b

Browse files
Tensor-level annotations (#1064)
* Initial annotations prototype * Lint
1 parent b88bbf9 commit b9f6c3b

14 files changed

+272
-12
lines changed

Sources/CX10/xla_tensor_wrapper.cc

+9
Original file line numberDiff line numberDiff line change
@@ -315,6 +315,10 @@ OpaqueXLATensor* XLATensor_all(OpaqueXLATensor* input, Int64ArrayRef dimensions,
315315
XlaHelpers::I64List(dimensions.slice()),
316316
keep_reduced_dimensions));
317317
}
318+
OpaqueXLATensor* XLATensor_annotate(OpaqueXLATensor* a,
319+
const char* annotation) {
320+
return new XLATensor(XLATensor::annotate(*a, std::string(annotation)));
321+
}
318322
OpaqueXLATensor* XLATensor_any(OpaqueXLATensor* input, Int64ArrayRef dimensions,
319323
bool keep_reduced_dimensions) {
320324
return new XLATensor(XLATensor::any(*input,
@@ -441,6 +445,11 @@ OpaqueXLATensor* XLATensor_full(Int64ArrayRef size, XLAScalar value,
441445
OpaqueXLATensor* XLATensor_ge(OpaqueXLATensor* x, OpaqueXLATensor* y) {
442446
return new XLATensor(XLATensor::ge(*x, *y));
443447
}
448+
OpaqueString* XLATensor_get_annotations(OpaqueXLATensor* a) {
449+
std::string ir_dag_text =
450+
swift_xla::ir::DumpUtil::GetAnnotations({a->GetIrValue().node.get()});
451+
return new std::string(ir_dag_text);
452+
}
444453
OpaqueXLATensor* XLATensor_gt(OpaqueXLATensor* x, OpaqueXLATensor* y) {
445454
return new XLATensor(XLATensor::gt(*x, *y));
446455
}

Sources/CX10/xla_tensor_wrapper.h

+3-1
Original file line numberDiff line numberDiff line change
@@ -227,6 +227,7 @@ XLA_API OpaqueXLATensor* XLATensor_add(OpaqueXLATensor* a, OpaqueXLATensor* b);
227227
XLA_API OpaqueXLATensor* XLATensor_all(OpaqueXLATensor* input,
228228
Int64ArrayRef dimensions,
229229
bool keep_reduced_dimensions);
230+
XLA_API OpaqueXLATensor* XLATensor_annotate(OpaqueXLATensor* a, const char*);
230231
XLA_API OpaqueXLATensor* XLATensor_any(OpaqueXLATensor* input,
231232
Int64ArrayRef dimensions,
232233
bool keep_reduced_dimensions);
@@ -284,10 +285,12 @@ XLA_API OpaqueXLATensor*
284285
XLATensor_full(Int64ArrayRef size, XLAScalar value, const struct CDevice device,
285286
enum XLATensorScalarType type);
286287
XLA_API OpaqueXLATensor* XLATensor_ge(OpaqueXLATensor* x, OpaqueXLATensor* y);
288+
XLA_API OpaqueString* XLATensor_get_annotations(OpaqueXLATensor* a);
287289
XLA_API OpaqueXLATensor* XLATensor_gt(OpaqueXLATensor* x, OpaqueXLATensor* y);
288290
XLA_API OpaqueXLATensor* XLATensor_index(OpaqueXLATensor* input,
289291
OpaqueXLATensorArrayRef indices,
290292
int64_t start_dim);
293+
XLA_API OpaqueString* XLATensor_ir_text(OpaqueXLATensor* a);
291294
XLA_API OpaqueXLATensor* XLATensor_is_finite(OpaqueXLATensor* input);
292295
XLA_API OpaqueXLATensor* XLATensor_is_inf(OpaqueXLATensor* input);
293296
XLA_API OpaqueXLATensor* XLATensor_is_nan(OpaqueXLATensor* input);
@@ -367,7 +370,6 @@ XLA_API OpaqueXLATensor* XLATensor_sqrt(OpaqueXLATensor* a);
367370
XLA_API OpaqueXLATensor* XLATensor_squeeze(OpaqueXLATensor* a, int64_t dim);
368371
XLA_API OpaqueXLATensor*
369372
XLATensor_stack(OpaqueXLATensorArrayRef tensors, int64_t dim);
370-
XLA_API OpaqueString* XLATensor_ir_text(OpaqueXLATensor* a);
371373
XLA_API OpaqueXLATensor* XLATensor_sub(OpaqueXLATensor* a, OpaqueXLATensor* b);
372374
XLA_API OpaqueXLATensor* XLATensor_sum(OpaqueXLATensor* a, Int64ArrayRef dims,
373375
bool keep_reduced_dimensions,

Sources/TensorFlow/Core/Tensor.swift

+63-2
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,8 @@
1212
// See the License for the specific language governing permissions and
1313
// limitations under the License.
1414

15-
import _Differentiation
1615
import CTensorFlow
16+
import _Differentiation
1717

1818
infix operator .==: ComparisonPrecedence
1919
infix operator .!=: ComparisonPrecedence
@@ -24,7 +24,7 @@ public protocol AnyTensor {
2424
var _tensorFlowDataType: TensorDataType { get }
2525
}
2626

27-
/// A multidimensional array of elements that is a generalization of vectors and matrices to
27+
/// A multidimensional array of elements that is a generalization of vectors and matrices to
2828
/// potentially higher dimensions.
2929
///
3030
/// The generic parameter `Scalar` describes the type of scalars in the tensor (such as `Int32`,
@@ -41,6 +41,67 @@ public struct Tensor<Scalar: TensorFlowScalar> {
4141
}
4242
}
4343

44+
public protocol TensorProtocol {
45+
associatedtype Scalar: TensorFlowScalar
46+
init(repeating repeatedValue: Scalar, shape: TensorShape, on device: Device)
47+
var annotations: String { get }
48+
var shape: TensorShape { get }
49+
var summary: String { get }
50+
}
51+
52+
public protocol DifferentiableTensorProtocol:
53+
TensorProtocol & Differentiable & EuclideanDifferentiable
54+
where Scalar: TensorFlowFloatingPoint {
55+
@differentiable(wrt: self)
56+
func annotate(_ annotation: String) -> Self
57+
}
58+
59+
extension Tensor: TensorProtocol & DifferentiableTensorProtocol
60+
where Scalar: TensorFlowFloatingPoint {
61+
62+
public var annotations: String {
63+
#if USING_X10_BACKEND
64+
switch handle.backend {
65+
case .XLA:
66+
let rawAnnotations = XLATensor.annotations(xlaTensor)
67+
68+
// TODO(michellecasbon): Add formatting.
69+
70+
return rawAnnotations
71+
72+
case .TF_EAGER:
73+
return Device.defaultTFEager.annotationsAvailable
74+
}
75+
#else
76+
return "Annotations not available in TF_EAGER."
77+
#endif
78+
}
79+
80+
public var summary: String { annotations }
81+
82+
@differentiable(wrt: self)
83+
public func annotate(_ annotation: String) -> Tensor<Scalar> {
84+
#if USING_X10_BACKEND
85+
switch handle.backend {
86+
case .XLA:
87+
return Tensor<Scalar>(_xla: XLATensor.annotate(xlaTensor, annotation))
88+
case .TF_EAGER:
89+
return self
90+
}
91+
#else
92+
return self
93+
#endif
94+
}
95+
96+
@derivative(of: annotate)
97+
@usableFromInline
98+
func vjpAnnotate(_ annotation: String) -> (
99+
value: Tensor<Scalar>, pullback: (Tensor<Scalar>) -> Tensor<Scalar>
100+
) {
101+
(annotate(annotation), { $0 })
102+
}
103+
}
104+
44105
extension Tensor: AnyTensor {
45106
public var _rawTensorHandle: CTensorHandle { return handle._cTensorHandle }
46107
public var _tensorFlowDataType: TensorDataType { return Scalar.tensorFlowDataType }

Sources/x10/swift_bindings/Device.swift

+13
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,13 @@ public struct Device {
7272
case .XLA: return "XLA"
7373
}
7474
}
75+
76+
var annotationsAvailable: String {
77+
switch self {
78+
case .TF_EAGER: return "Annotations not available in TF_EAGER."
79+
case .XLA: return "Annotations available in XLA."
80+
}
81+
}
7582
}
7683

7784
/// A device kind.
@@ -208,6 +215,12 @@ extension Device: CustomStringConvertible {
208215
}
209216
}
210217

218+
extension Device {
219+
public var annotationsAvailable: String {
220+
"\(backend.annotationsAvailable)"
221+
}
222+
}
223+
211224
extension CDevice {
212225
var device: Device {
213226
return Device(kind: hw_type.kind, ordinal: Int(ordinal), backend: .XLA)

Sources/x10/swift_bindings/XLATensor.swift

+21-8
Original file line numberDiff line numberDiff line change
@@ -242,6 +242,17 @@ extension XLATensor {
242242
}
243243
}
244244

245+
static func annotate(_ a: XLATensor, _ annotation: String) -> XLATensor {
246+
return XLATensor(_handle: XLATensor_annotate(a.handle, annotation))
247+
}
248+
249+
static func annotations(_ a: XLATensor) -> String {
250+
// TODO(michellecasbon): Format with header.
251+
let str = XLATensor_get_annotations(a.handle)
252+
defer { DeleteString(str) }
253+
return String(cString: GetStringCStr(str))
254+
}
255+
245256
static func any(_ input: XLATensor, _ reductionIndices: [Int64], _ keepDims: Bool) -> XLATensor {
246257
defer { _fixLifetime(input) }
247258
return reductionIndices.withArrayRef { reductionIndices in
@@ -407,7 +418,9 @@ extension XLATensor {
407418
return XLATensor(_handle: XLATensor_div(a.handle, b.handle))
408419
}
409420

410-
static func dynamic_slice(_ base: XLATensor, _ start_indices: [XLATensor], _ slice_shape: [Int64]) -> XLATensor {
421+
static func dynamic_slice(_ base: XLATensor, _ start_indices: [XLATensor], _ slice_shape: [Int64])
422+
-> XLATensor
423+
{
411424
start_indices.withArrayRef { start_indices in
412425
slice_shape.withArrayRef { slice_shape in
413426
return XLATensor(_handle: XLATensor_dynamic_slice(base.handle, start_indices, slice_shape))
@@ -491,6 +504,12 @@ extension XLATensor {
491504
}
492505
}
493506

507+
static func irText(_ a: XLATensor) -> String {
508+
let str = XLATensor_ir_text(a.handle)
509+
defer { DeleteString(str) }
510+
return String(cString: GetStringCStr(str))
511+
}
512+
494513
static func isFinite(_ input: XLATensor) -> XLATensor {
495514
defer { _fixLifetime(input) }
496515
return XLATensor(_handle: XLATensor_is_finite(input.handle))
@@ -761,7 +780,7 @@ extension XLATensor {
761780
}
762781

763782
static func replica_id(_ device: Device) -> XLATensor {
764-
return XLATensor(_handle: XLATensor_replica_id(device.cdevice));
783+
return XLATensor(_handle: XLATensor_replica_id(device.cdevice))
765784
}
766785

767786
static func resize_value(_ value: XLATensor, _ dims: [Int64]) -> XLATensor {
@@ -841,12 +860,6 @@ extension XLATensor {
841860
}
842861
}
843862

844-
static func irText(_ a: XLATensor) -> String {
845-
let str = XLATensor_ir_text(a.handle)
846-
defer { DeleteString(str) }
847-
return String(cString: GetStringCStr(str))
848-
}
849-
850863
static func sub(_ a: XLATensor, _ b: XLATensor) -> XLATensor {
851864
defer { _fixLifetime(a) }
852865
defer { _fixLifetime(b) }

Sources/x10/xla_tensor/aten_compat.h

+1
Original file line numberDiff line numberDiff line change
@@ -219,6 +219,7 @@
219219
_(aten, all) \
220220
_(aten, allclose) \
221221
_(aten, alpha_dropout) \
222+
_(aten, annotate) \
222223
_(aten, any) \
223224
_(aten, arange) \
224225
_(aten, argmax) \

Sources/x10/xla_tensor/ir_dump_util.cpp

+32
Original file line numberDiff line numberDiff line change
@@ -204,6 +204,22 @@ struct ChangeLogNode {
204204

205205
thread_local std::map<xla::hash_t, std::vector<ChangeLogNode>> g_change_logs;
206206

207+
std::string GenerateTextAnnotation(const Node* node) {
208+
// TODO(michellecasbon): Use json.
209+
std::stringstream ss;
210+
ss << " shape=[";
211+
size_t i = 0;
212+
for (auto& dimension : node->shape().dimensions()) {
213+
if ((i++) != 0) ss << ", ";
214+
ss << dimension;
215+
}
216+
ss << "] ";
217+
for (auto& tag : GetNodeTags(node)) {
218+
ss << tag.value;
219+
}
220+
return ss.str();
221+
}
222+
207223
} // namespace
208224

209225
std::string DumpUtil::ToDot(absl::Span<const Node* const> nodes) {
@@ -323,5 +339,21 @@ std::string DumpUtil::GetGraphChangeLog(absl::Span<const Node* const> roots) {
323339
return ss.str();
324340
}
325341

342+
std::string DumpUtil::GetAnnotations(absl::Span<const Node* const> nodes) {
343+
auto post_order = Util::ComputePostOrder(nodes);
344+
345+
NodeIdMap id_map = GenerateIdMap(post_order);
346+
std::stringstream ss;
347+
ss << "{";
348+
for (auto node : post_order) {
349+
// Only process annotations
350+
if (node->op().ToString() != "x10::annotate") continue;
351+
352+
ss << "\n" << GenerateTextAnnotation(node);
353+
}
354+
ss << "\n" << "}";
355+
return ss.str();
356+
}
357+
326358
} // namespace ir
327359
} // namespace swift_xla

Sources/x10/xla_tensor/ir_dump_util.h

+2
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,8 @@ class DumpUtil {
4141
const Device& device);
4242

4343
static std::string GetGraphChangeLog(absl::Span<const Node* const> roots);
44+
45+
static std::string GetAnnotations(absl::Span<const Node* const> nodes);
4446
};
4547

4648
} // namespace ir
+48
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
// Copyright 2020 TensorFlow Authors
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#include "tensorflow/compiler/tf2xla/xla_tensor/ops/annotate.h"
16+
17+
#include "tensorflow/compiler/xla/xla_client/util.h"
18+
#include "tensorflow/compiler/tf2xla/xla_tensor/elementwise.h"
19+
#include "tensorflow/compiler/tf2xla/xla_tensor/lowering_context.h"
20+
21+
namespace swift_xla {
22+
namespace ir {
23+
namespace ops {
24+
25+
Annotate::Annotate(const Value& input, std::string annotation)
26+
: Node(ir::OpKind(at::aten::annotate), {input}, input.shape(),
27+
/*num_outputs=*/1, xla::util::MHash()),
28+
annotation_(annotation) {}
29+
30+
NodePtr Annotate::Clone(OpList operands) const {
31+
return MakeNode<Annotate>(operands.at(0), annotation_);
32+
}
33+
34+
XlaOpVector Annotate::Lower(LoweringContext* loctx) const {
35+
xla::XlaOp input = loctx->GetOutputOp(operand(0));
36+
return ReturnOp(input, loctx);
37+
}
38+
39+
std::string Annotate::ToString() const {
40+
std::stringstream ss;
41+
ss << Node::ToString() << ", annotation=" << annotation_;
42+
return ss.str();
43+
}
44+
45+
} // namespace ops
46+
} // namespace ir
47+
} // namespace swift_xla
48+

Sources/x10/xla_tensor/ops/annotate.h

+47
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
/*
2+
* Copyright 2020 TensorFlow Authors
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
#pragma once
18+
19+
// #include <vector>
20+
21+
#include "tensorflow/compiler/tf2xla/xla_tensor/ir.h"
22+
23+
namespace swift_xla {
24+
namespace ir {
25+
namespace ops {
26+
27+
// IR node for collecting layer statistics.
28+
class Annotate : public Node {
29+
public:
30+
Annotate(const Value& input, std::string annotation);
31+
32+
NodePtr Clone(OpList operands) const override;
33+
34+
XlaOpVector Lower(LoweringContext* loctx) const override;
35+
36+
std::string ToString() const override;
37+
38+
const std::string& annotation() const { return annotation_; }
39+
40+
private:
41+
std::string annotation_;
42+
};
43+
44+
} // namespace ops
45+
} // namespace ir
46+
} // namespace swift_xla
47+

Sources/x10/xla_tensor/tensor.h

+2
Original file line numberDiff line numberDiff line change
@@ -298,6 +298,8 @@ class XLATensor {
298298
std::vector<xla::int64> dimensions,
299299
bool keep_reduced_dimensions);
300300

301+
static XLATensor annotate(const XLATensor& input, std::string annotation);
302+
301303
static XLATensor any(const XLATensor& input,
302304
std::vector<xla::int64> dimensions,
303305
bool keep_reduced_dimensions);

0 commit comments

Comments
 (0)