Skip to content

Commit 152d29c

Browse files
committed
[mlir][Transforms] Add pass to perform sparse conditional constant propagation
This revision adds the initial pass for performing SCCP generically in MLIR. SCCP is an algorithm for propagating constants across control flow, and optimistically assumes all values to be constant unless proven otherwise. It currently supports branching control, with support for regions and inter-procedural propagation being added in followups. Differential Revision: https://reviews.llvm.org/D78397
1 parent 4ccafab commit 152d29c

File tree

10 files changed

+790
-0
lines changed

10 files changed

+790
-0
lines changed

mlir/include/mlir/Dialect/StandardOps/IR/Ops.td

+8
Original file line numberDiff line numberDiff line change
@@ -596,6 +596,10 @@ def BranchOp : Std_Op<"br",
596596

597597
/// Erase the operand at 'index' from the operand list.
598598
void eraseOperand(unsigned index);
599+
600+
/// Returns the successor that would be chosen with the given constant
601+
/// operands. Returns nullptr if a single successor could not be chosen.
602+
Block *getSuccessorForOperands(ArrayRef<Attribute>);
599603
}];
600604

601605
let hasCanonicalizer = 1;
@@ -1092,6 +1096,10 @@ def CondBranchOp : Std_Op<"cond_br",
10921096
eraseSuccessorOperand(falseIndex, index);
10931097
}
10941098

1099+
/// Returns the successor that would be chosen with the given constant
1100+
/// operands. Returns nullptr if a single successor could not be chosen.
1101+
Block *getSuccessorForOperands(ArrayRef<Attribute> operands);
1102+
10951103
private:
10961104
/// Get the index of the first true destination operand.
10971105
unsigned getTrueDestOperandIndex() { return 1; }

mlir/include/mlir/Interfaces/ControlFlowInterfaces.td

+8
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,14 @@ def BranchOpInterface : OpInterface<"BranchOpInterface"> {
6868
}
6969
return llvm::None;
7070
}]
71+
>,
72+
InterfaceMethod<[{
73+
Returns the successor that would be chosen with the given constant
74+
operands. Returns nullptr if a single successor could not be chosen.
75+
}],
76+
"Block *", "getSuccessorForOperands",
77+
(ins "ArrayRef<Attribute>":$operands), [{}],
78+
/*defaultImplementation=*/[{ return nullptr; }]
7179
>
7280
];
7381

mlir/include/mlir/Transforms/FoldUtils.h

+5
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,11 @@ class OperationFolder {
119119
/// Clear out any constants cached inside of the folder.
120120
void clear();
121121

122+
/// Get or create a constant using the given builder. On success this returns
123+
/// the constant operation, nullptr otherwise.
124+
Value getOrCreateConstant(OpBuilder &builder, Dialect *dialect,
125+
Attribute value, Type type, Location loc);
126+
122127
private:
123128
/// This map keeps track of uniqued constants by dialect, attribute, and type.
124129
/// A constant operation materializes an attribute with a type. Dialects may

mlir/include/mlir/Transforms/Passes.h

+4
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,10 @@ std::unique_ptr<OperationPass<ModuleOp>> createPrintOpStatsPass();
7676
/// the CallGraph.
7777
std::unique_ptr<Pass> createInlinerPass();
7878

79+
/// Creates a pass which performs sparse conditional constant propagation over
80+
/// nested operations.
81+
std::unique_ptr<Pass> createSCCPPass();
82+
7983
/// Creates a pass which delete symbol operations that are unreachable. This
8084
/// pass may *only* be scheduled on an operation that defines a SymbolTable.
8185
std::unique_ptr<Pass> createSymbolDCEPass();

mlir/include/mlir/Transforms/Passes.td

+14
Original file line numberDiff line numberDiff line change
@@ -273,6 +273,20 @@ def PrintOp : Pass<"print-op-graph", "ModuleOp"> {
273273
let constructor = "mlir::createPrintOpGraphPass()";
274274
}
275275

276+
def SCCP : Pass<"sccp"> {
277+
let summary = "Sparse Conditional Constant Propagation";
278+
let description = [{
279+
This pass implements a general algorithm for sparse conditional constant
280+
propagation. This algorithm detects values that are known to be constant and
281+
optimistically propagates this throughout the IR. Any values proven to be
282+
constant are replaced, and removed if possible.
283+
284+
This implementation is based on the algorithm described by Wegman and Zadeck
285+
in [“Constant Propagation with Conditional Branches”](https://dl.acm.org/doi/10.1145/103135.103136) (1991).
286+
}];
287+
let constructor = "mlir::createSCCPPass()";
288+
}
289+
276290
def StripDebugInfo : Pass<"strip-debuginfo"> {
277291
let summary = "Strip debug info from all operations";
278292
let description = [{

mlir/lib/Dialect/StandardOps/IR/Ops.cpp

+10
Original file line numberDiff line numberDiff line change
@@ -597,6 +597,8 @@ Optional<OperandRange> BranchOp::getSuccessorOperands(unsigned index) {
597597

598598
bool BranchOp::canEraseSuccessorOperand() { return true; }
599599

600+
Block *BranchOp::getSuccessorForOperands(ArrayRef<Attribute>) { return dest(); }
601+
600602
//===----------------------------------------------------------------------===//
601603
// CallOp
602604
//===----------------------------------------------------------------------===//
@@ -863,6 +865,14 @@ Optional<OperandRange> CondBranchOp::getSuccessorOperands(unsigned index) {
863865

864866
bool CondBranchOp::canEraseSuccessorOperand() { return true; }
865867

868+
Block *CondBranchOp::getSuccessorForOperands(ArrayRef<Attribute> operands) {
869+
if (BoolAttr condAttr = operands.front().dyn_cast_or_null<BoolAttr>())
870+
return condAttr.getValue() ? trueDest() : falseDest();
871+
if (IntegerAttr condAttr = operands.front().dyn_cast_or_null<IntegerAttr>())
872+
return condAttr.getValue().isOneValue() ? trueDest() : falseDest();
873+
return nullptr;
874+
}
875+
866876
//===----------------------------------------------------------------------===//
867877
// Constant*Op
868878
//===----------------------------------------------------------------------===//

mlir/lib/Transforms/CMakeLists.txt

+1
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ add_mlir_library(MLIRTransforms
1313
OpStats.cpp
1414
ParallelLoopCollapsing.cpp
1515
PipelineDataTransfer.cpp
16+
SCCP.cpp
1617
StripDebugInfo.cpp
1718
SymbolDCE.cpp
1819
ViewOpGraph.cpp

0 commit comments

Comments
 (0)