Skip to content

Commit 4ba791e

Browse files
committed
Handle large integers (> 64 bits) for the IntegerAttr C-API
Fixes issue llvm#128072. Allows for arbitrarily sized integers to be requested via Python.
1 parent 9cab82f commit 4ba791e

File tree

4 files changed

+165
-7
lines changed

4 files changed

+165
-7
lines changed

mlir/include/mlir-c/BuiltinAttributes.h

+19
Original file line numberDiff line numberDiff line change
@@ -158,6 +158,25 @@ MLIR_CAPI_EXPORTED uint64_t mlirIntegerAttrGetValueUInt(MlirAttribute attr);
158158
/// Returns the typeID of an Integer attribute.
159159
MLIR_CAPI_EXPORTED MlirTypeID mlirIntegerAttrGetTypeID(void);
160160

161+
// Used to create large IntegerAttr's (>64 bits) via the CAPI
162+
// See
163+
// https://github.com/llvm/llvm-project/issues/128072#issuecomment-2672767777
164+
typedef struct {
165+
size_t numbits;
166+
union {
167+
uint64_t *pVAL;
168+
uint64_t VAL;
169+
} data;
170+
} apint_interop_t;
171+
172+
// Creates an APInt interop from an IntegerAttr
173+
MLIR_CAPI_EXPORTED int mlirIntegerAttrGetValueInterop(MlirAttribute attr,
174+
apint_interop_t *interop);
175+
176+
// Creates an integer attribute of the given type from an APInt interop
177+
MLIR_CAPI_EXPORTED MlirAttribute
178+
mlirIntegerAttrFromInterop(MlirType type, apint_interop_t *interop);
179+
161180
//===----------------------------------------------------------------------===//
162181
// Bool attribute.
163182
//===----------------------------------------------------------------------===//

mlir/lib/Bindings/Python/IRAttributes.cpp

