Skip to content
This repository was archived by the owner on Jul 1, 2023. It is now read-only.

Tensor-level annotations #1064

Merged
merged 17 commits into from
Aug 20, 2020
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
Initial annotations prototype
  • Loading branch information
texasmichelle committed Apr 22, 2020
commit 73e0971ec1a2093024139e01a7f7c9181d98357a
9 changes: 9 additions & 0 deletions Sources/CX10/xla_tensor_wrapper.cc
Original file line number Diff line number Diff line change
Expand Up @@ -315,6 +315,10 @@ OpaqueXLATensor* XLATensor_all(OpaqueXLATensor* input, Int64ArrayRef dimensions,
XlaHelpers::I64List(dimensions.slice()),
keep_reduced_dimensions));
}
OpaqueXLATensor* XLATensor_annotate(OpaqueXLATensor* a,
const char* annotation) {
return new XLATensor(XLATensor::annotate(*a, std::string(annotation)));
}
OpaqueXLATensor* XLATensor_any(OpaqueXLATensor* input, Int64ArrayRef dimensions,
bool keep_reduced_dimensions) {
return new XLATensor(XLATensor::any(*input,
Expand Down Expand Up @@ -429,6 +433,11 @@ OpaqueXLATensor* XLATensor_full(Int64ArrayRef size, XLAScalar value,
OpaqueXLATensor* XLATensor_ge(OpaqueXLATensor* x, OpaqueXLATensor* y) {
return new XLATensor(XLATensor::ge(*x, *y));
}
OpaqueString* XLATensor_get_annotations(OpaqueXLATensor* a) {
std::string ir_dag_text =
swift_xla::ir::DumpUtil::GetAnnotations({a->GetIrValue().node.get()});
return new std::string(ir_dag_text);
}
OpaqueXLATensor* XLATensor_gt(OpaqueXLATensor* x, OpaqueXLATensor* y) {
return new XLATensor(XLATensor::gt(*x, *y));
}
Expand Down
4 changes: 3 additions & 1 deletion Sources/CX10/xla_tensor_wrapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,7 @@ XLA_API OpaqueXLATensor* XLATensor_add(OpaqueXLATensor* a, OpaqueXLATensor* b);
XLA_API OpaqueXLATensor* XLATensor_all(OpaqueXLATensor* input,
Int64ArrayRef dimensions,
bool keep_reduced_dimensions);
XLA_API OpaqueXLATensor* XLATensor_annotate(OpaqueXLATensor* a, const char*);
XLA_API OpaqueXLATensor* XLATensor_any(OpaqueXLATensor* input,
Int64ArrayRef dimensions,
bool keep_reduced_dimensions);
Expand Down Expand Up @@ -277,10 +278,12 @@ XLA_API OpaqueXLATensor*
XLATensor_full(Int64ArrayRef size, XLAScalar value, const struct CDevice device,
enum XLATensorScalarType type);
XLA_API OpaqueXLATensor* XLATensor_ge(OpaqueXLATensor* x, OpaqueXLATensor* y);
XLA_API OpaqueString* XLATensor_get_annotations(OpaqueXLATensor* a);
XLA_API OpaqueXLATensor* XLATensor_gt(OpaqueXLATensor* x, OpaqueXLATensor* y);
XLA_API OpaqueXLATensor* XLATensor_index(OpaqueXLATensor* input,
OpaqueXLATensorArrayRef indices,
int64_t start_dim);
XLA_API OpaqueString* XLATensor_ir_text(OpaqueXLATensor* a);
XLA_API OpaqueXLATensor* XLATensor_is_finite(OpaqueXLATensor* input);
XLA_API OpaqueXLATensor* XLATensor_is_inf(OpaqueXLATensor* input);
XLA_API OpaqueXLATensor* XLATensor_is_nan(OpaqueXLATensor* input);
Expand Down Expand Up @@ -355,7 +358,6 @@ XLA_API OpaqueXLATensor* XLATensor_sqrt(OpaqueXLATensor* a);
XLA_API OpaqueXLATensor* XLATensor_squeeze(OpaqueXLATensor* a, int64_t dim);
XLA_API OpaqueXLATensor*
XLATensor_stack(OpaqueXLATensorArrayRef tensors, int64_t dim);
XLA_API OpaqueString* XLATensor_ir_text(OpaqueXLATensor* a);
XLA_API OpaqueXLATensor* XLATensor_sub(OpaqueXLATensor* a, OpaqueXLATensor* b);
XLA_API OpaqueXLATensor* XLATensor_sum(OpaqueXLATensor* a, Int64ArrayRef dims,
bool keep_reduced_dimensions,
Expand Down
64 changes: 64 additions & 0 deletions Sources/TensorFlow/Core/Tensor.swift
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,70 @@ public struct Tensor<Scalar: TensorFlowScalar> {
}
}

public protocol TensorProtocol {
associatedtype Scalar: TensorFlowScalar
init(repeating repeatedValue: Scalar, shape: TensorShape, on device: Device)
var annotations: String { get }
var shape: TensorShape { get }
var summary: String { get }
}

public protocol DifferentiableTensorProtocol:
TensorProtocol & Differentiable & EuclideanDifferentiable
where Scalar: TensorFlowFloatingPoint {
@differentiable(wrt: self)
func annotate(_ annotation: String) -> Self
}

extension Tensor: TensorProtocol & DifferentiableTensorProtocol where Scalar: TensorFlowFloatingPoint {

public var annotations: String {
#if USING_X10_BACKEND
switch handle.backend {
case .XLA:
let rawAnnotations = XLATensor.annotations(xlaTensor)

// TODO(michellecasbon): Add formatting.

let formattedAnnotations = """
Layer Output Shape Attributes
============================= ==================== ======================
\(rawAnnotations)
"""

return formattedAnnotations

case .TF_EAGER:
return Device.defaultTFEager.annotationsAvailable
}
#else
return ""
#endif
}

public var summary: String { annotations }

@differentiable(wrt: self)
public func annotate(_ annotation: String) -> Tensor<Scalar> {
#if USING_X10_BACKEND
switch handle.backend {
case .XLA:
return Tensor<Scalar>(_xla: XLATensor.annotate(xlaTensor, annotation))
case .TF_EAGER:
return self
}
#else
return self
#endif
}

@derivative(of: annotate)
@usableFromInline
func vjpAnnotate(_ annotation: String) -> (value: Tensor<Scalar>, pullback: (Tensor<Scalar>) -> Tensor<Scalar>) {
(annotate(annotation), { $0 })
}
}

extension Tensor: AnyTensor {
public var _rawTensorHandle: CTensorHandle { return handle._cTensorHandle }
public var _tensorFlowDataType: TensorDataType { return Scalar.tensorFlowDataType }
Expand Down
90 changes: 90 additions & 0 deletions Sources/TensorFlow/Layer.swift
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,50 @@ where
/// - Returns: The output.
@differentiable(wrt: self)
func callAsFunction(_ input: Input) -> Output

@differentiable(wrt: self)
func forward(_ input: Input) -> Output
}

extension Module {
@differentiable(wrt: self)
public func forward(_ input: Input) -> Output {
return callAsFunction(input)
}
}

extension Module where Input: TensorProtocol, Output: DifferentiableTensorProtocol {
@differentiable(wrt: self)
public func callAsFunction(_ input: Input) -> Output {
let activation = forward(input)

return annotated(activation)
}

@differentiable
public func annotated(_ output: Output) -> Output {
#if USING_X10_BACKEND
let selfType = String(describing: Self.self)
let annotation = "type=" + selfType
let annotated = output.annotate(annotation)
return annotated
#else
return output
#endif
}

/// Returns the annotations obtained from applying the layer to the given input.
///
/// - Parameter input: The input to the layer.
/// - Returns: All collected annotations from the XLA graph.
public func annotations(input: Input) -> String {
let output = self.callAsFunction(input)
return output.annotations
}

public func summary(input: Input) -> String {
return self.annotations(input: input)
}
}

/// A neural network layer.
Expand All @@ -43,13 +87,59 @@ public protocol Layer: Module where Input: Differentiable {
/// - Returns: The output.
@differentiable
func callAsFunction(_ input: Input) -> Output

@differentiable
func forward(_ input: Input) -> Output
}

extension Layer {
@differentiable
public func call(_ input: Input) -> Output {
callAsFunction(input)
}

@differentiable
public func forward(_ input: Input) -> Output {
return callAsFunction(input)
}
}

extension Layer where Input: DifferentiableTensorProtocol, Output: DifferentiableTensorProtocol {
@differentiable
public func callAsFunction(_ input: Input) -> Output {
let activation = forward(input)

return annotated(activation)
}

/// Returns the annotations obtained from applying the layer to the given input.
///
/// - Parameter input: The input to the layer.
/// - Returns: All collected annotations from the XLA graph.
public func annotations(input: Input) -> String {
return self.annotations(inputShape: input.shape)
}

/// Returns the annotations obtained from applying the layer to the given input.
///
/// - Parameter input: The shape of the input to the layer.
/// - Returns: All collected annotations from the XLA graph.
public func annotations(inputShape: TensorShape) -> String {
#if USING_X10_BACKEND
LazyTensorBarrier()
let zeros = Input.init(repeating: 0, shape: inputShape, on: Device.defaultXLA)
let model = Self.self.init(copying: self, to: Device.defaultXLA)
let output = model(zeros)

return output.annotations
#else
return ""
#endif
}

public func summary(inputShape: TensorShape) -> String {
return self.annotations(inputShape: inputShape)
}
}

/// An empty struct representing empty `TangentVector`s for parameterless layers.
Expand Down
4 changes: 3 additions & 1 deletion Sources/TensorFlow/Layers/Dense.swift
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@
/// weight and bias for each element in input batch.
@frozen
public struct Dense<Scalar: TensorFlowFloatingPoint>: Layer {
public typealias Input = Tensor<Scalar>
public typealias Output = Tensor<Scalar>
/// The weight matrix.
public var weight: Tensor<Scalar>
/// The bias vector.
Expand Down Expand Up @@ -76,7 +78,7 @@ public struct Dense<Scalar: TensorFlowFloatingPoint>: Layer {
/// - Parameter input: The input to the layer.
/// - Returns: The output.
@differentiable
public func callAsFunction(_ input: Tensor<Scalar>) -> Tensor<Scalar> {
public func forward(_ input: Tensor<Scalar>) -> Tensor<Scalar> {
if batched {
let hidden = matmul(input.expandingShape(at: 1), weight).squeezingShape(at: 1)
return activation(useBias ? hidden + bias : hidden)
Expand Down
13 changes: 13 additions & 0 deletions Sources/x10/swift_bindings/Device.swift
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,13 @@ public struct Device {
case .XLA: return "XLA"
}
}

var annotationsAvailable: String {
switch self {
case .TF_EAGER: return "Annotations not availabile in TF_EAGER."
case .XLA: return "Annotations available in XLA."
}
}
}

/// A device kind.
Expand Down Expand Up @@ -206,6 +213,12 @@ extension Device: CustomStringConvertible {
}
}

extension Device {
public var annotationsAvailable: String {
"\(backend.annotationsAvailable)"
}
}

extension CDevice {
var device: Device {
return Device(kind: hw_type.kind, ordinal: Int(ordinal), backend: .XLA)
Expand Down
23 changes: 17 additions & 6 deletions Sources/x10/swift_bindings/XLATensor.swift
Original file line number Diff line number Diff line change
Expand Up @@ -242,6 +242,17 @@ extension XLATensor {
}
}

static func annotate(_ a: XLATensor, _ annotation: String) -> XLATensor {
return XLATensor(_handle: XLATensor_annotate(a.handle, annotation))
}

static func annotations(_ a: XLATensor) -> String {
// TODO(michellecasbon): Format with header.
let str = XLATensor_get_annotations(a.handle)
defer { DeleteString(str) }
return String(cString: GetStringCStr(str))
}

static func any(_ input: XLATensor, _ reductionIndices: [Int64], _ keepDims: Bool) -> XLATensor {
defer { _fixLifetime(input) }
return reductionIndices.withArrayRef { reductionIndices in
Expand Down Expand Up @@ -474,6 +485,12 @@ extension XLATensor {
}
}

static func irText(_ a: XLATensor) -> String {
let str = XLATensor_ir_text(a.handle)
defer { DeleteString(str) }
return String(cString: GetStringCStr(str))
}

static func isFinite(_ input: XLATensor) -> XLATensor {
defer { _fixLifetime(input) }
return XLATensor(_handle: XLATensor_is_finite(input.handle))
Expand Down Expand Up @@ -807,12 +824,6 @@ extension XLATensor {
}
}

static func irText(_ a: XLATensor) -> String {
let str = XLATensor_ir_text(a.handle)
defer { DeleteString(str) }
return String(cString: GetStringCStr(str))
}

static func sub(_ a: XLATensor, _ b: XLATensor) -> XLATensor {
defer { _fixLifetime(a) }
defer { _fixLifetime(b) }
Expand Down
1 change: 1 addition & 0 deletions Sources/x10/xla_tensor/aten_compat.h
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,7 @@
_(aten, all) \
_(aten, allclose) \
_(aten, alpha_dropout) \
_(aten, annotate) \
_(aten, any) \
_(aten, arange) \
_(aten, argmax) \
Expand Down
32 changes: 32 additions & 0 deletions Sources/x10/xla_tensor/ir_dump_util.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,22 @@ std::string GenerateTextNodeSpec(const Node* node, const NodeIdMap& id_map) {
return ss.str();
}

std::string GenerateTextAnnotation(const Node* node) {
// TODO(michellecasbon): Use json.
std::stringstream ss;
ss << " shape=[";
size_t i = 0;
for (auto& dimension : node->shape().dimensions()) {
if ((i++) != 0) ss << ", ";
ss << dimension;
}
ss << "] ";
for (auto& tag : GetNodeTags(node)) {
ss << tag.value;
}
return ss.str();
}

} // namespace

std::string DumpUtil::ToDot(absl::Span<const Node* const> nodes) {
Expand Down Expand Up @@ -261,5 +277,21 @@ std::string DumpUtil::ToHlo(absl::Span<const Value> values) {
return ConsumeValue(xla::util::GetComputationHloText(computation));
}

std::string DumpUtil::GetAnnotations(absl::Span<const Node* const> nodes) {
auto post_order = Util::ComputePostOrder(nodes);

NodeIdMap id_map = GenerateIdMap(post_order);
std::stringstream ss;
ss << "{\n";
for (auto node : post_order) {
// Only process annotations
if (node->op().ToString() != "x10::annotate") continue;

ss << GenerateTextAnnotation(node) << "\n";
}
ss << "}\n";
return ss.str();
}

} // namespace ir
} // namespace swift_xla
2 changes: 2 additions & 0 deletions Sources/x10/xla_tensor/ir_dump_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@ class DumpUtil {
absl::Span<const Node* const> roots);

static std::string ToHlo(absl::Span<const Value> values);

static std::string GetAnnotations(absl::Span<const Node* const> nodes);
};

} // namespace ir
Expand Down
Loading