Skip to content

Commit af39b34

Browse files
authoredApr 6, 2023
[Node] Utility methods for ObjectPathPair handling (apache#14498)
* [Node] Utility methods for ObjectPathPair handling This commit adds a templated overload to `SEqualReducer::operator()` that accepts a lambda function to update the path of the LHS and RHS of the comparison. ```c++ // Usage prior to this utility function if (equal.IsPathTracingEnabled()) { const ObjectPathPair& self_paths = equal.GetCurrentObjectPaths(); ObjectPathPair attr_paths = {self_paths->lhs_path->Attr("value"), self_paths->rhs_path->Attr("value")}; if (!equal(this->value, other->value, attr_paths)) return false; } else { if (!equal(this->value, other->value)) return false; } // Usage after this utility function if (!equal(this->value, other->value, [](const auto& path) { return path->Attr("value"); })) { return false; } ``` * Unit test testing that discrepant path includes the PrimFunc's name * Updated docstring to resolve linting error * Fixed where to look for error message in unit test
1 parent 5239ec0 commit af39b34

File tree

4 files changed

+136
-67
lines changed

4 files changed

+136
-67
lines changed
 

‎include/tvm/node/structural_equal.h

+42-11
Original file line numberDiff line numberDiff line change
@@ -191,24 +191,53 @@ class SEqualReducer {
191191

192192
/*!
193193
* \brief Reduce condition to comparison of two attribute values.
194+
*
194195
* \param lhs The left operand.
196+
*
195197
* \param rhs The right operand.
198+
*
199+
* \param paths The paths to the LHS and RHS operands. If
200+
* unspecified, will attempt to identify the attribute's address
201+
* within the most recent ObjectRef. In general, the paths only
202+
* require explicit handling for computed parameters
203+
* (e.g. `array.size()`)
204+
*
196205
* \return the immediate check result.
197206
*/
198-
bool operator()(const double& lhs, const double& rhs) const;
199-
bool operator()(const int64_t& lhs, const int64_t& rhs) const;
200-
bool operator()(const uint64_t& lhs, const uint64_t& rhs) const;
201-
bool operator()(const int& lhs, const int& rhs) const;
202-
bool operator()(const bool& lhs, const bool& rhs) const;
203-
bool operator()(const std::string& lhs, const std::string& rhs) const;
204-
bool operator()(const DataType& lhs, const DataType& rhs) const;
207+
bool operator()(const double& lhs, const double& rhs,
208+
Optional<ObjectPathPair> paths = NullOpt) const;
209+
bool operator()(const int64_t& lhs, const int64_t& rhs,
210+
Optional<ObjectPathPair> paths = NullOpt) const;
211+
bool operator()(const uint64_t& lhs, const uint64_t& rhs,
212+
Optional<ObjectPathPair> paths = NullOpt) const;
213+
bool operator()(const int& lhs, const int& rhs, Optional<ObjectPathPair> paths = NullOpt) const;
214+
bool operator()(const bool& lhs, const bool& rhs, Optional<ObjectPathPair> paths = NullOpt) const;
215+
bool operator()(const std::string& lhs, const std::string& rhs,
216+
Optional<ObjectPathPair> paths = NullOpt) const;
217+
bool operator()(const DataType& lhs, const DataType& rhs,
218+
Optional<ObjectPathPair> paths = NullOpt) const;
205219

206220
template <typename ENum, typename = typename std::enable_if<std::is_enum<ENum>::value>::type>
207-
bool operator()(const ENum& lhs, const ENum& rhs) const {
221+
bool operator()(const ENum& lhs, const ENum& rhs,
222+
Optional<ObjectPathPair> paths = NullOpt) const {
208223
using Underlying = typename std::underlying_type<ENum>::type;
209224
static_assert(std::is_same<Underlying, int>::value,
210225
"Enum must have `int` as the underlying type");
211-
return EnumAttrsEqual(static_cast<int>(lhs), static_cast<int>(rhs), &lhs, &rhs);
226+
return EnumAttrsEqual(static_cast<int>(lhs), static_cast<int>(rhs), &lhs, &rhs, paths);
227+
}
228+
229+
template <typename T, typename Callable,
230+
typename = std::enable_if_t<
231+
std::is_same_v<std::invoke_result_t<Callable, const ObjectPath&>, ObjectPath>>>
232+
bool operator()(const T& lhs, const T& rhs, const Callable& callable) {
233+
if (IsPathTracingEnabled()) {
234+
ObjectPathPair current_paths = GetCurrentObjectPaths();
235+
ObjectPathPair new_paths = {callable(current_paths->lhs_path),
236+
callable(current_paths->rhs_path)};
237+
return (*this)(lhs, rhs, new_paths);
238+
} else {
239+
return (*this)(lhs, rhs);
240+
}
212241
}
213242

214243
/*!
@@ -310,7 +339,8 @@ class SEqualReducer {
310339
void RecordMismatchPaths(const ObjectPathPair& paths) const;
311340

312341
private:
313-
bool EnumAttrsEqual(int lhs, int rhs, const void* lhs_address, const void* rhs_address) const;
342+
bool EnumAttrsEqual(int lhs, int rhs, const void* lhs_address, const void* rhs_address,
343+
Optional<ObjectPathPair> paths = NullOpt) const;
314344

315345
bool ObjectAttrsEqual(const ObjectRef& lhs, const ObjectRef& rhs, bool map_free_vars,
316346
const ObjectPathPair* paths) const;
@@ -321,7 +351,8 @@ class SEqualReducer {
321351

322352
template <typename T>
323353
static bool CompareAttributeValues(const T& lhs, const T& rhs,
324-
const PathTracingData* tracing_data);
354+
const PathTracingData* tracing_data,
355+
Optional<ObjectPathPair> paths = NullOpt);
325356

326357
/*! \brief Internal class pointer. */
327358
Handler* handler_ = nullptr;

‎src/ir/module.cc

+32-32
Original file line numberDiff line numberDiff line change
@@ -63,46 +63,46 @@ IRModule::IRModule(tvm::Map<GlobalVar, BaseFunc> functions,
6363
}
6464

6565
bool IRModuleNode::SEqualReduce(const IRModuleNode* other, SEqualReducer equal) const {
66-
if (!equal(this->attrs, other->attrs)) return false;
66+
if (!equal(this->attrs, other->attrs, [](const auto& path) { return path->Attr("attrs"); })) {
67+
return false;
68+
}
69+
70+
if (equal.IsPathTracingEnabled()) {
71+
if ((functions.size() != other->functions.size()) ||
72+
(type_definitions.size() != other->type_definitions.size())) {
73+
return false;
74+
}
75+
}
6776

68-
if (functions.size() != other->functions.size()) return false;
69-
// Update GlobalVar remap
77+
// Define remaps for GlobalVar and GlobalTypeVar based on their
78+
// string name. Early bail-out is only performed when path-tracing
79+
// is disabled, as the later equality checks on the member variables
80+
// will provide better error messages.
7081
for (const auto& gv : this->GetGlobalVars()) {
71-
if (!other->ContainGlobalVar(gv->name_hint)) return false;
72-
if (!equal.DefEqual(gv, other->GetGlobalVar(gv->name_hint))) return false;
82+
if (other->ContainGlobalVar(gv->name_hint)) {
83+
if (!equal.DefEqual(gv, other->GetGlobalVar(gv->name_hint))) return false;
84+
} else if (!equal.IsPathTracingEnabled()) {
85+
return false;
86+
}
7387
}
74-
// Checking functions
75-
for (const auto& kv : this->functions) {
76-
if (equal.IsPathTracingEnabled()) {
77-
const ObjectPathPair& obj_path_pair = equal.GetCurrentObjectPaths();
78-
ObjectPathPair func_paths = {obj_path_pair->lhs_path->Attr("functions")->MapValue(kv.first),
79-
obj_path_pair->rhs_path->Attr("functions")
80-
->MapValue(other->GetGlobalVar(kv.first->name_hint))};
81-
if (!equal(kv.second, other->Lookup(kv.first->name_hint), func_paths)) return false;
82-
} else {
83-
if (!equal(kv.second, other->Lookup(kv.first->name_hint))) return false;
88+
for (const auto& gtv : this->GetGlobalTypeVars()) {
89+
if (other->ContainGlobalTypeVar(gtv->name_hint)) {
90+
if (!equal.DefEqual(gtv, other->GetGlobalTypeVar(gtv->name_hint))) return false;
91+
} else if (!equal.IsPathTracingEnabled()) {
92+
return false;
8493
}
8594
}
8695

87-
if (type_definitions.size() != other->type_definitions.size()) return false;
88-
// Update GlobalTypeVar remap
89-
for (const auto& gtv : this->GetGlobalTypeVars()) {
90-
if (!other->ContainGlobalTypeVar(gtv->name_hint)) return false;
91-
if (!equal.DefEqual(gtv, other->GetGlobalTypeVar(gtv->name_hint))) return false;
96+
// Checking functions and type definitions
97+
if (!equal(this->functions, other->functions,
98+
[](const auto& path) { return path->Attr("functions"); })) {
99+
return false;
92100
}
93-
// Checking type_definitions
94-
for (const auto& kv : this->type_definitions) {
95-
if (equal.IsPathTracingEnabled()) {
96-
const ObjectPathPair& obj_path_pair = equal.GetCurrentObjectPaths();
97-
ObjectPathPair type_paths = {
98-
obj_path_pair->lhs_path->Attr("type_definitions")->MapValue(kv.first),
99-
obj_path_pair->rhs_path->Attr("type_definitions")
100-
->MapValue(other->GetGlobalTypeVar(kv.first->name_hint))};
101-
if (!equal(kv.second, other->LookupTypeDef(kv.first->name_hint), type_paths)) return false;
102-
} else {
103-
if (!equal(kv.second, other->LookupTypeDef(kv.first->name_hint))) return false;
104-
}
101+
if (!equal(this->type_definitions, other->type_definitions,
102+
[](const auto& path) { return path->Attr("type_definitions"); })) {
103+
return false;
105104
}
105+
106106
return true;
107107
}
108108

‎src/node/structural_equal.cc

+36-15
Original file line numberDiff line numberDiff line change
@@ -109,51 +109,72 @@ bool SEqualReducer::DefEqual(const ObjectRef& lhs, const ObjectRef& rhs) {
109109

110110
template <typename T>
111111
/* static */ bool SEqualReducer::CompareAttributeValues(const T& lhs, const T& rhs,
112-
const PathTracingData* tracing_data) {
112+
const PathTracingData* tracing_data,
113+
Optional<ObjectPathPair> paths) {
113114
if (BaseValueEqual()(lhs, rhs)) {
114115
return true;
115-
} else {
116-
GetPathsFromAttrAddressesAndStoreMismatch(&lhs, &rhs, tracing_data);
117-
return false;
118116
}
117+
118+
if (tracing_data && !tracing_data->first_mismatch->defined()) {
119+
if (paths) {
120+
*tracing_data->first_mismatch = paths.value();
121+
} else {
122+
GetPathsFromAttrAddressesAndStoreMismatch(&lhs, &rhs, tracing_data);
123+
}
124+
}
125+
return false;
119126
}
120127

121-
bool SEqualReducer::operator()(const double& lhs, const double& rhs) const {
128+
bool SEqualReducer::operator()(const double& lhs, const double& rhs,
129+
Optional<ObjectPathPair> paths) const {
122130
return CompareAttributeValues(lhs, rhs, tracing_data_);
123131
}
124132

125-
bool SEqualReducer::operator()(const int64_t& lhs, const int64_t& rhs) const {
133+
bool SEqualReducer::operator()(const int64_t& lhs, const int64_t& rhs,
134+
Optional<ObjectPathPair> paths) const {
126135
return CompareAttributeValues(lhs, rhs, tracing_data_);
127136
}
128137

129-
bool SEqualReducer::operator()(const uint64_t& lhs, const uint64_t& rhs) const {
138+
bool SEqualReducer::operator()(const uint64_t& lhs, const uint64_t& rhs,
139+
Optional<ObjectPathPair> paths) const {
130140
return CompareAttributeValues(lhs, rhs, tracing_data_);
131141
}
132142

133-
bool SEqualReducer::operator()(const int& lhs, const int& rhs) const {
143+
bool SEqualReducer::operator()(const int& lhs, const int& rhs,
144+
Optional<ObjectPathPair> paths) const {
134145
return CompareAttributeValues(lhs, rhs, tracing_data_);
135146
}
136147

137-
bool SEqualReducer::operator()(const bool& lhs, const bool& rhs) const {
148+
bool SEqualReducer::operator()(const bool& lhs, const bool& rhs,
149+
Optional<ObjectPathPair> paths) const {
138150
return CompareAttributeValues(lhs, rhs, tracing_data_);
139151
}
140152

141-
bool SEqualReducer::operator()(const std::string& lhs, const std::string& rhs) const {
153+
bool SEqualReducer::operator()(const std::string& lhs, const std::string& rhs,
154+
Optional<ObjectPathPair> paths) const {
142155
return CompareAttributeValues(lhs, rhs, tracing_data_);
143156
}
144157

145-
bool SEqualReducer::operator()(const DataType& lhs, const DataType& rhs) const {
158+
bool SEqualReducer::operator()(const DataType& lhs, const DataType& rhs,
159+
Optional<ObjectPathPair> paths) const {
146160
return CompareAttributeValues(lhs, rhs, tracing_data_);
147161
}
148162

149163
bool SEqualReducer::EnumAttrsEqual(int lhs, int rhs, const void* lhs_address,
150-
const void* rhs_address) const {
164+
const void* rhs_address, Optional<ObjectPathPair> paths) const {
151165
if (lhs == rhs) {
152166
return true;
153-
} else {
154-
GetPathsFromAttrAddressesAndStoreMismatch(lhs_address, rhs_address, tracing_data_);
155-
return false;
156167
}
168+
169+
if (tracing_data_ && !tracing_data_->first_mismatch->defined()) {
170+
if (paths) {
171+
*tracing_data_->first_mismatch = paths.value();
172+
} else {
173+
GetPathsFromAttrAddressesAndStoreMismatch(&lhs, &rhs, tracing_data_);
174+
}
175+
}
176+
177+
return false;
157178
}
158179

159180
const ObjectPathPair& SEqualReducer::GetCurrentObjectPaths() const {

‎tests/python/unittest/test_tir_structural_equal_hash.py

+26-9
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import pytest
2020
from tvm import te
2121
from tvm.runtime import ObjectPath
22+
from tvm.script import tir as T, ir as I
2223

2324

2425
def consistent_equal(x, y, map_free_vars=False):
@@ -394,13 +395,29 @@ def test_seq_length_mismatch():
394395
assert rhs_path == expected_rhs_path
395396

396397

398+
def test_ir_module_equal():
399+
def generate(n: int):
400+
@I.ir_module
401+
class module:
402+
@T.prim_func
403+
def func(A: T.Buffer(1, "int32")):
404+
for i in range(n):
405+
A[0] = A[0] + 1
406+
407+
return module
408+
409+
# Equivalent IRModules should compare as equivalent, even though
410+
# they have distinct GlobalVars, and GlobalVars usually compare by
411+
# reference equality.
412+
tvm.ir.assert_structural_equal(generate(16), generate(16))
413+
414+
# When there is a difference, the location should include the
415+
# function name that caused the failure.
416+
with pytest.raises(ValueError) as err:
417+
tvm.ir.assert_structural_equal(generate(16), generate(32))
418+
419+
assert '<root>.functions[I.GlobalVar("func")].body.extent.value' in err.value.args[0]
420+
421+
397422
if __name__ == "__main__":
398-
test_exprs()
399-
test_prim_func()
400-
test_attrs()
401-
test_array()
402-
test_env_func()
403-
test_stmt()
404-
test_buffer_storage_scope()
405-
test_buffer_load_store()
406-
test_while()
423+
tvm.testing.main()

0 commit comments

Comments
 (0)
Please sign in to comment.