Coverage Report

Created: 2025-07-04 06:49

/src/cpython/Objects/unionobject.c
Line
Count
Source (jump to first uncovered line)
1
// typing.Union -- used to represent e.g. Union[int, str], int | str
2
#include "Python.h"
3
#include "pycore_object.h"  // _PyObject_GC_TRACK/UNTRACK
4
#include "pycore_typevarobject.h"  // _PyTypeAlias_Type, _Py_typing_type_repr
5
#include "pycore_unicodeobject.h" // _PyUnicode_EqualToASCIIString
6
#include "pycore_unionobject.h"
7
#include "pycore_weakref.h"       // FT_CLEAR_WEAKREFS()
8
9
10
typedef struct {
11
    PyObject_HEAD
12
    PyObject *args;  // all args (tuple)
13
    PyObject *hashable_args;  // frozenset or NULL
14
    PyObject *unhashable_args;  // tuple or NULL
15
    PyObject *parameters;
16
    PyObject *weakreflist;
17
} unionobject;
18
19
static void
20
unionobject_dealloc(PyObject *self)
21
11
{
22
11
    unionobject *alias = (unionobject *)self;
23
24
11
    _PyObject_GC_UNTRACK(self);
25
11
    FT_CLEAR_WEAKREFS(self, alias->weakreflist);
26
27
11
    Py_XDECREF(alias->args);
28
11
    Py_XDECREF(alias->hashable_args);
29
11
    Py_XDECREF(alias->unhashable_args);
30
11
    Py_XDECREF(alias->parameters);
31
11
    Py_TYPE(self)->tp_free(self);
32
11
}
33
34
static int
35
union_traverse(PyObject *self, visitproc visit, void *arg)
36
81.2k
{
37
81.2k
    unionobject *alias = (unionobject *)self;
38
81.2k
    Py_VISIT(alias->args);
39
81.2k
    Py_VISIT(alias->hashable_args);
40
81.2k
    Py_VISIT(alias->unhashable_args);
41
81.2k
    Py_VISIT(alias->parameters);
42
81.2k
    return 0;
43
81.2k
}
44
45
static Py_hash_t
46
union_hash(PyObject *self)
47
0
{
48
0
    unionobject *alias = (unionobject *)self;
49
    // If there are any unhashable args, treat this union as unhashable.
50
    // Otherwise, two unions might compare equal but have different hashes.
51
0
    if (alias->unhashable_args) {
52
        // Attempt to get an error from one of the values.
53
0
        assert(PyTuple_CheckExact(alias->unhashable_args));
54
0
        Py_ssize_t n = PyTuple_GET_SIZE(alias->unhashable_args);
55
0
        for (Py_ssize_t i = 0; i < n; i++) {
56
0
            PyObject *arg = PyTuple_GET_ITEM(alias->unhashable_args, i);
57
0
            Py_hash_t hash = PyObject_Hash(arg);
58
0
            if (hash == -1) {
59
0
                return -1;
60
0
            }
61
0
        }
62
        // The unhashable values somehow became hashable again. Still raise
63
        // an error.
64
0
        PyErr_Format(PyExc_TypeError, "union contains %d unhashable elements", n);
65
0
        return -1;
66
0
    }
67
0
    return PyObject_Hash(alias->hashable_args);
68
0
}
69
70
static int
71
unions_equal(unionobject *a, unionobject *b)
72
0
{
73
0
    int result = PyObject_RichCompareBool(a->hashable_args, b->hashable_args, Py_EQ);
74
0
    if (result == -1) {
75
0
        return -1;
76
0
    }
77
0
    if (result == 0) {
78
0
        return 0;
79
0
    }
80
0
    if (a->unhashable_args && b->unhashable_args) {
81
0
        Py_ssize_t n = PyTuple_GET_SIZE(a->unhashable_args);
82
0
        if (n != PyTuple_GET_SIZE(b->unhashable_args)) {
83
0
            return 0;
84
0
        }
85
0
        for (Py_ssize_t i = 0; i < n; i++) {
86
0
            PyObject *arg_a = PyTuple_GET_ITEM(a->unhashable_args, i);
87
0
            int result = PySequence_Contains(b->unhashable_args, arg_a);
88
0
            if (result == -1) {
89
0
                return -1;
90
0
            }
91
0
            if (!result) {
92
0
                return 0;
93
0
            }
94
0
        }
95
0
        for (Py_ssize_t i = 0; i < n; i++) {
96
0
            PyObject *arg_b = PyTuple_GET_ITEM(b->unhashable_args, i);
97
0
            int result = PySequence_Contains(a->unhashable_args, arg_b);
98
0
            if (result == -1) {
99
0
                return -1;
100
0
            }
101
0
            if (!result) {
102
0
                return 0;
103
0
            }
104
0
        }
105
0
    }
106
0
    else if (a->unhashable_args || b->unhashable_args) {
107
0
        return 0;
108
0
    }
109
0
    return 1;
110
0
}
111
112
static PyObject *
113
union_richcompare(PyObject *a, PyObject *b, int op)
114
0
{
115
0
    if (!_PyUnion_Check(b) || (op != Py_EQ && op != Py_NE)) {
116
0
        Py_RETURN_NOTIMPLEMENTED;
117
0
    }
118
119
0
    int equal = unions_equal((unionobject*)a, (unionobject*)b);
120
0
    if (equal == -1) {
121
0
        return NULL;
122
0
    }
123
0
    if (op == Py_EQ) {
124
0
        return PyBool_FromLong(equal);
125
0
    }
126
0
    else {
127
0
        return PyBool_FromLong(!equal);
128
0
    }
129
0
}
130
131
typedef struct {
132
    PyObject *args;  // list
133
    PyObject *hashable_args;  // set
134
    PyObject *unhashable_args;  // list or NULL
135
    bool is_checked;  // whether to call type_check()
136
} unionbuilder;
137
138
static bool unionbuilder_add_tuple(unionbuilder *, PyObject *);
139
static PyObject *make_union(unionbuilder *);
140
static PyObject *type_check(PyObject *, const char *);
141
142
static bool
143
unionbuilder_init(unionbuilder *ub, bool is_checked)
144
421
{
145
421
    ub->args = PyList_New(0);
146
421
    if (ub->args == NULL) {
147
0
        return false;
148
0
    }
149
421
    ub->hashable_args = PySet_New(NULL);
150
421
    if (ub->hashable_args == NULL) {
151
0
        Py_DECREF(ub->args);
152
0
        return false;
153
0
    }
154
421
    ub->unhashable_args = NULL;
155
421
    ub->is_checked = is_checked;
156
421
    return true;
157
421
}
158
159
static void
160
unionbuilder_finalize(unionbuilder *ub)
161
421
{
162
421
    Py_DECREF(ub->args);
163
421
    Py_DECREF(ub->hashable_args);
164
421
    Py_XDECREF(ub->unhashable_args);
165
421
}
166
167
static bool
168
unionbuilder_add_single_unchecked(unionbuilder *ub, PyObject *arg)
169
842
{
170
842
    Py_hash_t hash = PyObject_Hash(arg);
171
842
    if (hash == -1) {
172
0
        PyErr_Clear();
173
0
        if (ub->unhashable_args == NULL) {
174
0
            ub->unhashable_args = PyList_New(0);
175
0
            if (ub->unhashable_args == NULL) {
176
0
                return false;
177
0
            }
178
0
        }
179
0
        else {
180
0
            int contains = PySequence_Contains(ub->unhashable_args, arg);
181
0
            if (contains < 0) {
182
0
                return false;
183
0
            }
184
0
            if (contains == 1) {
185
0
                return true;
186
0
            }
187
0
        }
188
0
        if (PyList_Append(ub->unhashable_args, arg) < 0) {
189
0
            return false;
190
0
        }
191
0
    }
192
842
    else {
193
842
        int contains = PySet_Contains(ub->hashable_args, arg);
194
842
        if (contains < 0) {
195
0
            return false;
196
0
        }
197
842
        if (contains == 1) {
198
0
            return true;
199
0
        }
200
842
        if (PySet_Add(ub->hashable_args, arg) < 0) {
201
0
            return false;
202
0
        }
203
842
    }
204
842
    return PyList_Append(ub->args, arg) == 0;
205
842
}
206
207
static bool
208
unionbuilder_add_single(unionbuilder *ub, PyObject *arg)
209
842
{
210
842
    if (Py_IsNone(arg)) {
211
410
        arg = (PyObject *)&_PyNone_Type;  // immortal, so no refcounting needed
212
410
    }
213
432
    else if (_PyUnion_Check(arg)) {
214
0
        PyObject *args = ((unionobject *)arg)->args;
215
0
        return unionbuilder_add_tuple(ub, args);
216
0
    }
217
842
    if (ub->is_checked) {
218
0
        PyObject *type = type_check(arg, "Union[arg, ...]: each arg must be a type.");
219
0
        if (type == NULL) {
220
0
            return false;
221
0
        }
222
0
        bool result = unionbuilder_add_single_unchecked(ub, type);
223
0
        Py_DECREF(type);
224
0
        return result;
225
0
    }
226
842
    else {
227
842
        return unionbuilder_add_single_unchecked(ub, arg);
228
842
    }
229
842
}
230
231
static bool
232
unionbuilder_add_tuple(unionbuilder *ub, PyObject *tuple)
233
0
{
234
0
    Py_ssize_t n = PyTuple_GET_SIZE(tuple);
235
0
    for (Py_ssize_t i = 0; i < n; i++) {
236
0
        if (!unionbuilder_add_single(ub, PyTuple_GET_ITEM(tuple, i))) {
237
0
            return false;
238
0
        }
239
0
    }
240
0
    return true;
241
0
}
242
243
static int
244
is_unionable(PyObject *obj)
245
842
{
246
842
    if (obj == Py_None ||
247
842
        PyType_Check(obj) ||
248
842
        _PyGenericAlias_Check(obj) ||
249
842
        _PyUnion_Check(obj) ||
250
842
        Py_IS_TYPE(obj, &_PyTypeAlias_Type)) {
251
842
        return 1;
252
842
    }
253
0
    return 0;
254
842
}
255
256
PyObject *
257
_Py_union_type_or(PyObject* self, PyObject* other)
258
421
{
259
421
    if (!is_unionable(self) || !is_unionable(other)) {
260
0
        Py_RETURN_NOTIMPLEMENTED;
261
0
    }
262
263
421
    unionbuilder ub;
264
    // unchecked because we already checked is_unionable()
265
421
    if (!unionbuilder_init(&ub, false)) {
266
0
        return NULL;
267
0
    }
268
421
    if (!unionbuilder_add_single(&ub, self) ||
269
421
        !unionbuilder_add_single(&ub, other)) {
270
0
        unionbuilder_finalize(&ub);
271
0
        return NULL;
272
0
    }
273
274
421
    PyObject *new_union = make_union(&ub);
275
421
    return new_union;
276
421
}
277
278
static PyObject *
279
union_repr(PyObject *self)
280
0
{
281
0
    unionobject *alias = (unionobject *)self;
282
0
    Py_ssize_t len = PyTuple_GET_SIZE(alias->args);
283
284
    // Shortest type name "int" (3 chars) + " | " (3 chars) separator
285
0
    Py_ssize_t estimate = (len <= PY_SSIZE_T_MAX / 6) ? len * 6 : len;
286
0
    PyUnicodeWriter *writer = PyUnicodeWriter_Create(estimate);
287
0
    if (writer == NULL) {
288
0
        return NULL;
289
0
    }
290
291
0
    for (Py_ssize_t i = 0; i < len; i++) {
292
0
        if (i > 0 && PyUnicodeWriter_WriteASCII(writer, " | ", 3) < 0) {
293
0
            goto error;
294
0
        }
295
0
        PyObject *p = PyTuple_GET_ITEM(alias->args, i);
296
0
        if (_Py_typing_type_repr(writer, p) < 0) {
297
0
            goto error;
298
0
        }
299
0
    }
300
301
#if 0
302
    PyUnicodeWriter_WriteASCII(writer, "|args=", 6);
303
    PyUnicodeWriter_WriteRepr(writer, alias->args);
304
    PyUnicodeWriter_WriteASCII(writer, "|h=", 3);
305
    PyUnicodeWriter_WriteRepr(writer, alias->hashable_args);
306
    if (alias->unhashable_args) {
307
        PyUnicodeWriter_WriteASCII(writer, "|u=", 3);
308
        PyUnicodeWriter_WriteRepr(writer, alias->unhashable_args);
309
    }
310
#endif
311
312
0
    return PyUnicodeWriter_Finish(writer);
313
314
0
error:
315
0
    PyUnicodeWriter_Discard(writer);
316
0
    return NULL;
317
0
}
318
319
static PyMemberDef union_members[] = {
320
        {"__args__", _Py_T_OBJECT, offsetof(unionobject, args), Py_READONLY},
321
        {0}
322
};
323
324
// Populate __parameters__ if needed.
325
static int
326
union_init_parameters(unionobject *alias)
327
0
{
328
0
    int result = 0;
329
0
    Py_BEGIN_CRITICAL_SECTION(alias);
330
0
    if (alias->parameters == NULL) {
331
0
        alias->parameters = _Py_make_parameters(alias->args);
332
0
        if (alias->parameters == NULL) {
333
0
            result = -1;
334
0
        }
335
0
    }
336
0
    Py_END_CRITICAL_SECTION();
337
0
    return result;
338
0
}
339
340
static PyObject *
341
union_getitem(PyObject *self, PyObject *item)
342
0
{
343
0
    unionobject *alias = (unionobject *)self;
344
0
    if (union_init_parameters(alias) < 0) {
345
0
        return NULL;
346
0
    }
347
348
0
    PyObject *newargs = _Py_subs_parameters(self, alias->args, alias->parameters, item);
349
0
    if (newargs == NULL) {
350
0
        return NULL;
351
0
    }
352
353
0
    PyObject *res = _Py_union_from_tuple(newargs);
354
0
    Py_DECREF(newargs);
355
0
    return res;
356
0
}
357
358
static PyMappingMethods union_as_mapping = {
359
    .mp_subscript = union_getitem,
360
};
361
362
static PyObject *
363
union_parameters(PyObject *self, void *Py_UNUSED(unused))
364
0
{
365
0
    unionobject *alias = (unionobject *)self;
366
0
    if (union_init_parameters(alias) < 0) {
367
0
        return NULL;
368
0
    }
369
0
    return Py_NewRef(alias->parameters);
370
0
}
371
372
static PyObject *
373
union_name(PyObject *Py_UNUSED(self), void *Py_UNUSED(ignored))
374
0
{
375
0
    return PyUnicode_FromString("Union");
376
0
}
377
378
static PyObject *
379
union_origin(PyObject *Py_UNUSED(self), void *Py_UNUSED(ignored))
380
0
{
381
0
    return Py_NewRef(&_PyUnion_Type);
382
0
}
383
384
static PyGetSetDef union_properties[] = {
385
    {"__name__", union_name, NULL,
386
     PyDoc_STR("Name of the type"), NULL},
387
    {"__qualname__", union_name, NULL,
388
     PyDoc_STR("Qualified name of the type"), NULL},
389
    {"__origin__", union_origin, NULL,
390
     PyDoc_STR("Always returns the type"), NULL},
391
    {"__parameters__", union_parameters, NULL,
392
     PyDoc_STR("Type variables in the types.UnionType."), NULL},
393
    {0}
394
};
395
396
static PyNumberMethods union_as_number = {
397
        .nb_or = _Py_union_type_or, // Add __or__ function
398
};
399
400
static const char* const cls_attrs[] = {
401
        "__module__",  // Required for compatibility with typing module
402
        NULL,
403
};
404
405
static PyObject *
406
union_getattro(PyObject *self, PyObject *name)
407
0
{
408
0
    unionobject *alias = (unionobject *)self;
409
0
    if (PyUnicode_Check(name)) {
410
0
        for (const char * const *p = cls_attrs; ; p++) {
411
0
            if (*p == NULL) {
412
0
                break;
413
0
            }
414
0
            if (_PyUnicode_EqualToASCIIString(name, *p)) {
415
0
                return PyObject_GetAttr((PyObject *) Py_TYPE(alias), name);
416
0
            }
417
0
        }
418
0
    }
419
0
    return PyObject_GenericGetAttr(self, name);
420
0
}
421
422
PyObject *
423
_Py_union_args(PyObject *self)
424
0
{
425
0
    assert(_PyUnion_Check(self));
426
0
    return ((unionobject *) self)->args;
427
0
}
428
429
static PyObject *
430
call_typing_func_object(const char *name, PyObject **args, size_t nargs)
431
0
{
432
0
    PyObject *typing = PyImport_ImportModule("typing");
433
0
    if (typing == NULL) {
434
0
        return NULL;
435
0
    }
436
0
    PyObject *func = PyObject_GetAttrString(typing, name);
437
0
    if (func == NULL) {
438
0
        Py_DECREF(typing);
439
0
        return NULL;
440
0
    }
441
0
    PyObject *result = PyObject_Vectorcall(func, args, nargs, NULL);
442
0
    Py_DECREF(func);
443
0
    Py_DECREF(typing);
444
0
    return result;
445
0
}
446
447
static PyObject *
448
type_check(PyObject *arg, const char *msg)
449
0
{
450
0
    if (Py_IsNone(arg)) {
451
        // NoneType is immortal, so don't need an INCREF
452
0
        return (PyObject *)Py_TYPE(arg);
453
0
    }
454
    // Fast path to avoid calling into typing.py
455
0
    if (is_unionable(arg)) {
456
0
        return Py_NewRef(arg);
457
0
    }
458
0
    PyObject *message_str = PyUnicode_FromString(msg);
459
0
    if (message_str == NULL) {
460
0
        return NULL;
461
0
    }
462
0
    PyObject *args[2] = {arg, message_str};
463
0
    PyObject *result = call_typing_func_object("_type_check", args, 2);
464
0
    Py_DECREF(message_str);
465
0
    return result;
466
0
}
467
468
PyObject *
469
_Py_union_from_tuple(PyObject *args)
470
0
{
471
0
    unionbuilder ub;
472
0
    if (!unionbuilder_init(&ub, true)) {
473
0
        return NULL;
474
0
    }
475
0
    if (PyTuple_CheckExact(args)) {
476
0
        if (!unionbuilder_add_tuple(&ub, args)) {
477
0
            return NULL;
478
0
        }
479
0
    }
480
0
    else {
481
0
        if (!unionbuilder_add_single(&ub, args)) {
482
0
            return NULL;
483
0
        }
484
0
    }
485
0
    return make_union(&ub);
486
0
}
487
488
static PyObject *
489
union_class_getitem(PyObject *cls, PyObject *args)
490
0
{
491
0
    return _Py_union_from_tuple(args);
492
0
}
493
494
static PyObject *
495
union_mro_entries(PyObject *self, PyObject *args)
496
0
{
497
0
    return PyErr_Format(PyExc_TypeError,
498
0
                        "Cannot subclass %R", self);
499
0
}
500
501
static PyMethodDef union_methods[] = {
502
    {"__mro_entries__", union_mro_entries, METH_O},
503
    {"__class_getitem__", union_class_getitem, METH_O|METH_CLASS, PyDoc_STR("See PEP 585")},
504
    {0}
505
};
506
507
PyTypeObject _PyUnion_Type = {
508
    PyVarObject_HEAD_INIT(&PyType_Type, 0)
509
    .tp_name = "typing.Union",
510
    .tp_doc = PyDoc_STR("Represent a union type\n"
511
              "\n"
512
              "E.g. for int | str"),
513
    .tp_basicsize = sizeof(unionobject),
514
    .tp_dealloc = unionobject_dealloc,
515
    .tp_alloc = PyType_GenericAlloc,
516
    .tp_free = PyObject_GC_Del,
517
    .tp_flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_HAVE_GC,
518
    .tp_traverse = union_traverse,
519
    .tp_hash = union_hash,
520
    .tp_getattro = union_getattro,
521
    .tp_members = union_members,
522
    .tp_methods = union_methods,
523
    .tp_richcompare = union_richcompare,
524
    .tp_as_mapping = &union_as_mapping,
525
    .tp_as_number = &union_as_number,
526
    .tp_repr = union_repr,
527
    .tp_getset = union_properties,
528
    .tp_weaklistoffset = offsetof(unionobject, weakreflist),
529
};
530
531
static PyObject *
532
make_union(unionbuilder *ub)
533
421
{
534
421
    Py_ssize_t n = PyList_GET_SIZE(ub->args);
535
421
    if (n == 0) {
536
0
        PyErr_SetString(PyExc_TypeError, "Cannot take a Union of no types.");
537
0
        unionbuilder_finalize(ub);
538
0
        return NULL;
539
0
    }
540
421
    if (n == 1) {
541
0
        PyObject *result = PyList_GET_ITEM(ub->args, 0);
542
0
        Py_INCREF(result);
543
0
        unionbuilder_finalize(ub);
544
0
        return result;
545
0
    }
546
547
421
    PyObject *args = NULL, *hashable_args = NULL, *unhashable_args = NULL;
548
421
    args = PyList_AsTuple(ub->args);
549
421
    if (args == NULL) {
550
0
        goto error;
551
0
    }
552
421
    hashable_args = PyFrozenSet_New(ub->hashable_args);
553
421
    if (hashable_args == NULL) {
554
0
        goto error;
555
0
    }
556
421
    if (ub->unhashable_args != NULL) {
557
0
        unhashable_args = PyList_AsTuple(ub->unhashable_args);
558
0
        if (unhashable_args == NULL) {
559
0
            goto error;
560
0
        }
561
0
    }
562
563
421
    unionobject *result = PyObject_GC_New(unionobject, &_PyUnion_Type);
564
421
    if (result == NULL) {
565
0
        goto error;
566
0
    }
567
421
    unionbuilder_finalize(ub);
568
569
421
    result->parameters = NULL;
570
421
    result->args = args;
571
421
    result->hashable_args = hashable_args;
572
421
    result->unhashable_args = unhashable_args;
573
421
    result->weakreflist = NULL;
574
421
    _PyObject_GC_TRACK(result);
575
421
    return (PyObject*)result;
576
0
error:
577
0
    Py_XDECREF(args);
578
0
    Py_XDECREF(hashable_args);
579
0
    Py_XDECREF(unhashable_args);
580
0
    unionbuilder_finalize(ub);
581
0
    return NULL;
582
421
}