Skip to content

Commit c381f5f

Browse files
authored
[AutoDiff] Add Type Checking for @transposing (#25684)
Do the type checking for top level functions and methods that are marked as the transpose of another function.
1 parent 309f6dc commit c381f5f

20 files changed

+1511
-7
lines changed

Diff for: include/swift/AST/Attr.def

+3
Original file line numberDiff line numberDiff line change
@@ -432,6 +432,9 @@ SIMPLE_DECL_ATTR(compilerEvaluable, CompilerEvaluable,
432432
/* Not serialized */ 91)
433433
SIMPLE_DECL_ATTR(noDerivative, NoDerivative,
434434
OnVar, 92)
435+
DECL_ATTR(transposing, Transposing,
436+
OnFunc | LongAttribute | AllowMultipleAttributes |
437+
NotSerialized, 93)
435438

436439
#undef TYPE_ATTR
437440
#undef DECL_ATTR_ALIAS

Diff for: include/swift/AST/Attr.h

+73
Original file line numberDiff line numberDiff line change
@@ -1682,6 +1682,79 @@ class DifferentiatingAttr final
16821682
return DA->getKind() == DAK_Differentiating;
16831683
}
16841684
};
1685+
1686+
/// Attribute that registers a function as a transpose of another function.
1687+
///
1688+
/// Examples:
1689+
/// @transposing(foo)
1690+
/// @transposing(+, wrt: (lhs, rhs))
1691+
class TransposingAttr final
1692+
: public DeclAttribute,
1693+
private llvm::TrailingObjects<DifferentiableAttr,
1694+
ParsedAutoDiffParameter> {
1695+
/// The base type of the original function.
1696+
/// This is non-null only when the original function is not top-level (i.e. it
1697+
/// is an instance/static method).
1698+
TypeRepr *BaseType;
1699+
/// The original function name.
1700+
DeclNameWithLoc Original;
1701+
/// The original function, resolved by the type checker.
1702+
FuncDecl *OriginalFunction = nullptr;
1703+
/// The number of parsed parameters specified in 'wrt:'.
1704+
unsigned NumParsedParameters = 0;
1705+
/// The differentiation parameters' indices, resolved by the type checker.
1706+
AutoDiffIndexSubset *ParameterIndexSubset = nullptr;
1707+
1708+
explicit TransposingAttr(ASTContext &context, bool implicit,
1709+
SourceLoc atLoc, SourceRange baseRange,
1710+
TypeRepr *baseType, DeclNameWithLoc original,
1711+
ArrayRef<ParsedAutoDiffParameter> params);
1712+
1713+
explicit TransposingAttr(ASTContext &context, bool implicit,
1714+
SourceLoc atLoc, SourceRange baseRange,
1715+
TypeRepr *baseType, DeclNameWithLoc original,
1716+
AutoDiffIndexSubset *indices);
1717+
1718+
public:
1719+
static TransposingAttr *create(ASTContext &context, bool implicit,
1720+
SourceLoc atLoc, SourceRange baseRange,
1721+
TypeRepr *baseType, DeclNameWithLoc original,
1722+
ArrayRef<ParsedAutoDiffParameter> params);
1723+
1724+
static TransposingAttr *create(ASTContext &context, bool implicit,
1725+
SourceLoc atLoc, SourceRange baseRange,
1726+
TypeRepr *baseType, DeclNameWithLoc original,
1727+
AutoDiffIndexSubset *indices);
1728+
1729+
TypeRepr *getBaseType() const { return BaseType; }
1730+
DeclNameWithLoc getOriginal() const { return Original; }
1731+
1732+
FuncDecl *getOriginalFunction() const { return OriginalFunction; }
1733+
void setOriginalFunction(FuncDecl *decl) { OriginalFunction = decl; }
1734+
1735+
/// The parsed transposing parameters, i.e. the list of parameters
1736+
/// specified in 'wrt:'.
1737+
ArrayRef<ParsedAutoDiffParameter> getParsedParameters() const {
1738+
return {getTrailingObjects<ParsedAutoDiffParameter>(), NumParsedParameters};
1739+
}
1740+
MutableArrayRef<ParsedAutoDiffParameter> getParsedParameters() {
1741+
return {getTrailingObjects<ParsedAutoDiffParameter>(), NumParsedParameters};
1742+
}
1743+
size_t numTrailingObjects(OverloadToken<ParsedAutoDiffParameter>) const {
1744+
return NumParsedParameters;
1745+
}
1746+
1747+
AutoDiffIndexSubset *getParameterIndexSubset() const {
1748+
return ParameterIndexSubset;
1749+
}
1750+
void setParameterIndices(AutoDiffIndexSubset *pi) {
1751+
ParameterIndexSubset = pi;
1752+
}
1753+
1754+
static bool classof(const DeclAttribute *DA) {
1755+
return DA->getKind() == DAK_Transposing;
1756+
}
1757+
};
16851758

