Skip to content

Commit baebc7d

Browse files
committed
Handle mark_dependence in Differentiation
1 parent 61ce3a3 commit baebc7d

File tree

6 files changed

+108
-96
lines changed

6 files changed

+108
-96
lines changed

lib/SILOptimizer/Analysis/DifferentiableActivityAnalysis.cpp

+8-5
Original file line numberDiff line numberDiff line change
@@ -429,12 +429,15 @@ void DifferentiableActivityInfo::setUsefulThroughArrayInitialization(
429429
continue;
430430
// The second tuple field of the return value is the `RawPointer`.
431431
for (auto use : dti->getResult(1)->getUses()) {
432-
// The `RawPointer` passes through a `pointer_to_address`. That
433-
// instruction's first use is a `store` whose source is useful; its
432+
// The `RawPointer` passes through a `mark_dependence(pointer_to_address`.
433+
// That instruction's first use is a `store` whose source is useful; its
434434
// subsequent uses are `index_addr`s whose only use is a useful `store`.
435-
auto *ptai = dyn_cast<PointerToAddressInst>(use->getUser());
436-
assert(ptai && "Expected `pointer_to_address` user for uninitialized "
437-
"array intrinsic");
435+
auto *mdi = dyn_cast<MarkDependenceInst>(use->getUser());
436+
assert(
437+
mdi &&
438+
"Expected a mark_dependence user for uninitialized array intrinsic.");
439+
auto *ptai = dyn_cast<PointerToAddressInst>(getSingleNonDebugUser(mdi));
440+
assert(ptai && "Expected a pointer_to_address.");
438441
setUseful(ptai, dependentVariableIndex);
439442
// Propagate usefulness through array element addresses:
440443
// `pointer_to_address` and `index_addr` instructions.

lib/SILOptimizer/Differentiation/Common.cpp

+6-18
Original file line numberDiff line numberDiff line change
@@ -37,30 +37,18 @@ ApplyInst *getAllocateUninitializedArrayIntrinsicElementAddress(SILValue v) {
3737
ptai = dyn_cast<PointerToAddressInst>(iai->getOperand(0));
3838
if (!ptai)
3939
return nullptr;
40+
auto *mdi = dyn_cast<MarkDependenceInst>(
41+
ptai->getOperand()->getDefiningInstruction());
42+
if (!mdi)
43+
return nullptr;
4044
// Return the `array.uninitialized_intrinsic` application, if it exists.
4145
if (auto *dti = dyn_cast<DestructureTupleInst>(
42-
ptai->getOperand()->getDefiningInstruction()))
46+
mdi->getValue()->getDefiningInstruction()))
4347
return ArraySemanticsCall(dti->getOperand(),
4448
semantics::ARRAY_UNINITIALIZED_INTRINSIC);
4549
return nullptr;
4650
}
4751

48-
DestructureTupleInst *getSingleDestructureTupleUser(SILValue value) {
49-
bool foundDestructureTupleUser = false;
50-
if (!value->getType().is<TupleType>())
51-
return nullptr;
52-
DestructureTupleInst *result = nullptr;
53-
for (auto *use : value->getUses()) {
54-
if (auto *dti = dyn_cast<DestructureTupleInst>(use->getUser())) {
55-
assert(!foundDestructureTupleUser &&
56-
"There should only be one `destructure_tuple` user of a tuple");
57-
foundDestructureTupleUser = true;
58-
result = dti;
59-
}
60-
}
61-
return result;
62-
}
63-
6452
bool isSemanticMemberAccessor(SILFunction *original) {
6553
auto *dc = original->getDeclContext();
6654
if (!dc)
@@ -109,7 +97,7 @@ void forEachApplyDirectResult(
10997
resultCallback(ai);
11098
return;
11199
}
112-
if (auto *dti = getSingleDestructureTupleUser(ai))
100+
if (auto *dti = ai->getSingleUserOfType<DestructureTupleInst>())
113101
for (auto directResult : dti->getResults())
114102
resultCallback(directResult);
115103
break;

lib/SILOptimizer/Differentiation/JVPCloner.cpp

+2-1
Original file line numberDiff line numberDiff line change
@@ -1312,7 +1312,8 @@ class JVPCloner::Implementation final
13121312
if (!origResult->getType().is<TupleType>()) {
13131313
setTangentValue(bb, origResult,
13141314
makeConcreteTangentValue(differentialResult));
1315-
} else if (auto *dti = getSingleDestructureTupleUser(ai)) {
1315+
} else if (auto *dti =
1316+
ai->getSingleUserOfType<DestructureTupleInst>()) {
13161317
bool notSetValue = true;
13171318
for (auto result : dti->getResults()) {
13181319
if (activityInfo.isActive(result, getConfig())) {

lib/SILOptimizer/Differentiation/PullbackCloner.cpp

+5-1
Original file line numberDiff line numberDiff line change
@@ -3331,7 +3331,11 @@ void PullbackCloner::Implementation::
33313331
builder.setCurrentDebugScope(remapScope(dti->getDebugScope()));
33323332
builder.setInsertionPoint(arrayAdjoint->getParentBlock());
33333333
for (auto use : dti->getResult(1)->getUses()) {
3334-
auto *ptai = dyn_cast<PointerToAddressInst>(use->getUser());
3334+
auto *mdi = dyn_cast<MarkDependenceInst>(use->getUser());
3335+
assert(mdi && "Expected mark_dependence user");
3336+
auto *ptai =
3337+
dyn_cast_or_null<PointerToAddressInst>(getSingleNonDebugUser(mdi));
3338+
assert(ptai && "Expected pointer_to_address user");
33353339
auto adjBuf = getAdjointBuffer(origBB, ptai);
33363340
auto *eltAdjBuf = getArrayAdjointElementBuffer(arrayAdjoint, 0, loc);
33373341
builder.emitInPlaceAdd(loc, adjBuf, eltAdjBuf);

lib/SILOptimizer/Transforms/ArrayElementValuePropagation.cpp

+6
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,12 @@ bool ArrayAllocation::recursivelyCollectUses(ValueBase *Def) {
151151
continue;
152152
}
153153

154+
if (auto *MDI = dyn_cast<MarkDependenceInst>(User)) {
155+
if (Def != MDI->getBase())
156+
return false;
157+
continue;
158+
}
159+
154160
// Check array semantic calls.
155161
ArraySemanticsCall ArrayOp(User);
156162
switch (ArrayOp.getKind()) {

0 commit comments

Comments
 (0)