+67-7
Original file line numberDiff line numberDiff line change
@@ -601,8 +601,31 @@ class PyIntegerAttribute : public PyConcreteAttribute<PyIntegerAttribute> {
601601
static void bindDerived(ClassTy &c) {
602602
c.def_static(
603603
"get",
604-
[](PyType &type, int64_t value) {
605-
MlirAttribute attr = mlirIntegerAttrGet(type, value);
604+
[](PyType &type, py::int_ value) {
605+
apint_interop_t interop;
606+
if (mlirTypeIsAIndex(type))
607+
interop.numbits = 64;
608+
else
609+
interop.numbits = mlirIntegerTypeGetWidth((MlirType)type);
610+
611+
py::object to_bytes = value.attr("to_bytes");
612+
int numbytes = (interop.numbits + 7) / 8;
613+
bool Signed = mlirTypeIsAIndex(type) || mlirIntegerTypeIsSigned(type);
614+
py::bytes bytes_obj =
615+
to_bytes(numbytes, "little", py::arg("signed") = Signed);
616+
const char *data = bytes_obj.data();
617+
618+
if (interop.numbits <= 64) {
619+
memcpy((char *)&(interop.data.VAL), data, numbytes);
620+
} else {
621+
int numdoublewords = (interop.numbits + 63) / 64;
622+
interop.data.pVAL =
623+
(uint64_t *)malloc(numdoublewords, sizeof(uint64_t));
624+
memcpy((char *)interop.data.pVAL, data, numbytes);
625+
}
626+
MlirAttribute attr = mlirIntegerAttrFromInterop(type, &interop);
627+
if (interop.numbits <= 64)
628+
free(interop.data.pVAL);
606629
return PyIntegerAttribute(type.getContext(), attr);
607630
},
608631
nb::arg("type"), nb::arg("value"),
@@ -620,11 +643,48 @@ class PyIntegerAttribute : public PyConcreteAttribute<PyIntegerAttribute> {
620643
private:
621644
static int64_t toPyInt(PyIntegerAttribute &self) {
622645
MlirType type = mlirAttributeGetType(self);
623-
if (mlirTypeIsAIndex(type) || mlirIntegerTypeIsSignless(type))
624-
return mlirIntegerAttrGetValueInt(self);
625-
if (mlirIntegerTypeIsSigned(type))
626-
return mlirIntegerAttrGetValueSInt(self);
627-
return mlirIntegerAttrGetValueUInt(self);
646+
apint_interop_t interop;
647+
if (mlirTypeIsAIndex(type))
648+
interop.numbits = 64;
649+
else
650+
interop.numbits = mlirIntegerTypeGetWidth((MlirType)type);
651+
if (interop.numbits > 64) {
652+
size_t required_doublewords = (interop.numbits + 63) / 64;
653+
interop.data.pVAL =
654+
(uint64_t *)malloc(required_doublewords, sizeof(uint64_t));
655+
}
656+
mlirIntegerAttrGetValueInterop(self, &interop);
657+
658+
// Need to sign extend the last byte for conversion to py::bytes
659+
bool Signed = mlirTypeIsAIndex(type) || mlirIntegerTypeIsSigned(type);
660+
if (Signed) {
661+
size_t last_doubleword = (interop.numbits - 1) / 64;
662+
size_t last_bit = interop.numbits - 1 - (64 * last_doubleword);
663+
uint64_t sext_mask = -1 << last_bit;
664+
665+
if (interop.numbits > 64) {
666+
if ((interop.data.pVAL[last_doubleword] >> last_bit) & 1) {
667+
interop.data.pVAL[last_doubleword] |= sext_mask;
668+
}
669+
} else {
670+
if ((interop.data.VAL >> last_bit) & 1) {
671+
interop.data.VAL |= sext_mask;
672+
}
673+
}
674+
}
675+
676+
py::int_ int_obj;
677+
py::object from_bytes = int_obj.attr("from_bytes");
678+
size_t numbytes = (interop.numbits + 7) / 8;
679+
py::bytes bytes_obj;
680+
if (interop.numbits > 64) {
681+
bytes_obj = py::bytes((const char *)interop.data.pVAL, numbytes);
682+
free(interop.data.pVAL);
683+
} else {
684+
bytes_obj = py::bytes((const char *)&interop.data.VAL, numbytes);
685+
}
686+
int_obj = from_bytes(bytes_obj, "little", py::arg("signed") = Signed);
687+
return int_obj;
628688
}
629689
};
630690

mlir/lib/CAPI/IR/BuiltinAttributes.cpp

+34
Original file line numberDiff line numberDiff line change
@@ -161,6 +161,40 @@ uint64_t mlirIntegerAttrGetValueUInt(MlirAttribute attr) {
161161
return llvm::cast<IntegerAttr>(unwrap(attr)).getUInt();
162162
}
163163

164+
int mlirIntegerAttrGetValueInterop(MlirAttribute attr,
165+
apint_interop_t *interop) {
166+
size_t needed_bit_width =
167+
llvm::cast<IntegerAttr>(unwrap(attr)).getValue().getBitWidth();
168+
if (interop->numbits < needed_bit_width) {
169+
interop->numbits = needed_bit_width;
170+
return 1;
171+
}
172+
if (interop->numbits <= 64) {
173+
interop->data.VAL =
174+
llvm::cast<IntegerAttr>(unwrap(attr)).getValue().getRawData()[0];
175+
return 0;
176+
}
177+
int memcpy_bytes = (interop->numbits + 7) / 8;
178+
memcpy((void *)interop->data.pVAL,
179+
(const void *)llvm::cast<IntegerAttr>(unwrap(attr))
180+
.getValue()
181+
.getRawData(),
182+
memcpy_bytes);
183+
return 0;
184+
}
185+
186+
MlirAttribute mlirIntegerAttrFromInterop(MlirType type,
187+
apint_interop_t *interop) {
188+
if (interop->numbits <= 64) {
189+
return wrap(IntegerAttr::get(unwrap(type), interop->data.VAL));
190+
}
191+
APInt apInt(interop->numbits,
192+
llvm::ArrayRef<uint64_t>(interop->data.pVAL,
193+
(interop->numbits + 63) / 64));
194+
IntegerAttr value = IntegerAttr::get(unwrap(type), apInt);
195+
return wrap(value);
196+
}
197+
164198
MlirTypeID mlirIntegerAttrGetTypeID(void) {
165199
return wrap(IntegerAttr::getTypeID());
166200
}

mlir/test/python/ir/attributes.py

+45
Original file line numberDiff line numberDiff line change
@@ -239,6 +239,51 @@ def testIntegerAttr():
239239
print("default_get:", IntegerAttr.get(IntegerType.get_signless(32), 42))
240240

241241

242+
@run
243+
def testLargeIntegerAttr():
244+
with Context() as ctx:
245+
max_positive_64_val = 0x7fffffffffffffff
246+
max_positive_64 = IntegerAttr.get(IntegerType.get_signed(64), max_positive_64_val)
247+
# CHECK: max_positive_64: 9223372036854775807 : si64
248+
print("max_positive_64:", max_positive_64)
249+
assert(int(max_positive_64) == max_positive_64_val)
250+
251+
neg_one_64_val = -1
252+
neg_one_64 = IntegerAttr.get(IntegerType.get_signed(64), neg_one_64_val)
253+
# CHECK: neg_one_64: -1 : si64
254+
print("neg_one_64:", neg_one_64)
255+
assert(int(neg_one_64) == neg_one_64_val)
256+
257+
max_unsigned_64_val = 0xffffffffffffffff
258+
max_unsigned_64 = IntegerAttr.get(IntegerType.get_signless(64), max_unsigned_64_val)
259+
# CHECK: max_unsigned_64: -1 : i64
260+
print("max_unsigned_64:", max_unsigned_64)
261+
assert(int(max_unsigned_64) == max_unsigned_64_val)
262+
263+
random_64_val = 0x0123456789ABCDEF
264+
random_64 = IntegerAttr.get(IntegerType.get_signless(64), random_64_val)
265+
# CHECK: random_64: 81985529216486895 : i64
266+
print("random_64:", random_64)
267+
assert(int(random_64) == random_64_val)
268+
269+
max_unsigned_65_val = 0x1FFFFFFFFFFFFFFFF
270+
max_unsigned_65 = IntegerAttr.get(IntegerType.get_unsigned(65), max_unsigned_65_val)
271+
# CHECK: max_unsigned_65: 36893488147419103231 : ui65
272+
print("max_unsigned_65:", max_unsigned_65)
273+
assert(int(max_unsigned_65) == max_unsigned_65_val)
274+
275+
random_128_val = 0x0123456789ABCDEF0123456789ABCDEF
276+
random_128 = IntegerAttr.get(IntegerType.get_signless(128), random_128_val)
277+
# CHECK: random_128: 1512366075204170929049582354406559215 : i128
278+
print("random_128:", random_128)
279+
assert(int(random_128) == random_128_val)
280+
281+
random_92_val = 0x9ABCDEF0123456789ABCDEF
282+
random_92 = IntegerAttr.get(IntegerType.get_signless(92), random_92_val)
283+
# CHECK: random_92: -1958696259612506469130580497 : i92
284+
print("random_92:", random_92)
285+
assert(int(random_92) == random_92_val)
286+
242287
# CHECK-LABEL: TEST: testBoolAttr
243288
@run
244289
def testBoolAttr():

0 commit comments

Comments
 (0)