@@ -601,8 +601,31 @@ class PyIntegerAttribute : public PyConcreteAttribute<PyIntegerAttribute> {
601
601
static void bindDerived (ClassTy &c) {
602
602
c.def_static (
603
603
" get" ,
604
- [](PyType &type, int64_t value) {
605
- MlirAttribute attr = mlirIntegerAttrGet (type, value);
604
+ [](PyType &type, py::int_ value) {
605
+ apint_interop_t interop;
606
+ if (mlirTypeIsAIndex (type))
607
+ interop.numbits = 64 ;
608
+ else
609
+ interop.numbits = mlirIntegerTypeGetWidth ((MlirType)type);
610
+
611
+ py::object to_bytes = value.attr (" to_bytes" );
612
+ int numbytes = (interop.numbits + 7 ) / 8 ;
613
+ bool Signed = mlirTypeIsAIndex (type) || mlirIntegerTypeIsSigned (type);
614
+ py::bytes bytes_obj =
615
+ to_bytes (numbytes, " little" , py::arg (" signed" ) = Signed);
616
+ const char *data = bytes_obj.data ();
617
+
618
+ if (interop.numbits <= 64 ) {
619
+ memcpy ((char *)&(interop.data .VAL ), data, numbytes);
620
+ } else {
621
+ int numdoublewords = (interop.numbits + 63 ) / 64 ;
622
+ interop.data .pVAL =
623
+ (uint64_t *)malloc (numdoublewords, sizeof (uint64_t ));
624
+ memcpy ((char *)interop.data .pVAL , data, numbytes);
625
+ }
626
+ MlirAttribute attr = mlirIntegerAttrFromInterop (type, &interop);
627
+ if (interop.numbits <= 64 )
628
+ free (interop.data .pVAL );
606
629
return PyIntegerAttribute (type.getContext (), attr);
607
630
},
608
631
nb::arg (" type" ), nb::arg (" value" ),
@@ -620,11 +643,48 @@ class PyIntegerAttribute : public PyConcreteAttribute<PyIntegerAttribute> {
620
643
private:
621
644
static int64_t toPyInt (PyIntegerAttribute &self) {
622
645
MlirType type = mlirAttributeGetType (self);
623
- if (mlirTypeIsAIndex (type) || mlirIntegerTypeIsSignless (type))
624
- return mlirIntegerAttrGetValueInt (self);
625
- if (mlirIntegerTypeIsSigned (type))
626
- return mlirIntegerAttrGetValueSInt (self);
627
- return mlirIntegerAttrGetValueUInt (self);
646
+ apint_interop_t interop;
647
+ if (mlirTypeIsAIndex (type))
648
+ interop.numbits = 64 ;
649
+ else
650
+ interop.numbits = mlirIntegerTypeGetWidth ((MlirType)type);
651
+ if (interop.numbits > 64 ) {
652
+ size_t required_doublewords = (interop.numbits + 63 ) / 64 ;
653
+ interop.data .pVAL =
654
+ (uint64_t *)malloc (required_doublewords, sizeof (uint64_t ));
655
+ }
656
+ mlirIntegerAttrGetValueInterop (self, &interop);
657
+
658
+ // Need to sign extend the last byte for conversion to py::bytes
659
+ bool Signed = mlirTypeIsAIndex (type) || mlirIntegerTypeIsSigned (type);
660
+ if (Signed) {
661
+ size_t last_doubleword = (interop.numbits - 1 ) / 64 ;
662
+ size_t last_bit = interop.numbits - 1 - (64 * last_doubleword);
663
+ uint64_t sext_mask = -1 << last_bit;
664
+
665
+ if (interop.numbits > 64 ) {
666
+ if ((interop.data .pVAL [last_doubleword] >> last_bit) & 1 ) {
667
+ interop.data .pVAL [last_doubleword] |= sext_mask;
668
+ }
669
+ } else {
670
+ if ((interop.data .VAL >> last_bit) & 1 ) {
671
+ interop.data .VAL |= sext_mask;
672
+ }
673
+ }
674
+ }
675
+
676
+ py::int_ int_obj;
677
+ py::object from_bytes = int_obj.attr (" from_bytes" );
678
+ size_t numbytes = (interop.numbits + 7 ) / 8 ;
679
+ py::bytes bytes_obj;
680
+ if (interop.numbits > 64 ) {
681
+ bytes_obj = py::bytes ((const char *)interop.data .pVAL , numbytes);
682
+ free (interop.data .pVAL );
683
+ } else {
684
+ bytes_obj = py::bytes ((const char *)&interop.data .VAL , numbytes);
685
+ }
686
+ int_obj = from_bytes (bytes_obj, " little" , py::arg (" signed" ) = Signed);
687
+ return int_obj;
628
688
}
629
689
};
630
690
0 commit comments