Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[AutoDiff] [stdlib] Deprecate 'CotangentVector' in favor of 'TangentVector'. #24825

Merged
merged 3 commits into from
May 16, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 1 addition & 3 deletions include/swift/AST/ASTContext.h
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,6 @@ namespace swift {
class VarDecl;
class UnifiedStatsReporter;
// SWIFT_ENABLE_TENSORFLOW
enum class AutoDiffAssociatedVectorSpaceKind : unsigned;
class VectorSpace;
class AutoDiffParameterIndices;
class DifferentiableAttr;
Expand Down Expand Up @@ -276,8 +275,7 @@ class ASTContext final {
llvm::StringMap<Type> RemappedTypes;

/// Cache of autodiff-associated vector spaces.
llvm::DenseMap<std::pair<Type, unsigned>,
Optional<VectorSpace>> AutoDiffVectorSpaces;
llvm::DenseMap<Type, Optional<VectorSpace>> AutoDiffVectorSpaces;

/// Cache of `@differentiable` attributes keyed by parameter indices. This
/// helps us diagnose multiple `@differentiable`s that are with respect to the
Expand Down
5 changes: 0 additions & 5 deletions include/swift/AST/AutoDiff.h
Original file line number Diff line number Diff line change
Expand Up @@ -551,11 +551,6 @@ class AutoDiffAssociatedFunctionIdentifier : public llvm::FoldingSetNode {
}
};

/// The kind of an associated type.
enum class AutoDiffAssociatedVectorSpaceKind : unsigned {
Tangent = 0, Cotangent = 1
};

/// Automatic differentiation utility namespace.
namespace autodiff {

Expand Down
2 changes: 1 addition & 1 deletion include/swift/AST/DiagnosticsSema.def
Original file line number Diff line number Diff line change
Expand Up @@ -2715,7 +2715,7 @@ NOTE(protocol_witness_missing_specific_differentiable_attr,none,
// @differentiating
ERROR(differentiating_attr_expected_result_tuple,none,
"'@differentiating' attribute requires function to return a two-element tuple of type "
"'(value: T..., pullback: (U.CotangentVector) -> T.CotangentVector...)' or "
"'(value: T..., pullback: (U.TangentVector) -> T.TangentVector...)' or "
"'(value: T..., differential: (T.TangentVector...) -> U.TangentVector)'", ())
ERROR(differentiating_attr_invalid_result_tuple_value_label,none,
"'@differentiating' attribute requires function to return a two-element "
Expand Down
2 changes: 0 additions & 2 deletions include/swift/AST/KnownIdentifiers.def
Original file line number Diff line number Diff line change
Expand Up @@ -135,11 +135,9 @@ IDENTIFIER(zero)
IDENTIFIER(Scalar)
// Differentiable
IDENTIFIER(AllDifferentiableVariables)
IDENTIFIER(CotangentVector)
IDENTIFIER(TangentVector)
IDENTIFIER(allDifferentiableVariables)
IDENTIFIER(moved)
IDENTIFIER(tangentVector)

// Kinds of layout constraints
IDENTIFIER_WITH_NAME(UnknownLayout, "_UnknownLayout")
Expand Down
3 changes: 0 additions & 3 deletions include/swift/AST/KnownProtocols.def
Original file line number Diff line number Diff line change
Expand Up @@ -86,9 +86,6 @@ PROTOCOL(TensorGroup)
PROTOCOL_(TensorFlowDataTypeCompatible)
PROTOCOL(TensorProtocol)
PROTOCOL(VectorNumeric)
// TODO(TF-213): Remove underscore `Differentiable` protocols.
PROTOCOL(__Differentiable)
PROTOCOL(_Differentiable)
PROTOCOL(Differentiable)

PROTOCOL_(ObjectiveCBridgeable)
Expand Down
29 changes: 13 additions & 16 deletions include/swift/AST/Types.h
Original file line number Diff line number Diff line change
Expand Up @@ -1096,20 +1096,17 @@ class alignas(1 << TypeAlignInBits) TypeBase {
TypeTraitResult canBeClass();

// SWIFT_ENABLE_TENSORFLOW
/// Return the associated tangent or cotangent type. Return the null type if
/// there is no associated tangent/cotangent type.
///
/// `kind` specifies whether to return the tangent or cotangent type.
/// Return the associated tangent type. Return the null type if there is no
/// associated tangent type.
///
/// If the type conforms to `Differentiable`, then the associated
/// tangent/cotangent type is the associated `TangentVector`/`CotangentVector`
/// from the `Differentiable` requirement. If the type is a tuple, then the
/// associated tangent/cotangent type is the elementwise tangent/cotangent
/// type of its elements. If the type is a builtin float, then the associated
/// tangent/cotangent type is itself. Otherwise, there is no associated type.
/// tangent type is the associated `TangentVector` from the `Differentiable`
/// requirement. If the type is a tuple, then the associated tangent type is
/// the elementwise tangent type of its elements. If the type is a builtin
/// float, then the associated tangent type is itself. Otherwise, there is no
/// associated type.
Optional<VectorSpace>
getAutoDiffAssociatedVectorSpace(AutoDiffAssociatedVectorSpaceKind kind,
LookupConformanceFn lookupConformance);
getAutoDiffAssociatedTangentSpace(LookupConformanceFn lookupConformance);

private:
// Make vanilla new/delete illegal for Types.
Expand Down Expand Up @@ -3074,12 +3071,12 @@ class AnyFunctionType : public TypeBase {
///
/// By default, if the original type has a self parameter list and parameter
/// indices include self, the computed associated function type will return a
/// linear map taking/returning self's tangent/cotangent *last* instead of
/// first, for consistency with SIL.
/// linear map taking/returning self's tangent *last* instead of first, for
/// consistency with SIL.
///
/// If `makeSelfParamFirst` is true, self's tangent/cotangent is reordered to
/// appear first. This should be used during type-checking, e.g.
/// type-checking `@differentiable` and `@differentiating` attributes.
/// If `makeSelfParamFirst` is true, self's tangent is reordered to appear
/// first. This should be used during type-checking, e.g. type-checking
/// `@differentiable` and `@differentiating` attributes.
///
/// \note The original function type (`self`) need not be `@differentiable`.
/// The resulting function will preserve all `ExtInfo` of the original
Expand Down
2 changes: 1 addition & 1 deletion lib/AST/Builtins.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -985,7 +985,7 @@ static ValueDecl *getAutoDiffApplyAssociatedFunction(
// rethrows -> (R, (...T.TangentVector) -> R.TangentVector)
// VJP:
// <...T...(arity), R> (@differentiable (...T) throws -> R, ...T)
// rethrows -> (R, (R.CotangentVector) -> ...T.CotangentVector)
// rethrows -> (R, (R.TangentVector) -> ...T.TangentVector)
unsigned numGenericParams = 1 + arity;
BuiltinGenericSignatureBuilder builder(Context, numGenericParams);
// Look up the Differentiable protocol.
Expand Down
72 changes: 28 additions & 44 deletions lib/AST/Type.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4364,13 +4364,12 @@ makeFunctionType(AnyFunctionType *copy, ArrayRef<AnyFunctionType::Param> params,
return FunctionType::get(params, retTy, copy->getExtInfo());
}

Optional<VectorSpace> TypeBase::getAutoDiffAssociatedVectorSpace(
AutoDiffAssociatedVectorSpaceKind kind,
Optional<VectorSpace> TypeBase::getAutoDiffAssociatedTangentSpace(
LookupConformanceFn lookupConformance) {
assert(lookupConformance);
auto &ctx = getASTContext();

std::pair<Type, unsigned> cacheKey {this, (unsigned)kind};
Type cacheKey = this;
auto lookup = ctx.AutoDiffVectorSpaces.find(cacheKey);
if (lookup != ctx.AutoDiffVectorSpaces.end())
return lookup->getSecond();
Expand All @@ -4379,24 +4378,24 @@ Optional<VectorSpace> TypeBase::getAutoDiffAssociatedVectorSpace(
return vs;
};

// Functions' tangent/cotangent is the same function except the innermost
// return type being replaced by its tangent/cotangent.
// Functions' tangent is the same function except the innermost return type
// being replaced by its tangent.
if (auto *fnTy = getAs<AnyFunctionType>()) {
auto resultSpace = fnTy->getResult()->getAutoDiffAssociatedVectorSpace(
kind, lookupConformance);
auto resultSpace = fnTy->getResult()->getAutoDiffAssociatedTangentSpace(
lookupConformance);
if (!resultSpace)
return cache(None);
return cache(VectorSpace::getFunction(
makeFunctionType(fnTy, fnTy->getParams(), resultSpace->getType(),
fnTy->getOptGenericSignature())));
}

// Tuples' tangent/cotangent is a tuple of each element's Tangent/Cotangent.
// Tuples' tangent is a tuple of each element's Tangent.
if (auto *tupleTy = getAs<TupleType>()) {
SmallVector<TupleTypeElt, 8> newElts;
for (auto elt : tupleTy->getElements()) {
auto eltSpace = elt.getType()
->getAutoDiffAssociatedVectorSpace(kind, lookupConformance);
->getAutoDiffAssociatedTangentSpace(lookupConformance);
if (!eltSpace)
continue;
newElts.push_back(elt.getWithType(eltSpace->getType()));
Expand All @@ -4410,22 +4409,12 @@ Optional<VectorSpace> TypeBase::getAutoDiffAssociatedVectorSpace(
return cache(VectorSpace::getTuple(tupleType));
}

// Find the TangentVector/CotangentVector associated type on the
// Differentiable protocol.
// Find the TangentVector associated type on the Differentiable protocol.
auto *differentiableProtocol =
ctx.getProtocol(KnownProtocolKind::__Differentiable);
assert(differentiableProtocol && "Could not find __Differentiable protocol");
Identifier associatedTypeIdentifier;
switch (kind) {
case AutoDiffAssociatedVectorSpaceKind::Tangent:
associatedTypeIdentifier = ctx.Id_TangentVector;
break;
case AutoDiffAssociatedVectorSpaceKind::Cotangent:
associatedTypeIdentifier = ctx.Id_CotangentVector;
break;
}
ctx.getProtocol(KnownProtocolKind::Differentiable);
assert(differentiableProtocol && "Could not find Differentiable protocol");
auto associatedTypeLookup =
differentiableProtocol->lookupDirect(associatedTypeIdentifier);
differentiableProtocol->lookupDirect(ctx.Id_TangentVector);
assert(associatedTypeLookup.size() == 1);
auto *dependentType = DependentMemberType::get(
differentiableProtocol->getDeclaredInterfaceType(),
Expand All @@ -4448,7 +4437,7 @@ AnyFunctionType *AnyFunctionType::getAutoDiffAssociatedFunctionType(
// JVP: (T...) -> ((R...),
// (T.TangentVector...) -> (R.TangentVector...))
// VJP: (T...) -> ((R...),
// (R.CotangentVector...) -> (T.CotangentVector...))
// (R.TangentVector...) -> (T.TangentVector...))
//
// Note that both can be written as "(T...) -> ((R...), Closure)", so we build
// "Closure" and then use common code to wrap "Closure" in the outer function
Expand Down Expand Up @@ -4487,23 +4476,20 @@ AnyFunctionType *AnyFunctionType::getAutoDiffAssociatedFunctionType(
SmallVector<AnyFunctionType::Param, 8> differentialParams;
for (auto wrtParamType : wrtParamTypes)
differentialParams.push_back(
AnyFunctionType::Param(wrtParamType->getAutoDiffAssociatedVectorSpace(
AutoDiffAssociatedVectorSpaceKind::Tangent, lookupConformance)
AnyFunctionType::Param(
wrtParamType->getAutoDiffAssociatedTangentSpace(lookupConformance)
->getType()));

SmallVector<TupleTypeElt, 8> differentialResults;
if (auto *resultTuple = originalResult->getAs<TupleType>()) {
auto resultTupleEltType = resultTuple->getElementType(resultIndex);
differentialResults.push_back(
resultTupleEltType->getAutoDiffAssociatedVectorSpace(
AutoDiffAssociatedVectorSpaceKind::Tangent, lookupConformance)
->getType());
differentialResults.push_back(resultTupleEltType
->getAutoDiffAssociatedTangentSpace(lookupConformance)->getType());
} else {
assert(resultIndex == 0 && "resultIndex out of bounds");
differentialResults.push_back(
originalResult->getAutoDiffAssociatedVectorSpace(
AutoDiffAssociatedVectorSpaceKind::Tangent, lookupConformance)
->getType());
originalResult->getAutoDiffAssociatedTangentSpace(lookupConformance)
->getType());
}
Type differentialResult =
differentialResults.size() > 1
Expand All @@ -4515,28 +4501,26 @@ AnyFunctionType *AnyFunctionType::getAutoDiffAssociatedFunctionType(
}
case AutoDiffAssociatedFunctionKind::VJP: {
// closure is the VJP "pullback":
// (R.CotangentVector...) -> (T.CotangentVector...)
// (R.TangentVector...) -> (T.TangentVector...)
SmallVector<AnyFunctionType::Param, 8> pullbackParams;
if (auto *resultTuple = originalResult->getAs<TupleType>()) {
auto resultTupleEltType = resultTuple->getElementType(resultIndex);
pullbackParams.push_back(
AnyFunctionType::Param(
resultTupleEltType->getAutoDiffAssociatedVectorSpace(
AutoDiffAssociatedVectorSpaceKind::Cotangent,
lookupConformance)->getType()));
AnyFunctionType::Param(resultTupleEltType
->getAutoDiffAssociatedTangentSpace(lookupConformance)
->getType()));
} else {
assert(resultIndex == 0 && "resultIndex out of bounds");
pullbackParams.push_back(
AnyFunctionType::Param(
originalResult->getAutoDiffAssociatedVectorSpace(
AutoDiffAssociatedVectorSpaceKind::Cotangent,
lookupConformance)->getType()));
AnyFunctionType::Param(originalResult
->getAutoDiffAssociatedTangentSpace(lookupConformance)
->getType()));
}

SmallVector<TupleTypeElt, 8> pullbackResults;
for (auto wrtParamType : wrtParamTypes)
pullbackResults.push_back(wrtParamType->getAutoDiffAssociatedVectorSpace(
AutoDiffAssociatedVectorSpaceKind::Cotangent, lookupConformance)
pullbackResults.push_back(wrtParamType
->getAutoDiffAssociatedTangentSpace(lookupConformance)
->getType());
Type pullbackResult = pullbackResults.size() > 1
? TupleType::get(pullbackResults, ctx)
Expand Down
3 changes: 0 additions & 3 deletions lib/IRGen/GenMeta.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4192,9 +4192,6 @@ SpecialProtocol irgen::getSpecialProtocolID(ProtocolDecl *P) {
case KnownProtocolKind::TensorFlowDataTypeCompatible:
case KnownProtocolKind::TensorProtocol:
case KnownProtocolKind::VectorNumeric:
// TODO(TF-213): Remove underscore `Differentiable` protocols.
case KnownProtocolKind::__Differentiable:
case KnownProtocolKind::_Differentiable:
case KnownProtocolKind::Differentiable:
return SpecialProtocol::None;
}
Expand Down
52 changes: 25 additions & 27 deletions lib/SIL/SILFunctionType.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ CanSILFunctionType SILFunctionType::getAutoDiffAssociatedFunctionType(
// JVP: (T...) -> ((R...),
// (T.TangentVector...) -> (R.TangentVector...))
// VJP: (T...) -> ((R...),
// (R.CotangentVector...) -> (T.CotangentVector...))
// (R.TangentVector...) -> (T.TangentVector...))

auto &ctx = getASTContext();
auto &typeConverter = module.Types;
Expand All @@ -164,9 +164,10 @@ CanSILFunctionType SILFunctionType::getAutoDiffAssociatedFunctionType(
whereClauseGenSig = getGenericSignature();

// Given a type, returns its formal SIL parameter info.
auto getCotangentParameterInfoForOriginalResult = [&](
CanType cotanType, ResultConvention origResConv) -> SILParameterInfo {
auto &tl = typeConverter.getTypeLowering(cotanType, ResilienceExpansion::Minimal);
auto getTangentParameterInfoForOriginalResult = [&](
CanType tanType, ResultConvention origResConv) -> SILParameterInfo {
auto &tl = typeConverter.getTypeLowering(tanType,
ResilienceExpansion::Minimal);
ParameterConvention conv;
switch (origResConv) {
case ResultConvention::Owned:
Expand All @@ -183,13 +184,14 @@ CanSILFunctionType SILFunctionType::getAutoDiffAssociatedFunctionType(
conv = ParameterConvention::Indirect_In_Guaranteed;
break;
}
return {cotanType, conv};
return {tanType, conv};
};

// Given a type, returns its formal SIL result info.
auto getCotangentResultInfoForOriginalParameter = [&](
CanType cotanType, ParameterConvention origParamConv) -> SILResultInfo {
auto &tl = typeConverter.getTypeLowering(cotanType, ResilienceExpansion::Minimal);
auto getTangentResultInfoForOriginalParameter = [&](
CanType tanType, ParameterConvention origParamConv) -> SILResultInfo {
auto &tl = typeConverter.getTypeLowering(tanType,
ResilienceExpansion::Minimal);
ResultConvention conv;
switch (origParamConv) {
case ParameterConvention::Direct_Owned:
Expand All @@ -207,7 +209,7 @@ CanSILFunctionType SILFunctionType::getAutoDiffAssociatedFunctionType(
conv = ResultConvention::Indirect;
break;
}
return {cotanType, conv};
return {tanType, conv};
};

// Helper function testing if we are differentiating wrt this index.
Expand All @@ -228,17 +230,15 @@ CanSILFunctionType SILFunctionType::getAutoDiffAssociatedFunctionType(
SmallVector<SILParameterInfo, 8> differentialParams;
for (auto &param : wrtParams) {
differentialParams.push_back(
{param.getType()->getAutoDiffAssociatedVectorSpace(
AutoDiffAssociatedVectorSpaceKind::Tangent, lookupConformance)
->getCanonicalType(),
{param.getType()->getAutoDiffAssociatedTangentSpace(lookupConformance)
->getCanonicalType(),
param.getConvention()});
}
SmallVector<SILResultInfo, 8> differentialResults;
auto &result = getResults()[resultIndex];
differentialResults.push_back(
{result.getType()->getAutoDiffAssociatedVectorSpace(
AutoDiffAssociatedVectorSpaceKind::Tangent, lookupConformance)
->getCanonicalType(),
{result.getType()->getAutoDiffAssociatedTangentSpace(lookupConformance)
->getCanonicalType(),
result.getConvention()});
closureType = SILFunctionType::get(
/*genericSignature*/ nullptr, ExtInfo(), SILCoroutineKind::None,
Expand All @@ -249,22 +249,20 @@ CanSILFunctionType SILFunctionType::getAutoDiffAssociatedFunctionType(
case AutoDiffAssociatedFunctionKind::VJP: {
SmallVector<SILParameterInfo, 8> pullbackParams;
auto &origRes = getResults()[resultIndex];
auto cotangentAssocTy =
origRes.getType()->getAutoDiffAssociatedVectorSpace(
AutoDiffAssociatedVectorSpaceKind::Cotangent, lookupConformance)
->getCanonicalType();
auto tangentAssocTy =
origRes.getType()->getAutoDiffAssociatedTangentSpace(lookupConformance)
->getCanonicalType();
pullbackParams.push_back(
getCotangentParameterInfoForOriginalResult(cotangentAssocTy,
origRes.getConvention()));
getTangentParameterInfoForOriginalResult(tangentAssocTy,
origRes.getConvention()));
SmallVector<SILResultInfo, 8> pullbackResults;
for (auto &param : wrtParams) {
auto paramCotangentTy =
param.getType()->getAutoDiffAssociatedVectorSpace(
AutoDiffAssociatedVectorSpaceKind::Cotangent, lookupConformance)
->getCanonicalType();
auto paramTangentTy =
param.getType()->getAutoDiffAssociatedTangentSpace(lookupConformance)
->getCanonicalType();
pullbackResults.push_back(
getCotangentResultInfoForOriginalParameter(paramCotangentTy,
param.getConvention()));
getTangentResultInfoForOriginalParameter(paramTangentTy,
param.getConvention()));
}
closureType = SILFunctionType::get(
/*genericSignature*/ nullptr, ExtInfo(), SILCoroutineKind::None,
Expand Down
3 changes: 1 addition & 2 deletions lib/SIL/SILType.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -595,7 +595,6 @@ bool SILType::isLoweringOf(SILModule &Mod, CanType formalType) {
// SWIFT_ENABLE_TENSORFLOW
/// Returns true if this SILType is a differentiable type.
bool SILType::isDifferentiable(SILModule &M) const {
return getASTType()->getAutoDiffAssociatedVectorSpace(
AutoDiffAssociatedVectorSpaceKind::Tangent,
return getASTType()->getAutoDiffAssociatedTangentSpace(
LookUpConformanceInModule(M.getSwiftModule())).hasValue();
}
Loading