forked from llvm/llvm-project
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathByteCode.h
214 lines (174 loc) · 8.62 KB
/
ByteCode.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
//===- ByteCode.h - Pattern byte-code interpreter ---------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file declares a byte-code and interpreter for pattern rewrites in MLIR.
// The byte-code is constructed from the PDL Interpreter dialect.
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_REWRITE_BYTECODE_H_
#define MLIR_REWRITE_BYTECODE_H_
#include "mlir/IR/PatternMatch.h"
namespace mlir {
namespace pdl_interp {
class RecordMatchOp;
} // namespace pdl_interp
namespace detail {
class PDLByteCode;
/// Use generic bytecode types. ByteCodeField refers to the actual bytecode
/// entries. ByteCodeAddr refers to size of indices into the bytecode.
using ByteCodeField = uint16_t;
using ByteCodeAddr = uint32_t;
using OwningOpRange = llvm::OwningArrayRef<Operation *>;
//===----------------------------------------------------------------------===//
// PDLByteCodePattern
//===----------------------------------------------------------------------===//
/// All of the data pertaining to a specific pattern within the bytecode.
class PDLByteCodePattern : public Pattern {
public:
static PDLByteCodePattern create(pdl_interp::RecordMatchOp matchOp,
ByteCodeAddr rewriterAddr);
/// Return the bytecode address of the rewriter for this pattern.
ByteCodeAddr getRewriterAddr() const { return rewriterAddr; }
private:
template <typename... Args>
PDLByteCodePattern(ByteCodeAddr rewriterAddr, Args &&...patternArgs)
: Pattern(std::forward<Args>(patternArgs)...),
rewriterAddr(rewriterAddr) {}
/// The address of the rewriter for this pattern.
ByteCodeAddr rewriterAddr;
};
//===----------------------------------------------------------------------===//
// PDLByteCodeMutableState
//===----------------------------------------------------------------------===//
/// This class contains the mutable state of a bytecode instance. This allows
/// for a bytecode instance to be cached and reused across various different
/// threads/drivers.
class PDLByteCodeMutableState {
public:
/// Set the new benefit for a bytecode pattern. The `patternIndex` corresponds
/// to the position of the pattern within the range returned by
/// `PDLByteCode::getPatterns`.
void updatePatternBenefit(unsigned patternIndex, PatternBenefit benefit);
/// Cleanup any allocated state after a match/rewrite has been completed. This
/// method should be called irregardless of whether the match+rewrite was a
/// success or not.
void cleanupAfterMatchAndRewrite();
private:
/// Allow access to data fields.
friend class PDLByteCode;
/// The mutable block of memory used during the matching and rewriting phases
/// of the bytecode.
std::vector<const void *> memory;
/// A mutable block of memory used during the matching and rewriting phase of
/// the bytecode to store ranges of operations. These are always stored by
/// owning references, because at no point in the execution of the byte code
/// we get an indexed range (view) of operations.
std::vector<OwningOpRange> opRangeMemory;
/// A mutable block of memory used during the matching and rewriting phase of
/// the bytecode to store ranges of types.
std::vector<TypeRange> typeRangeMemory;
/// A set of type ranges that have been allocated by the byte code interpreter
/// to provide a guaranteed lifetime.
std::vector<llvm::OwningArrayRef<Type>> allocatedTypeRangeMemory;
/// A mutable block of memory used during the matching and rewriting phase of
/// the bytecode to store ranges of values.
std::vector<ValueRange> valueRangeMemory;
/// A set of value ranges that have been allocated by the byte code
/// interpreter to provide a guaranteed lifetime.
std::vector<llvm::OwningArrayRef<Value>> allocatedValueRangeMemory;
/// The current index of ranges being iterated over for each level of nesting.
/// These are always maintained at 0 for the loops that are not active, so we
/// do not need to have a separate initialization phase for each loop.
std::vector<unsigned> loopIndex;
/// The up-to-date benefits of the patterns held by the bytecode. The order
/// of this array corresponds 1-1 with the array of patterns in `PDLByteCode`.
std::vector<PatternBenefit> currentPatternBenefits;
};
//===----------------------------------------------------------------------===//
// PDLByteCode
//===----------------------------------------------------------------------===//
/// The bytecode class is also the interpreter. Contains the bytecode itself,
/// the static info, addresses of the rewriter functions, the interpreter
/// memory buffer, and the execution context.
class PDLByteCode {
public:
/// Each successful match returns a MatchResult, which contains information
/// necessary to execute the rewriter and indicates the originating pattern.
struct MatchResult {
MatchResult(Location loc, const PDLByteCodePattern &pattern,
PatternBenefit benefit)
: location(loc), pattern(&pattern), benefit(benefit) {}
MatchResult(const MatchResult &) = delete;
MatchResult &operator=(const MatchResult &) = delete;
MatchResult(MatchResult &&other) = default;
MatchResult &operator=(MatchResult &&) = default;
/// The location of operations to be replaced.
Location location;
/// Memory values defined in the matcher that are passed to the rewriter.
SmallVector<const void *> values;
/// Memory used for the range input values.
SmallVector<TypeRange, 0> typeRangeValues;
SmallVector<ValueRange, 0> valueRangeValues;
/// The originating pattern that was matched. This is always non-null, but
/// represented with a pointer to allow for assignment.
const PDLByteCodePattern *pattern;
/// The current benefit of the pattern that was matched.
PatternBenefit benefit;
};
/// Create a ByteCode instance from the given module containing operations in
/// the PDL interpreter dialect.
PDLByteCode(ModuleOp module,
llvm::StringMap<PDLConstraintFunction> constraintFns,
llvm::StringMap<PDLRewriteFunction> rewriteFns);
/// Return the patterns held by the bytecode.
ArrayRef<PDLByteCodePattern> getPatterns() const { return patterns; }
/// Initialize the given state such that it can be used to execute the current
/// bytecode.
void initializeMutableState(PDLByteCodeMutableState &state) const;
/// Run the pattern matcher on the given root operation, collecting the
/// matched patterns in `matches`.
void match(Operation *op, PatternRewriter &rewriter,
SmallVectorImpl<MatchResult> &matches,
PDLByteCodeMutableState &state) const;
/// Run the rewriter of the given pattern that was previously matched in
/// `match`.
void rewrite(PatternRewriter &rewriter, const MatchResult &match,
PDLByteCodeMutableState &state) const;
private:
/// Execute the given byte code starting at the provided instruction `inst`.
/// `matches` is an optional field provided when this function is executed in
/// a matching context.
void executeByteCode(const ByteCodeField *inst, PatternRewriter &rewriter,
PDLByteCodeMutableState &state,
SmallVectorImpl<MatchResult> *matches) const;
/// A vector containing pointers to uniqued data. The storage is intentionally
/// opaque such that we can store a wide range of data types. The types of
/// data stored here include:
/// * Attribute, OperationName, Type
std::vector<const void *> uniquedData;
/// A vector containing the generated bytecode for the matcher.
SmallVector<ByteCodeField, 64> matcherByteCode;
/// A vector containing the generated bytecode for all of the rewriters.
SmallVector<ByteCodeField, 64> rewriterByteCode;
/// The set of patterns contained within the bytecode.
SmallVector<PDLByteCodePattern, 32> patterns;
/// A set of user defined functions invoked via PDL.
std::vector<PDLConstraintFunction> constraintFunctions;
std::vector<PDLRewriteFunction> rewriteFunctions;
/// The maximum memory index used by a value.
ByteCodeField maxValueMemoryIndex = 0;
/// The maximum number of different types of ranges.
ByteCodeField maxOpRangeCount = 0;
ByteCodeField maxTypeRangeCount = 0;
ByteCodeField maxValueRangeCount = 0;
/// The maximum number of nested loops.
ByteCodeField maxLoopLevel = 0;
};
} // namespace detail
} // namespace mlir
#endif // MLIR_REWRITE_BYTECODE_H_