@@ -137,93 +137,80 @@ union_richcompare(PyObject *a, PyObject *b, int op)
137
137
return result ;
138
138
}
139
139
140
- static PyObject *
141
- flatten_args (PyObject * args )
140
+ static int
141
+ is_same (PyObject * left , PyObject * right )
142
+ {
143
+ int is_ga = _PyGenericAlias_Check (left ) && _PyGenericAlias_Check (right );
144
+ return is_ga ? PyObject_RichCompareBool (left , right , Py_EQ ) : left == right ;
145
+ }
146
+
147
+ static int
148
+ contains (PyObject * * items , Py_ssize_t size , PyObject * obj )
142
149
{
143
- Py_ssize_t arg_length = PyTuple_GET_SIZE (args );
144
- Py_ssize_t total_args = 0 ;
145
- // Get number of total args once it's flattened.
146
- for (Py_ssize_t i = 0 ; i < arg_length ; i ++ ) {
147
- PyObject * arg = PyTuple_GET_ITEM (args , i );
148
- if (_PyUnion_Check (arg )) {
149
- total_args += PyTuple_GET_SIZE (((unionobject * ) arg )-> args );
150
- } else {
151
- total_args ++ ;
150
+ for (int i = 0 ; i < size ; i ++ ) {
151
+ int is_duplicate = is_same (items [i ], obj );
152
+ if (is_duplicate ) { // -1 or 1
153
+ return is_duplicate ;
152
154
}
153
155
}
154
- // Create new tuple of flattened args.
155
- PyObject * flattened_args = PyTuple_New (total_args );
156
- if (flattened_args == NULL ) {
157
- return NULL ;
158
- }
156
+ return 0 ;
157
+ }
158
+
159
+ static PyObject *
160
+ merge (PyObject * * items1 , Py_ssize_t size1 ,
161
+ PyObject * * items2 , Py_ssize_t size2 )
162
+ {
163
+ PyObject * tuple = NULL ;
159
164
Py_ssize_t pos = 0 ;
160
- for (Py_ssize_t i = 0 ; i < arg_length ; i ++ ) {
161
- PyObject * arg = PyTuple_GET_ITEM (args , i );
162
- if (_PyUnion_Check (arg )) {
163
- PyObject * nested_args = ((unionobject * )arg )-> args ;
164
- Py_ssize_t nested_arg_length = PyTuple_GET_SIZE (nested_args );
165
- for (Py_ssize_t j = 0 ; j < nested_arg_length ; j ++ ) {
166
- PyObject * nested_arg = PyTuple_GET_ITEM (nested_args , j );
167
- Py_INCREF (nested_arg );
168
- PyTuple_SET_ITEM (flattened_args , pos , nested_arg );
169
- pos ++ ;
165
+
166
+ for (int i = 0 ; i < size2 ; i ++ ) {
167
+ PyObject * arg = items2 [i ];
168
+ int is_duplicate = contains (items1 , size1 , arg );
169
+ if (is_duplicate < 0 ) {
170
+ Py_XDECREF (tuple );
171
+ return NULL ;
172
+ }
173
+ if (is_duplicate ) {
174
+ continue ;
175
+ }
176
+
177
+ if (tuple == NULL ) {
178
+ tuple = PyTuple_New (size1 + size2 - i );
179
+ if (tuple == NULL ) {
180
+ return NULL ;
170
181
}
171
- } else {
172
- if (arg == Py_None ) {
173
- arg = (PyObject * )& _PyNone_Type ;
182
+ for (; pos < size1 ; pos ++ ) {
183
+ PyObject * a = items1 [pos ];
184
+ Py_INCREF (a );
185
+ PyTuple_SET_ITEM (tuple , pos , a );
174
186
}
175
- Py_INCREF (arg );
176
- PyTuple_SET_ITEM (flattened_args , pos , arg );
177
- pos ++ ;
178
187
}
188
+ Py_INCREF (arg );
189
+ PyTuple_SET_ITEM (tuple , pos , arg );
190
+ pos ++ ;
191
+ }
192
+
193
+ if (tuple ) {
194
+ (void ) _PyTuple_Resize (& tuple , pos );
179
195
}
180
- assert (pos == total_args );
181
- return flattened_args ;
196
+ return tuple ;
182
197
}
183
198
184
- static PyObject *
185
- dedup_and_flatten_args (PyObject * args )
199
+ static PyObject * *
200
+ get_types (PyObject * * obj , Py_ssize_t * size )
186
201
{
187
- args = flatten_args (args );
188
- if (args == NULL ) {
189
- return NULL ;
202
+ if (* obj == Py_None ) {
203
+ * obj = (PyObject * )& _PyNone_Type ;
190
204
}
191
- Py_ssize_t arg_length = PyTuple_GET_SIZE (args );
192
- PyObject * new_args = PyTuple_New (arg_length );
193
- if (new_args == NULL ) {
194
- Py_DECREF (args );
195
- return NULL ;
205
+ if (_PyUnion_Check (* obj )) {
206
+ PyObject * args = ((unionobject * ) * obj )-> args ;
207
+ * size = PyTuple_GET_SIZE (args );
208
+ return & PyTuple_GET_ITEM (args , 0 );
196
209
}
197
- // Add unique elements to an array.
198
- Py_ssize_t added_items = 0 ;
199
- for (Py_ssize_t i = 0 ; i < arg_length ; i ++ ) {
200
- int is_duplicate = 0 ;
201
- PyObject * i_element = PyTuple_GET_ITEM (args , i );
202
- for (Py_ssize_t j = 0 ; j < added_items ; j ++ ) {
203
- PyObject * j_element = PyTuple_GET_ITEM (new_args , j );
204
- int is_ga = _PyGenericAlias_Check (i_element ) &&
205
- _PyGenericAlias_Check (j_element );
206
- // RichCompare to also deduplicate GenericAlias types (slower)
207
- is_duplicate = is_ga ? PyObject_RichCompareBool (i_element , j_element , Py_EQ )
208
- : i_element == j_element ;
209
- // Should only happen if RichCompare fails
210
- if (is_duplicate < 0 ) {
211
- Py_DECREF (args );
212
- Py_DECREF (new_args );
213
- return NULL ;
214
- }
215
- if (is_duplicate )
216
- break ;
217
- }
218
- if (!is_duplicate ) {
219
- Py_INCREF (i_element );
220
- PyTuple_SET_ITEM (new_args , added_items , i_element );
221
- added_items ++ ;
222
- }
210
+ else {
211
+ * size = 1 ;
212
+ return obj ;
223
213
}
224
- Py_DECREF (args );
225
- _PyTuple_Resize (& new_args , added_items );
226
- return new_args ;
227
214
}
228
215
229
216
static int
@@ -242,9 +229,16 @@ _Py_union_type_or(PyObject* self, PyObject* other)
242
229
Py_RETURN_NOTIMPLEMENTED ;
243
230
}
244
231
245
- PyObject * tuple = PyTuple_Pack (2 , self , other );
232
+ Py_ssize_t size1 , size2 ;
233
+ PyObject * * items1 = get_types (& self , & size1 );
234
+ PyObject * * items2 = get_types (& other , & size2 );
235
+ PyObject * tuple = merge (items1 , size1 , items2 , size2 );
246
236
if (tuple == NULL ) {
247
- return NULL ;
237
+ if (PyErr_Occurred ()) {
238
+ return NULL ;
239
+ }
240
+ Py_INCREF (self );
241
+ return self ;
248
242
}
249
243
250
244
PyObject * new_union = make_union (tuple );
@@ -468,23 +462,12 @@ make_union(PyObject *args)
468
462
{
469
463
assert (PyTuple_CheckExact (args ));
470
464
471
- args = dedup_and_flatten_args (args );
472
- if (args == NULL ) {
473
- return NULL ;
474
- }
475
- if (PyTuple_GET_SIZE (args ) == 1 ) {
476
- PyObject * result1 = PyTuple_GET_ITEM (args , 0 );
477
- Py_INCREF (result1 );
478
- Py_DECREF (args );
479
- return result1 ;
480
- }
481
-
482
465
unionobject * result = PyObject_GC_New (unionobject , & _PyUnion_Type );
483
466
if (result == NULL ) {
484
- Py_DECREF (args );
485
467
return NULL ;
486
468
}
487
469
470
+ Py_INCREF (args );
488
471
result -> parameters = NULL ;
489
472
result -> args = args ;
490
473
_PyObject_GC_TRACK (result );
0 commit comments