@@ -436,6 +436,18 @@ bool COWArrayOpt::checkSafeArrayAddressUses(UserList &AddressUsers) {
436
436
return true ;
437
437
}
438
438
439
+ template <typename UserRange>
440
+ ArraySemanticsCall getEndMutationCall (const UserRange &AddressUsers) {
441
+ for (auto *UseInst : AddressUsers) {
442
+ if (auto *AI = dyn_cast<ApplyInst>(UseInst)) {
443
+ ArraySemanticsCall ASC (AI);
444
+ if (ASC.getKind () == ArrayCallKind::kEndMutation )
445
+ return ASC;
446
+ }
447
+ }
448
+ return ArraySemanticsCall ();
449
+ }
450
+
439
451
// / Returns true if this instruction is a safe array use if all of its users are
440
452
// / also safe array users.
441
453
static SILValue isTransitiveSafeUser (SILInstruction *I) {
@@ -811,8 +823,14 @@ void COWArrayOpt::hoistAddressProjections(Operand &ArrayOp) {
811
823
}
812
824
}
813
825
814
- // / Check if this call to "make_mutable" is hoistable, and move it, or delete it
815
- // / if it's already hoisted.
826
+ // / Check if this call to "make_mutable" is hoistable, and copy it, along with
827
+ // / the corresponding end_mutation call, to the loop pre-header.
828
+ // /
829
+ // / The origial make_mutable/end_mutation calls remain in the loop, because
830
+ // / removing them would violate the COW representation rules.
831
+ // / Having those calls in the pre-header will then enable COWOpts (after
832
+ // / inlining) to constant fold the uniqueness check of the begin_cow_mutation
833
+ // / in the loop.
816
834
bool COWArrayOpt::hoistMakeMutable (ArraySemanticsCall MakeMutable,
817
835
bool dominatesExits) {
818
836
LLVM_DEBUG (llvm::dbgs () << " Checking mutable array: " <<CurrentArrayAddr);
@@ -872,6 +890,18 @@ bool COWArrayOpt::hoistMakeMutable(ArraySemanticsCall MakeMutable,
872
890
return false ;
873
891
}
874
892
893
+ auto ArrayUsers = llvm::map_range (MakeMutable.getSelf ()->getUses (),
894
+ ValueBase::UseToUser ());
895
+
896
+ // There should be a call to end_mutation. Find it so that we can copy it to
897
+ // the pre-header.
898
+ ArraySemanticsCall EndMutation = getEndMutationCall (ArrayUsers);
899
+ if (!EndMutation) {
900
+ EndMutation = getEndMutationCall (StructUses.StructAddressUsers );
901
+ if (!EndMutation)
902
+ return false ;
903
+ }
904
+
875
905
// Hoist the make_mutable.
876
906
LLVM_DEBUG (llvm::dbgs () << " Hoisting make_mutable: " << *MakeMutable);
877
907
@@ -880,12 +910,18 @@ bool COWArrayOpt::hoistMakeMutable(ArraySemanticsCall MakeMutable,
880
910
assert (MakeMutable.canHoist (Preheader->getTerminator (), DomTree) &&
881
911
" Should be able to hoist make_mutable" );
882
912
883
- MakeMutable.hoist (Preheader->getTerminator (), DomTree);
913
+ // Copy the make_mutable and end_mutation calls to the pre-header.
914
+ TermInst *insertionPoint = Preheader->getTerminator ();
915
+ ApplyInst *hoistedMM = MakeMutable.copyTo (insertionPoint, DomTree);
916
+ ApplyInst *EMInst = EndMutation;
917
+ ApplyInst *hoistedEM = cast<ApplyInst>(EMInst->clone (insertionPoint));
918
+ hoistedEM->setArgument (0 , hoistedMM->getArgument (0 ));
919
+ placeFuncRef (hoistedEM, DomTree);
884
920
885
921
// Register array loads. This is needed for hoisting make_mutable calls of
886
922
// inner arrays in the two-dimensional case.
887
923
if (arrayContainerIsUnique &&
888
- StructUses.hasSingleAddressUse ((ApplyInst *)MakeMutable)) {
924
+ StructUses.hasOnlyAddressUses ((ApplyInst *)MakeMutable, EMInst )) {
889
925
for (auto use : MakeMutable.getSelf ()->getUses ()) {
890
926
if (auto *LI = dyn_cast<LoadInst>(use->getUser ()))
891
927
HoistableLoads.insert (LI);
@@ -917,39 +953,33 @@ bool COWArrayOpt::run() {
917
953
// is only mapped to a call once the analysis has determined that no
918
954
// make_mutable calls are required within the loop body for that array.
919
955
llvm::SmallDenseMap<SILValue, ApplyInst*> ArrayMakeMutableMap;
920
-
956
+
957
+ llvm::SmallVector<ArraySemanticsCall, 8 > makeMutableCalls;
958
+
921
959
for (auto *BB : Loop->getBlocks ()) {
922
960
if (ColdBlocks.isCold (BB))
923
961
continue ;
924
- bool dominatesExits = dominatesExitingBlocks (BB);
925
- for ( auto II = BB-> begin (), IE = BB-> end (); II != IE;) {
926
- // Inst may be moved by hoistMakeMutable .
927
- SILInstruction *Inst = &*II;
928
- ++II ;
929
- ArraySemanticsCall MakeMutableCall (Inst, " array.make_mutable " );
930
- if (! MakeMutableCall)
931
- continue ;
962
+
963
+ // Instructions are getting moved around. To not mess with iterator
964
+ // invalidation, first collect all calls, and then do the transformation .
965
+ for ( SILInstruction &I : *BB) {
966
+ ArraySemanticsCall MakeMutableCall (&I, " array.make_mutable " ) ;
967
+ if (MakeMutableCall)
968
+ makeMutableCalls. push_back ( MakeMutableCall);
969
+ }
932
970
971
+ bool dominatesExits = dominatesExitingBlocks (BB);
972
+ for (ArraySemanticsCall MakeMutableCall : makeMutableCalls) {
933
973
CurrentArrayAddr = MakeMutableCall.getSelf ();
934
974
auto HoistedCallEntry = ArrayMakeMutableMap.find (CurrentArrayAddr);
935
975
if (HoistedCallEntry == ArrayMakeMutableMap.end ()) {
936
- if (!hoistMakeMutable (MakeMutableCall, dominatesExits)) {
976
+ if (hoistMakeMutable (MakeMutableCall, dominatesExits)) {
977
+ ArrayMakeMutableMap[CurrentArrayAddr] = MakeMutableCall;
978
+ HasChanged = true ;
979
+ } else {
937
980
ArrayMakeMutableMap[CurrentArrayAddr] = nullptr ;
938
- continue ;
939
981
}
940
-
941
- ArrayMakeMutableMap[CurrentArrayAddr] = MakeMutableCall;
942
- HasChanged = true ;
943
- continue ;
944
982
}
945
-
946
- if (!HoistedCallEntry->second )
947
- continue ;
948
-
949
- LLVM_DEBUG (llvm::dbgs () << " Removing make_mutable call: "
950
- << *MakeMutableCall);
951
- MakeMutableCall.removeCall ();
952
- HasChanged = true ;
953
983
}
954
984
}
955
985
return HasChanged;
0 commit comments