16861759
/// Attributes that may be applied to declarations.
16871760
class DeclAttributes {

Diff for: include/swift/AST/AutoDiff.h

+7
Original file line numberDiff line numberDiff line change
@@ -354,6 +354,13 @@ class AutoDiffIndexSubset : public llvm::FoldingSetNode {
354354
unsigned getNumIndices() const {
355355
return (unsigned)std::distance(begin(), end());
356356
}
357+
358+
SmallBitVector getBitVector() const {
359+
SmallBitVector indicesBitVec(capacity, false);
360+
for (auto index : getIndices())
361+
indicesBitVec.set(index);
362+
return indicesBitVec;
363+
}
357364

358365
bool contains(unsigned index) const {
359366
unsigned bitWordIndex, offset;

Diff for: include/swift/AST/DiagnosticsParse.def

+12-2
Original file line numberDiff line numberDiff line change
@@ -1512,13 +1512,23 @@ ERROR(attr_differentiating_expected_original_name,PointsToFirstBadToken,
15121512
ERROR(attr_differentiating_expected_label_linear_or_wrt,none,
15131513
"expected either 'linear' or 'wrt:'", ())
15141514

1515+
// transposing
1516+
ERROR(attr_transposing_expected_original_name,PointsToFirstBadToken,
1517+
"expected an original function name", ())
1518+
ERROR(attr_transposing_expected_label_linear_or_wrt,none,
1519+
"expected 'wrt:'", ())
1520+
1521+
// transposing `wrt` parameters clause
1522+
ERROR(transposing_params_clause_expected_parameter,PointsToFirstBadToken,
1523+
"expected a parameter, which can be a 'unsigned int' parameter number "
1524+
"or 'self'", ())
1525+
15151526
// differentiation `wrt` parameters clause
15161527
ERROR(expected_colon_after_label,PointsToFirstBadToken,
15171528
"expected a colon ':' after '%0'", (StringRef))
15181529
ERROR(diff_params_clause_expected_parameter,PointsToFirstBadToken,
15191530
"expected a parameter, which can be a function parameter name, "
1520-
"parameter index, or 'self'",
1521-
())
1531+
"parameter index, or 'self'", ())
15221532

15231533
// [differentiable ...] (sil-decl attr)
15241534
ERROR(sil_attr_differentiable_expected_keyword,PointsToFirstBadToken,

Diff for: include/swift/AST/DiagnosticsSema.def

+13
Original file line numberDiff line numberDiff line change
@@ -2826,6 +2826,19 @@ ERROR(differentiating_attr_not_in_same_file_as_original,none,
28262826
ERROR(differentiating_attr_original_already_has_derivative,none,
28272827
"a derivative already exists for %0", (DeclName))
28282828

2829+
// transposing
2830+
ERROR(transpose_params_clause_param_not_differentiable,none,
2831+
"can only transpose with respect to parameters that conform to "
2832+
"'Differentiable' and where '%0 == %0.TangentVector'", (StringRef))
2833+
ERROR(transposing_attr_overload_not_found,none,
2834+
"could not find function %0 with expected type %1", (DeclName, Type))
2835+
ERROR(transposing_attr_cant_use_named_wrt_params,none,
2836+
"cannot use named wrt parameters in '@transposing' attribute, found %0",
2837+
(Identifier))
2838+
ERROR(transposing_attr_result_value_not_differentiable,none,
2839+
"'@transposing' attribute requires original function result to "
2840+
"conform to 'Differentiable'", (Type))
2841+
28292842
// differentiation `wrt` parameters clause
28302843
ERROR(diff_function_no_parameters,none,
28312844
"%0 has no parameters to differentiate with respect to", (DeclName))

Diff for: include/swift/AST/Types.h

+9-1
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919

2020
// SWIFT_ENABLE_TENSORFLOW
2121
#include "swift/AST/AutoDiff.h"
22+
#include "swift/AST/Attr.h"
2223
#include "swift/AST/DeclContext.h"
2324
#include "swift/AST/GenericParamKey.h"
2425
#include "swift/AST/Identifier.h"
@@ -3093,7 +3094,7 @@ class AnyFunctionType : public TypeBase {
30933094
///
30943095
/// If `makeSelfParamFirst` is true, self's tangent is reordered to appear
30953096
/// first. This should be used during type-checking, e.g. type-checking
3096-
/// `@differentiable` and `@differentiating` attributes.
3097+
/// `@differentiable`, `@differentiating`, and `@transposing` attributes.
30973098
///
30983099
/// \note The original function type (`self`) need not be `@differentiable`.
30993100
/// The resulting function will preserve all `ExtInfo` of the original
@@ -3108,6 +3109,13 @@ class AnyFunctionType : public TypeBase {
31083109
/// Given the type of an autodiff associated function, returns the
31093110
/// corresponding original function type.
31103111
AnyFunctionType *getAutoDiffOriginalFunctionType();
3112+
3113+
/// Given the type of a transposing associated function, returns the
3114+
/// corresponding original function type.
3115+
AnyFunctionType *
3116+
getTransposeOriginalFunctionType(TransposingAttr *attr,
3117+
AutoDiffIndexSubset *wrtParamIndices,
3118+
bool wrtSelf);
31113119

31123120
AnyFunctionType *getWithoutDifferentiability() const;
31133121

Diff for: include/swift/Parse/Parser.h

+7
Original file line numberDiff line numberDiff line change
@@ -960,10 +960,17 @@ class Parser {
960960
/// Parse a differentiation parameters clause.
961961
bool parseDifferentiationParametersClause(
962962
SmallVectorImpl<ParsedAutoDiffParameter> &params, StringRef attrName);
963+
964+
/// Parse a transposing parameters clause.
965+
bool parseTransposingParametersClause(
966+
SmallVectorImpl<ParsedAutoDiffParameter> &params, StringRef attrName);
963967

964968
/// Parse the @differentiating attribute.
965969
ParserResult<DifferentiatingAttr>
966970
parseDifferentiatingAttribute(SourceLoc AtLoc, SourceLoc Loc);
971+
972+
ParserResult<TransposingAttr> parseTransposingAttribute(SourceLoc AtLoc,
973+
SourceLoc Loc);
967974

968975
/// Parse a specific attribute.
969976
ParserStatus parseDeclAttribute(DeclAttributes &Attributes, SourceLoc AtLoc);

Diff for: include/swift/Serialization/ModuleFormat.h

+9
Original file line numberDiff line numberDiff line change
@@ -1670,6 +1670,15 @@ namespace decls_block {
16701670
DeclIDField, // Original function declaration.
16711671
BCArray<BCFixed<1>> // Differentiation parameter indices' bitvector.
16721672
>;
1673+
1674+
// SWIFT_ENABLE_TENSORFLOW
1675+
using TransposingDeclAttrLayout = BCRecordLayout<
1676+
Transposing_DECL_ATTR,
1677+
BCFixed<1>, // Implicit flag.
1678+
IdentifierIDField, // Original name.
1679+
DeclIDField, // Original function declaration.
1680+
BCArray<BCFixed<1>> // Differentiation parameter indices' bitvector.
1681+
>;
16731682

16741683
#define SIMPLE_DECL_ATTR(X, CLASS, ...) \
16751684
using CLASS##DeclAttrLayout = BCRecordLayout< \

Diff for: lib/AST/Attr.cpp

+121-2
Original file line numberDiff line numberDiff line change
@@ -389,6 +389,67 @@ static std::string getDifferentiationParametersClauseString(
389389
return printer.str();
390390
}
391391

392+
// Returns the differentiation parameters clause string for the given function,
393+
// parameter indices, and parsed parameters.
394+
static std::string getTransposingParametersClauseString(
395+
const AbstractFunctionDecl *function, AutoDiffIndexSubset *indices,
396+
ArrayRef<ParsedAutoDiffParameter> parsedParams) {
397+
bool isInstanceMethod = function && function->isInstanceMember();
398+
399+
std::string result;
400+
llvm::raw_string_ostream printer(result);
401+
402+
// Use parameters from `AutoDiffIndexSubset`, if specified.
403+
if (indices) {
404+
SmallBitVector parameters(indices->getBitVector());
405+
auto parameterCount = parameters.count();
406+
printer << "wrt: ";
407+
if (parameterCount > 1)
408+
printer << '(';
409+
// Check if differentiating wrt `self`. If so, manually print it first.
410+
if (isInstanceMethod && parameters.test(parameters.size() - 1)) {
411+
parameters.reset(parameters.size() - 1);
412+
printer << "self";
413+
if (parameters.any())
414+
printer << ", ";
415+
}
416+
// Print remaining differentiation parameters.
417+
interleave(parameters.set_bits(), [&](unsigned index) { printer << index; },
418+
[&] { printer << ", "; });
419+
if (parameterCount > 1)
420+
printer << ')';
421+
}
422+
// Otherwise, use the parsed parameters.
423+
else if (!parsedParams.empty()) {
424+
printer << "wrt: ";
425+
if (parsedParams.size() > 1)
426+
printer << '(';
427+
interleave(
428+
parsedParams,
429+
[&](const ParsedAutoDiffParameter &param) {
430+
switch (param.getKind()) {
431+
case ParsedAutoDiffParameter::Kind::Named:
432+
printer << param.getName();
433+
break;
434+
case ParsedAutoDiffParameter::Kind::Self:
435+
printer << "self";
436+
break;
437+
case ParsedAutoDiffParameter::Kind::Ordered:
438+
assert((param.getIndex() < function->getParameters()->size()) &&
439+
"'wrt:' parameter index should be less than the number "
440+
"of parameters");
441+
auto *funcParam = function->getParameters()->get(param.getIndex());
442+
printer << funcParam->getNameStr();
443+
break;
444+
}
445+
},
446+
[&] { printer << ", "; });
447+
if (parsedParams.size() > 1)
448+
printer << ')';
449+
}
450+
return printer.str();
451+
}
452+
392453
// SWIFT_ENABLE_TENSORFLOW
393454
// Print the arguments of the given `@differentiable` attribute.
394455
static void printDifferentiableAttrArguments(
@@ -803,6 +864,22 @@ bool DeclAttribute::printImpl(ASTPrinter &Printer, const PrintOptions &Options,
803864
Printer << ')';
804865
break;
805866
}
867+
868+
// SWIFT_ENABLE_TENSORFLOW
869+
case DAK_Transposing: {
870+
Printer.printAttrName("@transposing");
871+
Printer << '(';
872+
auto *attr = cast<TransposingAttr>(this);
873+
auto *transpose = dyn_cast_or_null<AbstractFunctionDecl>(D);
874+
Printer << attr->getOriginal().Name;
875+
auto diffParamsString = getTransposingParametersClauseString(
876+
transpose, attr->getParameterIndexSubset(),
877+
attr->getParsedParameters());
878+
if (!diffParamsString.empty())
879+
Printer << ", " << diffParamsString;
880+
Printer << ')';
881+
break;
882+
}
806883

807884
case DAK_DynamicReplacement: {
808885
Printer.printAttrName("@_dynamicReplacement");
@@ -955,6 +1032,8 @@ StringRef DeclAttribute::getAttrName() const {
9551032
return "differentiable";
9561033
case DAK_Differentiating:
9571034
return "differentiating";
1035+
case DAK_Transposing:
1036+
return "transposing";
9581037
}
9591038
llvm_unreachable("bad DeclAttrKind");
9601039
}
@@ -1438,9 +1517,49 @@ DifferentiatingAttr::create(ASTContext &context, bool implicit,
14381517
std::move(original), linear, indices);
14391518
}
14401519

1520+
TransposingAttr::TransposingAttr(ASTContext &context, bool implicit,
1521+
SourceLoc atLoc, SourceRange baseRange,
1522+
TypeRepr *baseType, DeclNameWithLoc original,
1523+
ArrayRef<ParsedAutoDiffParameter> params)
1524+
: DeclAttribute(DAK_Transposing, atLoc, baseRange, implicit),
1525+
BaseType(baseType), Original(std::move(original)),
1526+
NumParsedParameters(params.size()) {
1527+
std::uninitialized_copy(params.begin(), params.end(),
1528+
getTrailingObjects<ParsedAutoDiffParameter>());
1529+
}
1530+
1531+
TransposingAttr::TransposingAttr(ASTContext &context, bool implicit,
1532+
SourceLoc atLoc, SourceRange baseRange,
1533+
TypeRepr *baseType, DeclNameWithLoc original,
1534+
AutoDiffIndexSubset *indices)
1535+
: DeclAttribute(DAK_Differentiating, atLoc, baseRange, implicit),
1536+
BaseType(baseType), Original(std::move(original)),
1537+
ParameterIndexSubset(indices) {}
1538+
1539+
TransposingAttr *
1540+
TransposingAttr::create(ASTContext &context, bool implicit, SourceLoc atLoc,
1541+
SourceRange baseRange, TypeRepr *baseType,
1542+
DeclNameWithLoc original,
1543+
ArrayRef<ParsedAutoDiffParameter> params) {
1544+
unsigned size = totalSizeToAlloc<ParsedAutoDiffParameter>(params.size());
1545+
void *mem = context.Allocate(size, alignof(TransposingAttr));
1546+
return new (mem) TransposingAttr(context, implicit, atLoc, baseRange,
1547+
baseType, std::move(original), params);
1548+
}
1549+
1550+
TransposingAttr *
1551+
TransposingAttr::create(ASTContext &context, bool implicit, SourceLoc atLoc,
1552+
SourceRange baseRange, TypeRepr *baseType,
1553+
DeclNameWithLoc original,
1554+
AutoDiffIndexSubset *indices) {
1555+
void *mem =
1556+
context.Allocate(sizeof(TransposingAttr), alignof(TransposingAttr));
1557+
return new (mem) TransposingAttr(context, implicit, atLoc, baseRange,
1558+
baseType, std::move(original), indices);
1559+
}
1560+
14411561
ImplementsAttr::ImplementsAttr(SourceLoc atLoc, SourceRange range,
1442-
TypeLoc ProtocolType,
1443-
DeclName MemberName,
1562+
TypeLoc ProtocolType, DeclName MemberName,
14441563
DeclNameLoc MemberNameLoc)
14451564
: DeclAttribute(DAK_Implements, atLoc, range, /*Implicit=*/false),
14461565
ProtocolType(ProtocolType),

0 commit comments

Comments
 (0)