|
| 1 | +# Writing DataFlow Analyses in MLIR |
| 2 | + |
| 3 | +Writing dataflow analyses in MLIR, or well any compiler, can often seem quite |
| 4 | +daunting and/or complex. A dataflow analysis generally involves propagating |
| 5 | +information about the IR across various different types of control flow |
| 6 | +constructs, of which MLIR has many (Block-based branches, Region-based branches, |
| 7 | +CallGraph, etc), and it isn't always clear how best to go about performing the |
| 8 | +propagation. To help writing these types of analyses in MLIR, this document |
| 9 | +details several utilities that simplify the process and make it a bit more |
| 10 | +approachable. |
| 11 | + |
| 12 | +## Forward Dataflow Analysis |
| 13 | + |
| 14 | +One type of dataflow analysis is a forward propagation analysis. This type of |
| 15 | +analysis, as the name may suggest, propagates information forward (e.g. from |
| 16 | +definitions to uses). To provide a bit of concrete context, let's go over |
| 17 | +writing a simple forward dataflow analysis in MLIR. Let's say for this analysis |
| 18 | +that we want to propagate information about a special "metadata" dictionary |
| 19 | +attribute. The contents of this attribute are simply a set of metadata that |
| 20 | +describe a specific value, e.g. `metadata = { likes_pizza = true }`. We will |
| 21 | +collect the `metadata` for operations in the IR and propagate them about. |
| 22 | + |
| 23 | +### Lattices |
| 24 | + |
| 25 | +Before going into how one might setup the analysis itself, it is important to |
| 26 | +first introduce the concept of a `Lattice` and how we will use it for the |
| 27 | +analysis. A lattice represents all of the possible values or results of the |
| 28 | +analysis for a given value. A lattice element holds the set of information |
| 29 | +computed by the analysis for a given value, and is what gets propagated across |
| 30 | +the IR. For our analysis, this would correspond to the `metadata` dictionary |
| 31 | +attribute. |
| 32 | + |
| 33 | +Regardless of the value held within, every type of lattice contains two special |
| 34 | +element states: |
| 35 | + |
| 36 | +* `uninitialized` |
| 37 | + |
| 38 | + - The element has not been initialized. |
| 39 | + |
| 40 | +* `top`/`overdefined`/`unknown` |
| 41 | + |
| 42 | + - The element encompasses every possible value. |
| 43 | + - This is a very conservative state, and essentially means "I can't make |
| 44 | + any assumptions about the value, it could be anything" |
| 45 | + |
| 46 | +These two states are important when merging, or `join`ing as we will refer to it |
| 47 | +further in this document, information as part of the analysis. Lattice elements |
| 48 | +are `join`ed whenever there are two different source points, such as an argument |
| 49 | +to a block with multiple predecessors. One important note about the `join` |
| 50 | +operation, is that it is required to be monotonic (see the `join` method in the |
| 51 | +example below for more information). This ensures that `join`ing elements is |
| 52 | +consistent. The two special states mentioned above have unique properties during |
| 53 | +a `join`: |
| 54 | + |
| 55 | +* `uninitialized` |
| 56 | + |
| 57 | + - If one of the elements is `uninitialized`, the other element is used. |
| 58 | + - `uninitialized` in the context of a `join` essentially means "take the |
| 59 | + other thing". |
| 60 | + |
| 61 | +* `top`/`overdefined`/`unknown` |
| 62 | + |
| 63 | + - If one of the elements being joined is `overdefined`, the result is |
| 64 | + `overdefined`. |
| 65 | + |
| 66 | +For our analysis in MLIR, we will need to define a class representing the value |
| 67 | +held by an element of the lattice used by our dataflow analysis: |
| 68 | + |
| 69 | +```c++ |
| 70 | +/// The value of our lattice represents the inner structure of a DictionaryAttr, |
| 71 | +/// for the `metadata`. |
| 72 | +struct MetadataLatticeValue { |
| 73 | + MetadataLatticeValue() = default; |
| 74 | + /// Compute a lattice value from the provided dictionary. |
| 75 | + MetadataLatticeValue(DictionaryAttr attr) |
| 76 | + : metadata(attr.begin(), attr.end()) {} |
| 77 | + |
| 78 | + /// Return a pessimistic value state, i.e. the `top`/`overdefined`/`unknown` |
| 79 | + /// state, for our value type. The resultant state should not assume any |
| 80 | + /// information about the state of the IR. |
| 81 | + static MetadataLatticeValue getPessimisticValueState(MLIRContext *context) { |
| 82 | + // The `top`/`overdefined`/`unknown` state is when we know nothing about any |
| 83 | + // metadata, i.e. an empty dictionary. |
| 84 | + return MetadataLatticeValue(); |
| 85 | + } |
| 86 | + /// Return a pessimistic value state for our value type using only information |
| 87 | + /// about the state of the provided IR. This is similar to the above method, |
| 88 | + /// but may produce a slightly more refined result. This is okay, as the |
| 89 | + /// information is already encoded as fact in the IR. |
| 90 | + static MetadataLatticeValue getPessimisticValueState(Value value) { |
| 91 | + // Check to see if the parent operation has metadata. |
| 92 | + if (Operation *parentOp = value.getDefiningOp()) { |
| 93 | + if (auto metadata = parentOp->getAttrOfType<DictionaryAttr>("metadata")) |
| 94 | + return MetadataLatticeValue(metadata); |
| 95 | + |
| 96 | + // If no metadata is present, fallback to the |
| 97 | + // `top`/`overdefined`/`unknown` state. |
| 98 | + } |
| 99 | + return MetadataLatticeValue(); |
| 100 | + } |
| 101 | + |
| 102 | + /// This method conservatively joins the information held by `lhs` and `rhs` |
| 103 | + /// into a new value. This method is required to be monotonic. `monotonicity` |
| 104 | + /// is implied by the satisfaction of the following axioms: |
| 105 | + /// * idempotence: join(x,x) == x |
| 106 | + /// * commutativity: join(x,y) == join(y,x) |
| 107 | + /// * associativity: join(x,join(y,z)) == join(join(x,y),z) |
| 108 | + /// |
| 109 | + /// When the above axioms are satisfied, we achieve `monotonicity`: |
| 110 | + /// * monotonicity: join(x, join(x,y)) == join(x,y) |
| 111 | + static MetadataLatticeValue join(const MetadataLatticeValue &lhs, |
| 112 | + const MetadataLatticeValue &rhs) { |
| 113 | + // To join `lhs` and `rhs` we will define a simple policy, which is that we |
| 114 | + // only keep information that is the same. This means that we only keep |
| 115 | + // facts that are true in both. |
| 116 | + MetadataLatticeValue result; |
| 117 | + for (const auto &lhsIt : lhs) { |
| 118 | + // As noted above, we only merge if the values are the same. |
| 119 | + auto it = rhs.metadata.find(lhsIt.first); |
| 120 | + if (it == rhs.metadata.end() || it->second != lhsIt.second) |
| 121 | + continue; |
| 122 | + result.insert(lhsIt); |
| 123 | + } |
| 124 | + return result; |
| 125 | + } |
| 126 | + |
| 127 | + /// A simple comparator that checks to see if this value is equal to the one |
| 128 | + /// provided. |
| 129 | + bool operator==(const MetadataLatticeValue &rhs) const { |
| 130 | + if (metadata.size() != rhs.metadata.size()) |
| 131 | + return false; |
| 132 | + // Check that the 'rhs' contains the same metadata. |
| 133 | + return llvm::all_of(metadata, [&](auto &it) { |
| 134 | + return rhs.metadata.count(it.second); |
| 135 | + }); |
| 136 | + } |
| 137 | + |
| 138 | + /// Our value represents the combined metadata, which is originally a |
| 139 | + /// DictionaryAttr, so we use a map. |
| 140 | + DenseMap<Identifier, Attribute> metadata; |
| 141 | +}; |
| 142 | +``` |
| 143 | +
|
| 144 | +One interesting thing to note above is that we don't have an explicit method for |
| 145 | +the `uninitialized` state. This state is handled by the `LatticeElement` class, |
| 146 | +which manages a lattice value for a given IR entity. A quick overview of this |
| 147 | +class, and the API that will be interesting to us while writing our analysis, is |
| 148 | +shown below: |
| 149 | +
|
| 150 | +```c++ |
| 151 | +/// This class represents a lattice element holding a specific value of type |
| 152 | +/// `ValueT`. |
| 153 | +template <typename ValueT> |
| 154 | +class LatticeElement ... { |
| 155 | +public: |
| 156 | + /// Return the value held by this element. This requires that a value is |
| 157 | + /// known, i.e. not `uninitialized`. |
| 158 | + ValueT &getValue(); |
| 159 | + const ValueT &getValue() const; |
| 160 | +
|
| 161 | + /// Join the information contained in the 'rhs' element into this |
| 162 | + /// element. Returns if the state of the current element changed. |
| 163 | + ChangeResult join(const LatticeElement<ValueT> &rhs); |
| 164 | +
|
| 165 | + /// Join the information contained in the 'rhs' value into this |
| 166 | + /// lattice. Returns if the state of the current lattice changed. |
| 167 | + ChangeResult join(const ValueT &rhs); |
| 168 | +
|
| 169 | + /// Mark the lattice element as having reached a pessimistic fixpoint. This |
| 170 | + /// means that the lattice may potentially have conflicting value states, and |
| 171 | + /// only the conservatively known value state should be relied on. |
| 172 | + ChangeResult markPessimisticFixPoint(); |
| 173 | +}; |
| 174 | +``` |
| 175 | + |
| 176 | +With our lattice defined, we can now define the driver that will compute and |
| 177 | +propagate our lattice across the IR. |
| 178 | + |
| 179 | +### ForwardDataflowAnalysis Driver |
| 180 | + |
| 181 | +The `ForwardDataFlowAnalysis` class represents the driver of the dataflow |
| 182 | +analysis, and performs all of the related analysis computation. When defining |
| 183 | +our analysis, we will inherit from this class and implement some of its hooks. |
| 184 | +Before that, let's look at a quick overview of this class and some of the |
| 185 | +important API for our analysis: |
| 186 | + |
| 187 | +```c++ |
| 188 | +/// This class represents the main driver of the forward dataflow analysis. It |
| 189 | +/// takes as a template parameter the value type of lattice being computed. |
| 190 | +template <typename ValueT> |
| 191 | +class ForwardDataFlowAnalysis : ... { |
| 192 | +public: |
| 193 | + ForwardDataFlowAnalysis(MLIRContext *context); |
| 194 | + |
| 195 | + /// Compute the analysis on operations rooted under the given top-level |
| 196 | + /// operation. Note that the top-level operation is not visited. |
| 197 | + void run(Operation *topLevelOp); |
| 198 | + |
| 199 | + /// Return the lattice element attached to the given value. If a lattice has |
| 200 | + /// not been added for the given value, a new 'uninitialized' value is |
| 201 | + /// inserted and returned. |
| 202 | + LatticeElement<ValueT> &getLatticeElement(Value value); |
| 203 | + |
| 204 | + /// Return the lattice element attached to the given value, or nullptr if no |
| 205 | + /// lattice element for the value has yet been created. |
| 206 | + LatticeElement<ValueT> *lookupLatticeElement(Value value); |
| 207 | + |
| 208 | + /// Mark all of the lattice elements for the given range of Values as having |
| 209 | + /// reached a pessimistic fixpoint. |
| 210 | + ChangeResult markAllPessimisticFixPoint(ValueRange values); |
| 211 | + |
| 212 | +protected: |
| 213 | + /// Visit the given operation, and join any necessary analysis state |
| 214 | + /// into the lattice elements for the results and block arguments owned by |
| 215 | + /// this operation using the provided set of operand lattice elements |
| 216 | + /// (all pointer values are guaranteed to be non-null). Returns if any result |
| 217 | + /// or block argument value lattice elements changed during the visit. The |
| 218 | + /// lattice element for a result or block argument value can be obtained, and |
| 219 | + /// join'ed into, by using `getLatticeElement`. |
| 220 | + virtual ChangeResult visitOperation( |
| 221 | + Operation *op, ArrayRef<LatticeElement<ValueT> *> operands) = 0; |
| 222 | +}; |
| 223 | +``` |
| 224 | +
|
| 225 | +NOTE: Some API has been redacted for our example. The `ForwardDataFlowAnalysis` |
| 226 | +contains various other hooks that allow for injecting custom behavior when |
| 227 | +applicable. |
| 228 | +
|
| 229 | +The main API that we are responsible for defining is the `visitOperation` |
| 230 | +method. This method is responsible for computing new lattice elements for the |
| 231 | +results and block arguments owned by the given operation. This is where we will |
| 232 | +inject the lattice element computation logic, also known as the transfer |
| 233 | +function for the operation, that is specific to our analysis. A simple |
| 234 | +implementation for our example is shown below: |
| 235 | +
|
| 236 | +```c++ |
| 237 | +class MetadataAnalysis : public ForwardDataFlowAnalysis<MetadataLatticeValue> { |
| 238 | +public: |
| 239 | + using ForwardDataFlowAnalysis<MetadataLatticeValue>::ForwardDataFlowAnalysis; |
| 240 | +
|
| 241 | + ChangeResult visitOperation( |
| 242 | + Operation *op, ArrayRef<LatticeElement<ValueT> *> operands) override { |
| 243 | + DictionaryAttr metadata = op->getAttrOfType<DictionaryAttr>("metadata"); |
| 244 | +
|
| 245 | + // If we have no metadata for this operation, we will conservatively mark |
| 246 | + // all of the results as having reached a pessimistic fixpoint. |
| 247 | + if (!metadata) |
| 248 | + return markAllPessimisticFixPoint(op->getResults()); |
| 249 | +
|
| 250 | + // Otherwise, we will compute a lattice value for the metadata and join it |
| 251 | + // into the current lattice element for all of our results. |
| 252 | + MetadataLatticeValue latticeValue(metadata); |
| 253 | + ChangeResult result = ChangeResult::NoChange; |
| 254 | + for (Value value : op->getResults()) { |
| 255 | + // We grab the lattice element for `value` via `getLatticeElement` and |
| 256 | + // then join it with the lattice value for this operation's metadata. Note |
| 257 | + // that during the analysis phase, it is fine to freely create a new |
| 258 | + // lattice element for a value. This is why we don't use the |
| 259 | + // `lookupLatticeElement` method here. |
| 260 | + result |= getLatticeElement(value).join(latticeValue); |
| 261 | + } |
| 262 | + return result; |
| 263 | + } |
| 264 | +}; |
| 265 | +``` |
| 266 | + |
| 267 | +With that, we have all of the necessary components to compute our analysis. |
| 268 | +After the analysis has been computed, we can grab any computed information for |
| 269 | +values by using `lookupLatticeElement`. We use this function over |
| 270 | +`getLatticeElement` as the analysis is not guaranteed to visit all values, e.g. |
| 271 | +if the value is in a unreachable block, and we don't want to create a new |
| 272 | +uninitialized lattice element in this case. See below for a quick example: |
| 273 | + |
| 274 | +```c++ |
| 275 | +void MyPass::runOnOperation() { |
| 276 | + MetadataAnalysis analysis(&getContext()); |
| 277 | + analysis.run(getOperation()); |
| 278 | + ... |
| 279 | +} |
| 280 | + |
| 281 | +void MyPass::useAnalysisOn(MetadataAnalysis &analysis, Value value) { |
| 282 | + LatticeElement<MetadataLatticeValue> *latticeElement = analysis.lookupLatticeElement(value); |
| 283 | + |
| 284 | + // If we don't have an element, the `value` wasn't visited during our analysis |
| 285 | + // meaning that it could be dead. We need to treat this conservatively. |
| 286 | + if (!lattice) |
| 287 | + return; |
| 288 | + |
| 289 | + // Our lattice element has a value, use it: |
| 290 | + MetadataLatticeValue &value = lattice->getValue(); |
| 291 | + ... |
| 292 | +} |
| 293 | +``` |
0 commit comments