-
Notifications
You must be signed in to change notification settings - Fork 10.4k
/
Copy pathPullbackCloner.cpp
3624 lines (3328 loc) · 152 KB
/
PullbackCloner.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
//===--- PullbackCloner.cpp - Pullback function generation ---*- C++ -*----===//
//
// This source file is part of the Swift.org open source project
//
// Copyright (c) 2019 - 2020 Apple Inc. and the Swift project authors
// Licensed under Apache License v2.0 with Runtime Library Exception
//
// See https://swift.org/LICENSE.txt for license information
// See https://swift.org/CONTRIBUTORS.txt for the list of Swift project authors
//
//===----------------------------------------------------------------------===//
//
// This file defines a helper class for generating pullback functions for
// automatic differentiation.
//
//===----------------------------------------------------------------------===//
#define DEBUG_TYPE "differentiation"
#include "swift/SILOptimizer/Differentiation/PullbackCloner.h"
#include "swift/SILOptimizer/Analysis/DifferentiableActivityAnalysis.h"
#include "swift/SILOptimizer/Differentiation/ADContext.h"
#include "swift/SILOptimizer/Differentiation/AdjointValue.h"
#include "swift/SILOptimizer/Differentiation/DifferentiationInvoker.h"
#include "swift/SILOptimizer/Differentiation/LinearMapInfo.h"
#include "swift/SILOptimizer/Differentiation/Thunk.h"
#include "swift/SILOptimizer/Differentiation/VJPCloner.h"
#include "swift/AST/ConformanceLookup.h"
#include "swift/AST/Expr.h"
#include "swift/AST/PropertyWrappers.h"
#include "swift/AST/TypeCheckRequests.h"
#include "swift/Basic/Assertions.h"
#include "swift/Basic/STLExtras.h"
#include "swift/SIL/ApplySite.h"
#include "swift/SIL/InstructionUtils.h"
#include "swift/SIL/Projection.h"
#include "swift/SIL/TypeSubstCloner.h"
#include "swift/SILOptimizer/PassManager/PrettyStackTrace.h"
#include "swift/SILOptimizer/Utils/SILOptFunctionBuilder.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/SmallSet.h"
namespace swift {
class SILDifferentiabilityWitness;
class SILBasicBlock;
class SILFunction;
class SILInstruction;
namespace autodiff {
class ADContext;
class VJPCloner;
/// The implementation class for `PullbackCloner`.
///
/// The implementation class is a `SILInstructionVisitor`. Effectively, it acts
/// as a `SILCloner` that visits basic blocks in post-order and that visits
/// instructions per basic block in reverse order. This visitation order is
/// necessary for generating pullback functions, whose control flow graph is
/// ~a transposed version of the original function's control flow graph.
class PullbackCloner::Implementation final
: public SILInstructionVisitor<PullbackCloner::Implementation> {
public:
explicit Implementation(VJPCloner &vjpCloner);
private:
/// The parent VJP cloner.
VJPCloner &vjpCloner;
/// Dominance info for the original function.
DominanceInfo *domInfo = nullptr;
/// Post-dominance info for the original function.
PostDominanceInfo *postDomInfo = nullptr;
/// Post-order info for the original function.
PostOrderFunctionInfo *postOrderInfo = nullptr;
/// Mapping from original basic blocks to corresponding pullback basic blocks.
/// Pullback basic blocks always have the predecessor as the single argument.
llvm::DenseMap<SILBasicBlock *, SILBasicBlock *> pullbackBBMap;
/// Mapping from original basic blocks and original values to corresponding
/// adjoint values.
llvm::DenseMap<std::pair<SILBasicBlock *, SILValue>, AdjointValue> valueMap;
/// Mapping from original basic blocks and original values to corresponding
/// adjoint buffers.
llvm::DenseMap<std::pair<SILBasicBlock *, SILValue>, SILValue> bufferMap;
/// Mapping from pullback struct field declarations to pullback struct
/// elements destructured from the linear map basic block argument. In the
/// beginning of each pullback basic block, the block's pullback struct is
/// destructured into individual elements stored here.
llvm::DenseMap<SILBasicBlock*, SmallVector<SILValue, 4>> pullbackTupleElements;
/// Mapping from original basic blocks and successor basic blocks to
/// corresponding pullback trampoline basic blocks. Trampoline basic blocks
/// take additional arguments in addition to the predecessor enum argument.
llvm::DenseMap<std::pair<SILBasicBlock *, SILBasicBlock *>, SILBasicBlock *>
pullbackTrampolineBBMap;
/// Mapping from original basic blocks to dominated active values.
llvm::DenseMap<SILBasicBlock *, SmallVector<SILValue, 8>> activeValues;
/// Mapping from original basic blocks and original active values to
/// corresponding pullback block arguments.
llvm::DenseMap<std::pair<SILBasicBlock *, SILValue>, SILArgument *>
activeValuePullbackBBArgumentMap;
/// Mapping from original basic blocks to local temporary values to be cleaned
/// up. This is populated when pullback emission is run on one basic block and
/// cleaned before processing another basic block.
llvm::DenseMap<SILBasicBlock *, llvm::SmallSetVector<SILValue, 32>>
blockTemporaries;
/// The scope cloner.
ScopeCloner scopeCloner;
/// The main builder.
TangentBuilder builder;
/// An auxiliary local allocation builder.
TangentBuilder localAllocBuilder;
/// The original function's exit block.
SILBasicBlock *originalExitBlock = nullptr;
/// Stack buffers allocated for storing local adjoint values.
SmallVector<AllocStackInst *, 64> functionLocalAllocations;
/// Copies created to deal with destructive enum operations
/// (unchecked_take_enum_addr)
llvm::SmallDenseMap<InitEnumDataAddrInst*, SILValue> enumDataAdjCopies;
/// A set used to remember local allocations that were destroyed.
llvm::SmallDenseSet<SILValue> destroyedLocalAllocations;
/// The seed arguments of the pullback function.
SmallVector<SILArgument *, 4> seeds;
/// The `AutoDiffLinearMapContext` object, if any.
SILValue contextValue = nullptr;
llvm::BumpPtrAllocator allocator;
bool errorOccurred = false;
ADContext &getContext() const { return vjpCloner.getContext(); }
SILModule &getModule() const { return getContext().getModule(); }
ASTContext &getASTContext() const { return getPullback().getASTContext(); }
SILFunction &getOriginal() const { return vjpCloner.getOriginal(); }
SILDifferentiabilityWitness *getWitness() const {
return vjpCloner.getWitness();
}
DifferentiationInvoker getInvoker() const { return vjpCloner.getInvoker(); }
LinearMapInfo &getPullbackInfo() const { return vjpCloner.getPullbackInfo(); }
const AutoDiffConfig &getConfig() const { return vjpCloner.getConfig(); }
const DifferentiableActivityInfo &getActivityInfo() const {
return vjpCloner.getActivityInfo();
}
//--------------------------------------------------------------------------//
// Pullback struct mapping
//--------------------------------------------------------------------------//
void initializePullbackTupleElements(SILBasicBlock *origBB,
SILInstructionResultArray values) {
auto *pbTupleTyple = getPullbackInfo().getLinearMapTupleType(origBB);
assert(pbTupleTyple->getNumElements() == values.size() &&
"The number of pullback tuple fields must equal the number of "
"pullback tuple element values");
auto res = pullbackTupleElements.insert({origBB, { values.begin(), values.end() }});
(void)res;
assert(res.second && "A pullback tuple element already exists!");
}
void initializePullbackTupleElements(SILBasicBlock *origBB,
const llvm::ArrayRef<SILArgument *> &values) {
auto *pbTupleTyple = getPullbackInfo().getLinearMapTupleType(origBB);
assert(pbTupleTyple->getNumElements() == values.size() &&
"The number of pullback tuple fields must equal the number of "
"pullback tuple element values");
auto res = pullbackTupleElements.insert({origBB, { values.begin(), values.end() }});
(void)res;
assert(res.second && "A pullback struct element already exists!");
}
/// Returns the pullback tuple element value corresponding to the given
/// original block and apply inst.
SILValue getPullbackTupleElement(FullApplySite fai) {
unsigned idx = getPullbackInfo().lookUpLinearMapIndex(fai);
assert((idx > 0 || (idx == 0 && fai.getParent()->isEntry())) &&
"impossible linear map index");
auto values = pullbackTupleElements.lookup(fai.getParent());
assert(idx < values.size() &&
"pullback tuple element for this apply does not exist!");
return values[idx];
}
/// Returns the pullback tuple element value corresponding to the predecessor
/// for the given original block.
SILValue getPullbackPredTupleElement(SILBasicBlock *origBB) {
assert(!origBB->isEntry() && "no predecessors for entry block");
auto values = pullbackTupleElements.lookup(origBB);
assert(values.size() && "pullback tuple cannot be empty");
return values[0];
}
//--------------------------------------------------------------------------//
// Type transformer
//--------------------------------------------------------------------------//
/// Get the type lowering for the given AST type.
///
/// Explicitly use minimal type expansion context: in general, differentiation
/// happens on function types, so it cannot know if the original function is
/// resilient or not.
const Lowering::TypeLowering &getTypeLowering(Type type) {
auto pbGenSig =
getPullback().getLoweredFunctionType()->getSubstGenericSignature();
Lowering::AbstractionPattern pattern(pbGenSig,
type->getReducedType(pbGenSig));
return getContext().getTypeConverter().getTypeLowering(
pattern, type, TypeExpansionContext::minimal());
}
/// Remap any archetypes into the current function's context.
SILType remapType(SILType ty) {
if (ty.hasArchetype())
ty = ty.mapTypeOutOfContext();
auto remappedType = ty.getASTType()->getReducedType(
getPullback().getLoweredFunctionType()->getSubstGenericSignature());
auto remappedSILType =
SILType::getPrimitiveType(remappedType, ty.getCategory());
// FIXME: Sometimes getPullback() doesn't have a generic environment, in which
// case callers are apparently happy to receive an interface type.
if (getPullback().getGenericEnvironment())
return getPullback().mapTypeIntoContext(remappedSILType);
return remappedSILType;
}
std::optional<TangentSpace> getTangentSpace(CanType type) {
// Use witness generic signature to remap types.
type =
getWitness()->getDerivativeGenericSignature().getReducedType(
type);
return type->getAutoDiffTangentSpace(
LookUpConformanceInModule());
}
/// Returns the tangent value category of the given value.
SILValueCategory getTangentValueCategory(SILValue v) {
// Tangent value category table:
//
// Let $L be a loadable type and $*A be an address-only type.
//
// Original type | Tangent type loadable? | Tangent value category and type
// --------------|------------------------|--------------------------------
// $L | loadable | object, $L' (no mismatch)
// $*A | loadable | address, $*L' (create a buffer)
// $L | address-only | address, $*A' (no alternative)
// $*A | address-only | address, $*A' (no alternative)
// TODO(https://github.com/apple/swift/issues/55523): Make "tangent value category" depend solely on whether the tangent type is loadable or address-only.
//
// For loadable tangent types, using symbolic adjoint values instead of
// concrete adjoint buffers is more efficient.
// Quick check: if the value has an address type, the tangent value category
// is currently always "address".
if (v->getType().isAddress())
return SILValueCategory::Address;
// If the value has an object type and the tangent type is not address-only,
// then the tangent value category is "object".
auto tanSpace = getTangentSpace(remapType(v->getType()).getASTType());
auto tanASTType = tanSpace->getCanonicalType();
if (v->getType().isObject() && getTypeLowering(tanASTType).isLoadable())
return SILValueCategory::Object;
// Otherwise, the tangent value category is "address".
return SILValueCategory::Address;
}
/// Assuming the given type conforms to `Differentiable` after remapping,
/// returns the associated tangent space type.
SILType getRemappedTangentType(SILType type) {
return SILType::getPrimitiveType(
getTangentSpace(remapType(type).getASTType())->getCanonicalType(),
type.getCategory());
}
/// Substitutes all replacement types of the given substitution map using the
/// pullback function's substitution map.
SubstitutionMap remapSubstitutionMap(SubstitutionMap substMap) {
return substMap.subst(getPullback().getForwardingSubstitutionMap());
}
//--------------------------------------------------------------------------//
// Temporary value management
//--------------------------------------------------------------------------//
/// Record a temporary value for cleanup before its block's terminator.
SILValue recordTemporary(SILValue value) {
assert(value->getType().isObject());
assert(value->getFunction() == &getPullback());
auto inserted = blockTemporaries[value->getParentBlock()].insert(value);
(void)inserted;
LLVM_DEBUG(getADDebugStream() << "Recorded temporary " << value);
assert(inserted && "Temporary already recorded?");
return value;
}
/// Clean up all temporary values for the given pullback block.
void cleanUpTemporariesForBlock(SILBasicBlock *bb, SILLocation loc) {
assert(bb->getParent() == &getPullback());
LLVM_DEBUG(getADDebugStream() << "Cleaning up temporaries for pullback bb"
<< bb->getDebugID() << '\n');
for (auto temp : blockTemporaries[bb])
builder.emitDestroyValueOperation(loc, temp);
blockTemporaries[bb].clear();
}
//--------------------------------------------------------------------------//
// Adjoint value factory methods
//--------------------------------------------------------------------------//
AdjointValue makeZeroAdjointValue(SILType type) {
return AdjointValue::createZero(allocator, remapType(type));
}
AdjointValue makeConcreteAdjointValue(SILValue value) {
return AdjointValue::createConcrete(allocator, value);
}
AdjointValue makeAggregateAdjointValue(SILType type,
ArrayRef<AdjointValue> elements) {
return AdjointValue::createAggregate(allocator, remapType(type), elements);
}
AdjointValue makeAddElementAdjointValue(AdjointValue baseAdjoint,
AdjointValue eltToAdd,
FieldLocator fieldLocator) {
auto *addElementValue =
new AddElementValue(baseAdjoint, eltToAdd, fieldLocator);
return AdjointValue::createAddElement(allocator, baseAdjoint.getType(),
addElementValue);
}
//--------------------------------------------------------------------------//
// Adjoint value materialization
//--------------------------------------------------------------------------//
/// Materializes an adjoint value. The type of the given adjoint value must be
/// loadable.
SILValue materializeAdjointDirect(AdjointValue val, SILLocation loc) {
assert(val.getType().isObject());
LLVM_DEBUG(getADDebugStream()
<< "Materializing adjoint for " << val << '\n');
SILValue result;
switch (val.getKind()) {
case AdjointValueKind::Zero:
result = recordTemporary(builder.emitZero(loc, val.getSwiftType()));
break;
case AdjointValueKind::Aggregate: {
SmallVector<SILValue, 8> elements;
for (auto i : range(val.getNumAggregateElements())) {
auto eltVal = materializeAdjointDirect(val.getAggregateElement(i), loc);
elements.push_back(builder.emitCopyValueOperation(loc, eltVal));
}
if (val.getType().is<TupleType>())
result = recordTemporary(
builder.createTuple(loc, val.getType(), elements));
else
result = recordTemporary(
builder.createStruct(loc, val.getType(), elements));
break;
}
case AdjointValueKind::Concrete:
result = val.getConcreteValue();
break;
case AdjointValueKind::AddElement: {
auto adjointSILType = val.getAddElementValue()->baseAdjoint.getType();
auto *baseAdjAlloc = builder.createAllocStack(loc, adjointSILType);
materializeAdjointIndirect(val, baseAdjAlloc, loc);
auto baseAdjConcrete = recordTemporary(builder.emitLoadValueOperation(
loc, baseAdjAlloc, LoadOwnershipQualifier::Take));
builder.createDeallocStack(loc, baseAdjAlloc);
result = baseAdjConcrete;
break;
}
}
if (auto debugInfo = val.getDebugInfo())
builder.createDebugValue(
debugInfo->first.getLocation(), result, debugInfo->second);
return result;
}
/// Materializes an adjoint value indirectly to a SIL buffer.
void materializeAdjointIndirect(AdjointValue val, SILValue destAddress,
SILLocation loc) {
assert(destAddress->getType().isAddress());
switch (val.getKind()) {
/// If adjoint value is a symbolic zero, emit a call to
/// `AdditiveArithmetic.zero`.
case AdjointValueKind::Zero:
builder.emitZeroIntoBuffer(loc, destAddress, IsInitialization);
break;
/// If adjoint value is a symbolic aggregate (tuple or struct), recursively
/// materialize the symbolic tuple or struct, filling the
/// buffer.
case AdjointValueKind::Aggregate: {
if (auto *tupTy = val.getSwiftType()->getAs<TupleType>()) {
for (auto idx : range(val.getNumAggregateElements())) {
auto eltTy = SILType::getPrimitiveAddressType(
tupTy->getElementType(idx)->getCanonicalType());
auto *eltBuf =
builder.createTupleElementAddr(loc, destAddress, idx, eltTy);
materializeAdjointIndirect(val.getAggregateElement(idx), eltBuf, loc);
}
} else if (auto *structDecl =
val.getSwiftType()->getStructOrBoundGenericStruct()) {
auto fieldIt = structDecl->getStoredProperties().begin();
for (unsigned i = 0; fieldIt != structDecl->getStoredProperties().end();
++fieldIt, ++i) {
auto eltBuf =
builder.createStructElementAddr(loc, destAddress, *fieldIt);
materializeAdjointIndirect(val.getAggregateElement(i), eltBuf, loc);
}
} else {
llvm_unreachable("Not an aggregate type");
}
break;
}
/// If adjoint value is concrete, it is already materialized. Store it in
/// the destination address.
case AdjointValueKind::Concrete: {
auto concreteVal = val.getConcreteValue();
auto copyOfConcreteVal = builder.emitCopyValueOperation(loc, concreteVal);
builder.emitStoreValueOperation(loc, copyOfConcreteVal, destAddress,
StoreOwnershipQualifier::Init);
break;
}
case AdjointValueKind::AddElement: {
auto baseAdjoint = val;
auto baseAdjointType = baseAdjoint.getType();
// Current adjoint may be made up of layers of `AddElement` adjoints.
// We can iteratively gather the list of elements to add instead of making
// recursive calls to `materializeAdjointIndirect`.
SmallVector<AddElementValue *, 4> addEltAdjValues;
do {
auto addElementValue = baseAdjoint.getAddElementValue();
addEltAdjValues.push_back(addElementValue);
baseAdjoint = addElementValue->baseAdjoint;
assert(baseAdjointType == baseAdjoint.getType());
} while (baseAdjoint.getKind() == AdjointValueKind::AddElement);
materializeAdjointIndirect(baseAdjoint, destAddress, loc);
for (auto *addElementValue : addEltAdjValues) {
auto eltToAdd = addElementValue->eltToAdd;
SILValue baseAdjEltAddr;
if (baseAdjoint.getType().is<TupleType>()) {
baseAdjEltAddr = builder.createTupleElementAddr(
loc, destAddress, addElementValue->getFieldIndex());
} else {
baseAdjEltAddr = builder.createStructElementAddr(
loc, destAddress, addElementValue->getFieldDecl());
}
auto eltToAddMaterialized = materializeAdjointDirect(eltToAdd, loc);
// Copy `eltToAddMaterialized` so we have a value with owned ownership
// semantics, required for using `eltToAddMaterialized` in a `store`
// instruction.
auto eltToAddMaterializedCopy =
builder.emitCopyValueOperation(loc, eltToAddMaterialized);
auto *eltToAddAlloc = builder.createAllocStack(loc, eltToAdd.getType());
builder.emitStoreValueOperation(loc, eltToAddMaterializedCopy,
eltToAddAlloc,
StoreOwnershipQualifier::Init);
builder.emitInPlaceAdd(loc, baseAdjEltAddr, eltToAddAlloc);
builder.createDestroyAddr(loc, eltToAddAlloc);
builder.createDeallocStack(loc, eltToAddAlloc);
}
break;
}
}
}
//--------------------------------------------------------------------------//
// Adjoint value mapping
//--------------------------------------------------------------------------//
/// Returns true if the given value in the original function has a
/// corresponding adjoint value.
bool hasAdjointValue(SILBasicBlock *origBB, SILValue originalValue) const {
assert(origBB->getParent() == &getOriginal());
assert(originalValue->getType().isObject());
return valueMap.count({origBB, originalValue});
}
/// Initializes the adjoint value for the original value. Asserts that the
/// original value does not already have an adjoint value.
void setAdjointValue(SILBasicBlock *origBB, SILValue originalValue,
AdjointValue adjointValue) {
LLVM_DEBUG(getADDebugStream()
<< "Setting adjoint value for " << originalValue);
assert(origBB->getParent() == &getOriginal());
assert(originalValue->getType().isObject());
assert(getTangentValueCategory(originalValue) == SILValueCategory::Object);
assert(adjointValue.getType().isObject());
assert(originalValue->getFunction() == &getOriginal());
// The adjoint value must be in the tangent space.
assert(adjointValue.getType() ==
getRemappedTangentType(originalValue->getType()));
// Try to assign a debug variable.
if (auto debugInfo = findDebugLocationAndVariable(originalValue)) {
LLVM_DEBUG({
auto &s = getADDebugStream();
s << "Found debug variable: \"" << debugInfo->second.Name
<< "\"\nLocation: ";
debugInfo->first.getLocation().print(s, getASTContext().SourceMgr);
s << '\n';
});
adjointValue.setDebugInfo(*debugInfo);
} else {
LLVM_DEBUG(getADDebugStream() << "No debug variable found.\n");
}
// Insert into dictionary.
auto insertion =
valueMap.try_emplace({origBB, originalValue}, adjointValue);
LLVM_DEBUG(getADDebugStream()
<< "The new adjoint value, replacing the existing one, is: "
<< insertion.first->getSecond() << '\n');
if (!insertion.second)
insertion.first->getSecond() = adjointValue;
}
/// Returns the adjoint value for a value in the original function.
///
/// This method first tries to find an existing entry in the adjoint value
/// mapping. If no entry exists, creates a zero adjoint value.
AdjointValue getAdjointValue(SILBasicBlock *origBB, SILValue originalValue) {
assert(origBB->getParent() == &getOriginal());
assert(originalValue->getType().isObject());
assert(getTangentValueCategory(originalValue) == SILValueCategory::Object);
assert(originalValue->getFunction() == &getOriginal());
auto insertion = valueMap.try_emplace(
{origBB, originalValue},
makeZeroAdjointValue(getRemappedTangentType(originalValue->getType())));
auto it = insertion.first;
return it->getSecond();
}
/// Adds `newAdjointValue` to the adjoint value for `originalValue` and sets
/// the sum as the new adjoint value.
void addAdjointValue(SILBasicBlock *origBB, SILValue originalValue,
AdjointValue newAdjointValue, SILLocation loc) {
assert(origBB->getParent() == &getOriginal());
assert(originalValue->getType().isObject());
assert(newAdjointValue.getType().isObject());
assert(originalValue->getFunction() == &getOriginal());
LLVM_DEBUG(getADDebugStream()
<< "Adding adjoint value for " << originalValue);
// The adjoint value must be in the tangent space.
assert(newAdjointValue.getType() ==
getRemappedTangentType(originalValue->getType()));
// Try to assign a debug variable.
if (auto debugInfo = findDebugLocationAndVariable(originalValue)) {
LLVM_DEBUG({
auto &s = getADDebugStream();
s << "Found debug variable: \"" << debugInfo->second.Name
<< "\"\nLocation: ";
debugInfo->first.getLocation().print(s, getASTContext().SourceMgr);
s << '\n';
});
newAdjointValue.setDebugInfo(*debugInfo);
} else {
LLVM_DEBUG(getADDebugStream() << "No debug variable found.\n");
}
auto insertion =
valueMap.try_emplace({origBB, originalValue}, newAdjointValue);
auto inserted = insertion.second;
if (inserted)
return;
// If adjoint already exists, accumulate the adjoint onto the existing
// adjoint.
auto it = insertion.first;
auto existingValue = it->getSecond();
valueMap.erase(it);
auto adjVal = accumulateAdjointsDirect(existingValue, newAdjointValue, loc);
// If the original value is the `Array` result of an
// `array.uninitialized_intrinsic` application, accumulate adjoint buffers
// for the array element addresses.
accumulateArrayLiteralElementAddressAdjoints(origBB, originalValue, adjVal,
loc);
setAdjointValue(origBB, originalValue, adjVal);
}
/// Get the pullback block argument corresponding to the given original block
/// and active value.
SILArgument *getActiveValuePullbackBlockArgument(SILBasicBlock *origBB,
SILValue activeValue) {
assert(getTangentValueCategory(activeValue) == SILValueCategory::Object);
assert(origBB->getParent() == &getOriginal());
auto pullbackBBArg =
activeValuePullbackBBArgumentMap[{origBB, activeValue}];
assert(pullbackBBArg);
assert(pullbackBBArg->getParent() == getPullbackBlock(origBB));
return pullbackBBArg;
}
//--------------------------------------------------------------------------//
// Adjoint value accumulation
//--------------------------------------------------------------------------//
/// Given two adjoint values, accumulates them and returns their sum.
AdjointValue accumulateAdjointsDirect(AdjointValue lhs, AdjointValue rhs,
SILLocation loc);
//--------------------------------------------------------------------------//
// Adjoint buffer mapping
//--------------------------------------------------------------------------//
/// If the given original value is an address projection, returns a
/// corresponding adjoint projection to be used as its adjoint buffer.
///
/// Helper function for `getAdjointBuffer`.
SILValue getAdjointProjection(SILBasicBlock *origBB, SILValue originalValue);
/// Returns the adjoint buffer for the original value.
///
/// This method first tries to find an existing entry in the adjoint buffer
/// mapping. If no entry exists, creates a zero adjoint buffer.
SILValue getAdjointBuffer(SILBasicBlock *origBB, SILValue originalValue) {
assert(getTangentValueCategory(originalValue) == SILValueCategory::Address);
assert(originalValue->getFunction() == &getOriginal());
auto insertion = bufferMap.try_emplace({origBB, originalValue}, SILValue());
if (!insertion.second) // not inserted
return insertion.first->getSecond();
// If the original buffer is a projection, return a corresponding projection
// into the adjoint buffer.
if (auto adjProj = getAdjointProjection(origBB, originalValue))
return (bufferMap[{origBB, originalValue}] = adjProj);
LLVM_DEBUG(getADDebugStream() << "Creating new adjoint buffer for "
<< originalValue
<< "in bb" << origBB->getDebugID() << '\n');
auto bufType = getRemappedTangentType(originalValue->getType());
// Set insertion point for local allocation builder: before the last local
// allocation, or at the start of the pullback function's entry if no local
// allocations exist yet.
auto debugInfo = findDebugLocationAndVariable(originalValue);
SILLocation loc = debugInfo ? debugInfo->first.getLocation()
: RegularLocation::getAutoGeneratedLocation();
llvm::SmallString<32> adjName;
auto *newBuf = createFunctionLocalAllocation(
bufType, loc, /*zeroInitialize*/ true,
swift::transform(debugInfo,
[&](AdjointValue::DebugInfo di) {
llvm::raw_svector_ostream adjNameStream(adjName);
SILDebugVariable &dv = di.second;
dv.ArgNo = 0;
adjNameStream << "derivative of '" << dv.Name << "'";
if (SILDebugLocation origBBLoc = origBB->front().getDebugLocation()) {
adjNameStream << " in scope at ";
origBBLoc.getLocation().print(adjNameStream, getASTContext().SourceMgr);
}
adjNameStream << " (scope #" << origBB->getDebugID() << ")";
dv.Name = adjName;
// We have no meaningful debug location, and the type is different.
dv.Scope = nullptr;
dv.Loc = {};
dv.Type = {};
dv.DIExpr = {};
return dv;
}));
return (insertion.first->getSecond() = newBuf);
}
/// Initializes the adjoint buffer for the original value. Asserts that the
/// original value does not already have an adjoint buffer.
void setAdjointBuffer(SILBasicBlock *origBB, SILValue originalValue,
SILValue adjointBuffer) {
assert(getTangentValueCategory(originalValue) == SILValueCategory::Address);
auto insertion =
bufferMap.try_emplace({origBB, originalValue}, adjointBuffer);
assert(insertion.second && "Adjoint buffer already exists");
(void)insertion;
}
/// Accumulates `rhsAddress` into the adjoint buffer corresponding to the
/// original value.
void addToAdjointBuffer(SILBasicBlock *origBB, SILValue originalValue,
SILValue rhsAddress, SILLocation loc) {
assert(getTangentValueCategory(originalValue) ==
SILValueCategory::Address &&
rhsAddress->getType().isAddress());
assert(originalValue->getFunction() == &getOriginal());
assert(rhsAddress->getFunction() == &getPullback());
auto adjointBuffer = getAdjointBuffer(origBB, originalValue);
LLVM_DEBUG(getADDebugStream() << "Adding"
<< rhsAddress << "to adjoint ("
<< adjointBuffer << ") of "
<< originalValue
<< "in bb" << origBB->getDebugID() << '\n');
builder.emitInPlaceAdd(loc, adjointBuffer, rhsAddress);
}
/// Returns a next insertion point for creating a local allocation: either
/// before the previous local allocation, or at the start of the pullback
/// entry if no local allocations exist.
///
/// Helper for `createFunctionLocalAllocation`.
SILBasicBlock::iterator getNextFunctionLocalAllocationInsertionPoint() {
// If there are no local allocations, insert at the pullback entry start.
if (functionLocalAllocations.empty())
return getPullback().getEntryBlock()->begin();
// Otherwise, insert before the last local allocation. Inserting before
// rather than after ensures that allocation and zero initialization
// instructions are grouped together.
auto lastLocalAlloc = functionLocalAllocations.back();
return lastLocalAlloc->getDefiningInstruction()->getIterator();
}
/// Creates and returns a local allocation with the given type.
///
/// Local allocations are created uninitialized in the pullback entry and
/// deallocated in the pullback exit. All local allocations not in
/// `destroyedLocalAllocations` are also destroyed in the pullback exit.
///
/// Helper for `getAdjointBuffer`.
AllocStackInst *createFunctionLocalAllocation(
SILType type, SILLocation loc, bool zeroInitialize = false,
std::optional<SILDebugVariable> varInfo = std::nullopt) {
// Set insertion point for local allocation builder: before the last local
// allocation, or at the start of the pullback function's entry if no local
// allocations exist yet.
localAllocBuilder.setInsertionPoint(
getPullback().getEntryBlock(),
getNextFunctionLocalAllocationInsertionPoint());
// Create and return local allocation.
auto *alloc = localAllocBuilder.createAllocStack(loc, type, varInfo);
functionLocalAllocations.push_back(alloc);
// Zero-initialize if requested.
if (zeroInitialize)
localAllocBuilder.emitZeroIntoBuffer(loc, alloc, IsInitialization);
return alloc;
}
//--------------------------------------------------------------------------//
// Optional differentiation
//--------------------------------------------------------------------------//
/// Given a `wrappedAdjoint` value of type `T.TangentVector` and `Optional<T>`
/// type, creates an `Optional<T>.TangentVector` buffer from it.
///
/// `wrappedAdjoint` may be an object or address value, both cases are
/// handled.
AllocStackInst *createOptionalAdjoint(SILBasicBlock *bb,
SILValue wrappedAdjoint,
SILType optionalTy);
/// Accumulate adjoint of `wrappedAdjoint` into optionalBuffer.
void accumulateAdjointForOptionalBuffer(SILBasicBlock *bb,
SILValue optionalBuffer,
SILValue wrappedAdjoint);
/// Accumulate adjoint of `wrappedAdjoint` into optionalValue.
void accumulateAdjointValueForOptional(SILBasicBlock *bb,
SILValue optionalValue,
SILValue wrappedAdjoint);
//--------------------------------------------------------------------------//
// Array literal initialization differentiation
//--------------------------------------------------------------------------//
/// Given the adjoint value of an array initialized from an
/// `array.uninitialized_intrinsic` application and an array element index,
/// returns an `alloc_stack` containing the adjoint value of the array element
/// at the given index by applying `Array.TangentVector.subscript`.
AllocStackInst *getArrayAdjointElementBuffer(SILValue arrayAdjoint,
int eltIndex, SILLocation loc);
/// Given the adjoint value of an array initialized from an
/// `array.uninitialized_intrinsic` application, accumulates the adjoint
/// value's elements into the adjoint buffers of its element addresses.
void accumulateArrayLiteralElementAddressAdjoints(
SILBasicBlock *origBB, SILValue originalValue,
AdjointValue arrayAdjointValue, SILLocation loc);
//--------------------------------------------------------------------------//
// CFG mapping
//--------------------------------------------------------------------------//
SILBasicBlock *getPullbackBlock(SILBasicBlock *originalBlock) {
return pullbackBBMap.lookup(originalBlock);
}
SILBasicBlock *getPullbackTrampolineBlock(SILBasicBlock *originalBlock,
SILBasicBlock *successorBlock) {
return pullbackTrampolineBBMap.lookup({originalBlock, successorBlock});
}
//--------------------------------------------------------------------------//
// Debug info
//--------------------------------------------------------------------------//
const SILDebugScope *remapScope(const SILDebugScope *DS) {
return scopeCloner.getOrCreateClonedScope(DS);
}
//--------------------------------------------------------------------------//
// Debugging utilities
//--------------------------------------------------------------------------//
void printAdjointValueMapping() {
// Group original/adjoint values by basic block.
llvm::DenseMap<SILBasicBlock *, llvm::DenseMap<SILValue, AdjointValue>> tmp;
for (auto pair : valueMap) {
auto origPair = pair.first;
auto *origBB = origPair.first;
auto origValue = origPair.second;
auto adjValue = pair.second;
tmp[origBB].insert({origValue, adjValue});
}
// Print original/adjoint values per basic block.
auto &s = getADDebugStream() << "Adjoint value mapping:\n";
for (auto &origBB : getOriginal()) {
if (!pullbackBBMap.count(&origBB))
continue;
auto bbValueMap = tmp[&origBB];
s << "bb" << origBB.getDebugID();
s << " (size " << bbValueMap.size() << "):\n";
for (auto valuePair : bbValueMap) {
auto origValue = valuePair.first;
auto adjValue = valuePair.second;
s << "ORIG: " << origValue;
s << "ADJ: " << adjValue << '\n';
}
s << '\n';
}
}
void printAdjointBufferMapping() {
// Group original/adjoint buffers by basic block.
llvm::DenseMap<SILBasicBlock *, llvm::DenseMap<SILValue, SILValue>> tmp;
for (auto pair : bufferMap) {
auto origPair = pair.first;
auto *origBB = origPair.first;
auto origBuf = origPair.second;
auto adjBuf = pair.second;
tmp[origBB][origBuf] = adjBuf;
}
// Print original/adjoint buffers per basic block.
auto &s = getADDebugStream() << "Adjoint buffer mapping:\n";
for (auto &origBB : getOriginal()) {
if (!pullbackBBMap.count(&origBB))
continue;
auto bbBufferMap = tmp[&origBB];
s << "bb" << origBB.getDebugID();
s << " (size " << bbBufferMap.size() << "):\n";
for (auto valuePair : bbBufferMap) {
auto origBuf = valuePair.first;
auto adjBuf = valuePair.second;
s << "ORIG: " << origBuf;
s << "ADJ: " << adjBuf << '\n';
}
s << '\n';
}
}
public:
//--------------------------------------------------------------------------//
// Entry point
//--------------------------------------------------------------------------//
/// Performs pullback generation on the empty pullback function. Returns true
/// if any error occurs.
bool run();
/// Performs pullback generation on the empty pullback function, given that
/// the original function is a "semantic member accessor".
///
/// "Semantic member accessors" are attached to member properties that have a
/// corresponding tangent stored property in the parent `TangentVector` type.
/// These accessors have special-case pullback generation based on their
/// semantic behavior.
///
/// Returns true if any error occurs.
bool runForSemanticMemberAccessor();
bool runForSemanticMemberGetter();
bool runForSemanticMemberSetter();
/// If original result is non-varied, it will always have a zero derivative.
/// Skip full pullback generation and simply emit zero derivatives for wrt
/// parameters.
void emitZeroDerivativesForNonvariedResult(SILValue origNonvariedResult);
/// Public helper so that our users can get the underlying newly created
/// function.
SILFunction &getPullback() const { return vjpCloner.getPullback(); }
using TrampolineBlockSet = SmallPtrSet<SILBasicBlock *, 4>;
/// Determines the pullback successor block for a given original block and one
/// of its predecessors. When a trampoline block is necessary, emits code into
/// the trampoline block to trampoline the original block's active value's
/// adjoint values.
///
/// Populates `pullbackTrampolineBlockMap`, which maps active values' adjoint
/// values to the pullback successor blocks in which they are used. This
/// allows us to release those values in pullback successor blocks that do not
/// use them.
SILBasicBlock *
buildPullbackSuccessor(SILBasicBlock *origBB, SILBasicBlock *origPredBB,
llvm::SmallDenseMap<SILValue, TrampolineBlockSet>
&pullbackTrampolineBlockMap);
/// Emits pullback code in the corresponding pullback block.
void visitSILBasicBlock(SILBasicBlock *bb);
void visit(SILInstruction *inst) {
if (errorOccurred)
return;
LLVM_DEBUG(getADDebugStream()
<< "PullbackCloner visited:\n[ORIG]" << *inst);
#ifndef NDEBUG
auto beforeInsertion = std::prev(builder.getInsertionPoint());
#endif
SILInstructionVisitor::visit(inst);
LLVM_DEBUG({
auto &s = llvm::dbgs() << "[ADJ] Emitted in pullback (pb bb" <<
builder.getInsertionBB()->getDebugID() << "):\n";
auto afterInsertion = builder.getInsertionPoint();
for (auto it = ++beforeInsertion; it != afterInsertion; ++it)
s << *it;
});
}
/// Fallback instruction visitor for unhandled instructions.
/// Emit a general non-differentiability diagnostic.
void visitSILInstruction(SILInstruction *inst) {
LLVM_DEBUG(getADDebugStream()
<< "Unhandled instruction in PullbackCloner: " << *inst);
getContext().emitNondifferentiabilityError(
inst, getInvoker(), diag::autodiff_expression_not_differentiable_note);
errorOccurred = true;
}
/// Handle `apply` instruction.
/// Original: (y0, y1, ...) = apply @fn (x0, x1, ...)
/// Adjoint: (adj[x0], adj[x1], ...) += apply @fn_pullback (adj[y0], ...)
void visitApplyInst(ApplyInst *ai) {
assert(getPullbackInfo().shouldDifferentiateApplySite(ai));
// Skip `array.uninitialized_intrinsic` applications, which have special
// `store` and `copy_addr` support.
if (ArraySemanticsCall(ai, semantics::ARRAY_UNINITIALIZED_INTRINSIC))
return;
auto loc = ai->getLoc();
auto *bb = ai->getParent();
// Handle `array.finalize_intrinsic` applications.
// `array.finalize_intrinsic` semantically behaves like an identity
// function.
if (ArraySemanticsCall(ai, semantics::ARRAY_FINALIZE_INTRINSIC)) {
assert(ai->getNumArguments() == 1 &&
"Expected intrinsic to have one operand");
// Accumulate result's adjoint into argument's adjoint.
auto adjResult = getAdjointValue(bb, ai);
auto origArg = ai->getArgumentsWithoutIndirectResults().front();
addAdjointValue(bb, origArg, adjResult, loc);
return;
}
buildPullbackCall(ai);
}
void buildPullbackCall(FullApplySite fai) {
auto loc = fai->getLoc();
auto *bb = fai->getParent();
// Replace a call to a function with a call to its pullback.
auto &nestedApplyInfo = getContext().getNestedApplyInfo();