Skip to content

Commit 2c45148

Browse files
neoneneerlend-aaslandencukou
authored
gh-117578: Introduce _PyType_GetModuleByDef2 private function (GH-117661)
Co-authored-by: Erlend E. Aasland <erlend.aasland@protonmail.com> Co-authored-by: Petr Viktorin <encukou@gmail.com>
1 parent f180b31 commit 2c45148

File tree

3 files changed

+49
-12
lines changed

3 files changed

+49
-12
lines changed

Include/internal/pycore_typeobject.h

+1
Original file line numberDiff line numberDiff line change
@@ -164,6 +164,7 @@ extern PyObject * _PyType_GetBases(PyTypeObject *type);
164164
extern PyObject * _PyType_GetMRO(PyTypeObject *type);
165165
extern PyObject* _PyType_GetSubclasses(PyTypeObject *);
166166
extern int _PyType_HasSubclasses(PyTypeObject *);
167+
PyAPI_FUNC(PyObject *) _PyType_GetModuleByDef2(PyTypeObject *, PyTypeObject *, PyModuleDef *);
167168

168169
// PyType_Ready() must be called if _PyType_IsReady() is false.
169170
// See also the Py_TPFLAGS_READY flag.

Modules/_decimal/_decimal.c

+3-5
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
#include <Python.h>
3333
#include "pycore_long.h" // _PyLong_IsZero()
3434
#include "pycore_pystate.h" // _PyThreadState_GET()
35+
#include "pycore_typeobject.h"
3536
#include "complexobject.h"
3637
#include "mpdecimal.h"
3738

@@ -120,11 +121,8 @@ get_module_state_by_def(PyTypeObject *tp)
120121
static inline decimal_state *
121122
find_state_left_or_right(PyObject *left, PyObject *right)
122123
{
123-
PyObject *mod = PyType_GetModuleByDef(Py_TYPE(left), &_decimal_module);
124-
if (mod == NULL) {
125-
PyErr_Clear();
126-
mod = PyType_GetModuleByDef(Py_TYPE(right), &_decimal_module);
127-
}
124+
PyObject *mod = _PyType_GetModuleByDef2(Py_TYPE(left), Py_TYPE(right),
125+
&_decimal_module);
128126
assert(mod != NULL);
129127
return get_module_state(mod);
130128
}

Objects/typeobject.c

+45-7
Original file line numberDiff line numberDiff line change
@@ -4825,24 +4825,39 @@ PyType_GetModuleState(PyTypeObject *type)
48254825
/* Get the module of the first superclass where the module has the
48264826
* given PyModuleDef.
48274827
*/
4828-
PyObject *
4829-
PyType_GetModuleByDef(PyTypeObject *type, PyModuleDef *def)
4828+
static inline PyObject *
4829+
get_module_by_def(PyTypeObject *type, PyModuleDef *def)
48304830
{
48314831
assert(PyType_Check(type));
48324832

4833+
if (!_PyType_HasFeature(type, Py_TPFLAGS_HEAPTYPE)) {
4834+
// type_ready_mro() ensures that no heap type is
4835+
// contained in a static type MRO.
4836+
return NULL;
4837+
}
4838+
else {
4839+
PyHeapTypeObject *ht = (PyHeapTypeObject*)type;
4840+
PyObject *module = ht->ht_module;
4841+
if (module && _PyModule_GetDef(module) == def) {
4842+
return module;
4843+
}
4844+
}
4845+
48334846
PyObject *res = NULL;
48344847
BEGIN_TYPE_LOCK()
48354848

48364849
PyObject *mro = lookup_tp_mro(type);
48374850
// The type must be ready
48384851
assert(mro != NULL);
48394852
assert(PyTuple_Check(mro));
4840-
// mro_invoke() ensures that the type MRO cannot be empty, so we don't have
4841-
// to check i < PyTuple_GET_SIZE(mro) at the first loop iteration.
4853+
// mro_invoke() ensures that the type MRO cannot be empty.
48424854
assert(PyTuple_GET_SIZE(mro) >= 1);
4855+
// Also, the first item in the MRO is the type itself, which
4856+
// we already checked above. We skip it in the loop.
4857+
assert(PyTuple_GET_ITEM(mro, 0) == (PyObject *)type);
48434858

48444859
Py_ssize_t n = PyTuple_GET_SIZE(mro);
4845-
for (Py_ssize_t i = 0; i < n; i++) {
4860+
for (Py_ssize_t i = 1; i < n; i++) {
48464861
PyObject *super = PyTuple_GET_ITEM(mro, i);
48474862
if(!_PyType_HasFeature((PyTypeObject *)super, Py_TPFLAGS_HEAPTYPE)) {
48484863
// Static types in the MRO need to be skipped
@@ -4857,14 +4872,37 @@ PyType_GetModuleByDef(PyTypeObject *type, PyModuleDef *def)
48574872
}
48584873
}
48594874
END_TYPE_LOCK()
4875+
return res;
4876+
}
48604877

4861-
if (res == NULL) {
4878+
PyObject *
4879+
PyType_GetModuleByDef(PyTypeObject *type, PyModuleDef *def)
4880+
{
4881+
PyObject *module = get_module_by_def(type, def);
4882+
if (module == NULL) {
48624883
PyErr_Format(
48634884
PyExc_TypeError,
48644885
"PyType_GetModuleByDef: No superclass of '%s' has the given module",
48654886
type->tp_name);
48664887
}
4867-
return res;
4888+
return module;
4889+
}
4890+
4891+
PyObject *
4892+
_PyType_GetModuleByDef2(PyTypeObject *left, PyTypeObject *right,
4893+
PyModuleDef *def)
4894+
{
4895+
PyObject *module = get_module_by_def(left, def);
4896+
if (module == NULL) {
4897+
module = get_module_by_def(right, def);
4898+
if (module == NULL) {
4899+
PyErr_Format(
4900+
PyExc_TypeError,
4901+
"PyType_GetModuleByDef: No superclass of '%s' nor '%s' has "
4902+
"the given module", left->tp_name, right->tp_name);
4903+
}
4904+
}
4905+
return module;
48684906
}
48694907

48704908
void *

0 commit comments

Comments
 (0)