Coverage Report

Created: 2026-05-16 06:46

next uncovered line (L), next uncovered region (R), next uncovered branch (B)
/src/cpython/Objects/unionobject.c
Line
Count
Source
1
// typing.Union -- used to represent e.g. Union[int, str], int | str
2
#include "Python.h"
3
#include "pycore_object.h"  // _PyObject_GC_TRACK/UNTRACK
4
#include "pycore_typevarobject.h"  // _PyTypeAlias_Type, _Py_typing_type_repr
5
#include "pycore_unicodeobject.h" // _PyUnicode_EqualToASCIIString
6
#include "pycore_unionobject.h"
7
#include "pycore_weakref.h"       // FT_CLEAR_WEAKREFS()
8
9
10
typedef struct {
11
    PyObject_HEAD
12
    PyObject *args;  // all args (tuple)
13
    PyObject *hashable_args;  // frozenset or NULL
14
    PyObject *unhashable_args;  // tuple or NULL
15
    PyObject *parameters;
16
    PyObject *weakreflist;
17
} unionobject;
18
19
static void
20
unionobject_dealloc(PyObject *self)
21
168
{
22
168
    unionobject *alias = (unionobject *)self;
23
24
168
    _PyObject_GC_UNTRACK(self);
25
168
    FT_CLEAR_WEAKREFS(self, alias->weakreflist);
26
27
168
    Py_XDECREF(alias->args);
28
168
    Py_XDECREF(alias->hashable_args);
29
168
    Py_XDECREF(alias->unhashable_args);
30
168
    Py_XDECREF(alias->parameters);
31
168
    Py_TYPE(self)->tp_free(self);
32
168
}
33
34
static int
35
union_traverse(PyObject *self, visitproc visit, void *arg)
36
16.1k
{
37
16.1k
    unionobject *alias = (unionobject *)self;
38
16.1k
    Py_VISIT(alias->args);
39
16.1k
    Py_VISIT(alias->hashable_args);
40
16.1k
    Py_VISIT(alias->unhashable_args);
41
16.1k
    Py_VISIT(alias->parameters);
42
16.1k
    return 0;
43
16.1k
}
44
45
static Py_hash_t
46
union_hash(PyObject *self)
47
48
{
48
48
    unionobject *alias = (unionobject *)self;
49
    // If there are any unhashable args, treat this union as unhashable.
50
    // Otherwise, two unions might compare equal but have different hashes.
51
48
    if (alias->unhashable_args) {
52
        // Attempt to get an error from one of the values.
53
0
        assert(PyTuple_CheckExact(alias->unhashable_args));
54
0
        Py_ssize_t n = PyTuple_GET_SIZE(alias->unhashable_args);
55
0
        for (Py_ssize_t i = 0; i < n; i++) {
56
0
            PyObject *arg = PyTuple_GET_ITEM(alias->unhashable_args, i);
57
0
            Py_hash_t hash = PyObject_Hash(arg);
58
0
            if (hash == -1) {
59
0
                return -1;
60
0
            }
61
0
        }
62
        // The unhashable values somehow became hashable again. Still raise
63
        // an error.
64
0
        PyErr_Format(PyExc_TypeError, "union contains %zd unhashable elements", n);
65
0
        return -1;
66
0
    }
67
48
    return PyObject_Hash(alias->hashable_args);
68
48
}
69
70
static int
71
unions_equal(unionobject *a, unionobject *b)
72
4
{
73
4
    int result = PyObject_RichCompareBool(a->hashable_args, b->hashable_args, Py_EQ);
74
4
    if (result == -1) {
75
0
        return -1;
76
0
    }
77
4
    if (result == 0) {
78
0
        return 0;
79
0
    }
80
4
    if (a->unhashable_args && b->unhashable_args) {
81
0
        Py_ssize_t n = PyTuple_GET_SIZE(a->unhashable_args);
82
0
        if (n != PyTuple_GET_SIZE(b->unhashable_args)) {
83
0
            return 0;
84
0
        }
85
0
        for (Py_ssize_t i = 0; i < n; i++) {
86
0
            PyObject *arg_a = PyTuple_GET_ITEM(a->unhashable_args, i);
87
0
            int result = PySequence_Contains(b->unhashable_args, arg_a);
88
0
            if (result == -1) {
89
0
                return -1;
90
0
            }
91
0
            if (!result) {
92
0
                return 0;
93
0
            }
94
0
        }
95
0
        for (Py_ssize_t i = 0; i < n; i++) {
96
0
            PyObject *arg_b = PyTuple_GET_ITEM(b->unhashable_args, i);
97
0
            int result = PySequence_Contains(a->unhashable_args, arg_b);
98
0
            if (result == -1) {
99
0
                return -1;
100
0
            }
101
0
            if (!result) {
102
0
                return 0;
103
0
            }
104
0
        }
105
0
    }
106
4
    else if (a->unhashable_args || b->unhashable_args) {
107
0
        return 0;
108
0
    }
109
4
    return 1;
110
4
}
111
112
static PyObject *
113
union_richcompare(PyObject *a, PyObject *b, int op)
114
164
{
115
164
    if (!_PyUnion_Check(b) || (op != Py_EQ && op != Py_NE)) {
116
160
        Py_RETURN_NOTIMPLEMENTED;
117
160
    }
118
119
4
    int equal = unions_equal((unionobject*)a, (unionobject*)b);
120
4
    if (equal == -1) {
121
0
        return NULL;
122
0
    }
123
4
    if (op == Py_EQ) {
124
4
        return PyBool_FromLong(equal);
125
4
    }
126
0
    else {
127
0
        return PyBool_FromLong(!equal);
128
0
    }
129
4
}
130
131
typedef struct {
132
    PyObject *args;  // list
133
    PyObject *hashable_args;  // set
134
    PyObject *unhashable_args;  // list or NULL
135
    bool is_checked;  // whether to call type_check()
136
} unionbuilder;
137
138
static bool unionbuilder_add_tuple(unionbuilder *, PyObject *);
139
static PyObject *make_union(unionbuilder *);
140
static PyObject *type_check(PyObject *, const char *);
141
142
static bool
143
unionbuilder_init(unionbuilder *ub, bool is_checked)
144
1.42k
{
145
1.42k
    ub->args = PyList_New(0);
146
1.42k
    if (ub->args == NULL) {
147
0
        return false;
148
0
    }
149
1.42k
    ub->hashable_args = PySet_New(NULL);
150
1.42k
    if (ub->hashable_args == NULL) {
151
0
        Py_DECREF(ub->args);
152
0
        return false;
153
0
    }
154
1.42k
    ub->unhashable_args = NULL;
155
1.42k
    ub->is_checked = is_checked;
156
1.42k
    return true;
157
1.42k
}
158
159
static void
160
unionbuilder_finalize(unionbuilder *ub)
161
1.42k
{
162
1.42k
    Py_DECREF(ub->args);
163
1.42k
    Py_DECREF(ub->hashable_args);
164
1.42k
    Py_XDECREF(ub->unhashable_args);
165
1.42k
}
166
167
static bool
168
unionbuilder_add_single_unchecked(unionbuilder *ub, PyObject *arg)
169
3.30k
{
170
3.30k
    Py_hash_t hash = PyObject_Hash(arg);
171
3.30k
    if (hash == -1) {
172
0
        PyErr_Clear();
173
0
        if (ub->unhashable_args == NULL) {
174
0
            ub->unhashable_args = PyList_New(0);
175
0
            if (ub->unhashable_args == NULL) {
176
0
                return false;
177
0
            }
178
0
        }
179
0
        else {
180
0
            int contains = PySequence_Contains(ub->unhashable_args, arg);
181
0
            if (contains < 0) {
182
0
                return false;
183
0
            }
184
0
            if (contains == 1) {
185
0
                return true;
186
0
            }
187
0
        }
188
0
        if (PyList_Append(ub->unhashable_args, arg) < 0) {
189
0
            return false;
190
0
        }
191
0
    }
192
3.30k
    else {
193
3.30k
        int contains = PySet_Contains(ub->hashable_args, arg);
194
3.30k
        if (contains < 0) {
195
0
            return false;
196
0
        }
197
3.30k
        if (contains == 1) {
198
0
            return true;
199
0
        }
200
3.30k
        if (PySet_Add(ub->hashable_args, arg) < 0) {
201
0
            return false;
202
0
        }
203
3.30k
    }
204
3.30k
    return PyList_Append(ub->args, arg) == 0;
205
3.30k
}
206
207
static bool
208
unionbuilder_add_single(unionbuilder *ub, PyObject *arg)
209
3.49k
{
210
3.49k
    if (Py_IsNone(arg)) {
211
1.14k
        arg = (PyObject *)&_PyNone_Type;  // immortal, so no refcounting needed
212
1.14k
    }
213
2.34k
    else if (_PyUnion_Check(arg)) {
214
188
        PyObject *args = ((unionobject *)arg)->args;
215
188
        return unionbuilder_add_tuple(ub, args);
216
188
    }
217
3.30k
    if (ub->is_checked) {
218
984
        PyObject *type = type_check(arg, "Union[arg, ...]: each arg must be a type.");
219
984
        if (type == NULL) {
220
0
            return false;
221
0
        }
222
984
        bool result = unionbuilder_add_single_unchecked(ub, type);
223
984
        Py_DECREF(type);
224
984
        return result;
225
984
    }
226
2.32k
    else {
227
2.32k
        return unionbuilder_add_single_unchecked(ub, arg);
228
2.32k
    }
229
3.30k
}
230
231
static bool
232
unionbuilder_add_tuple(unionbuilder *ub, PyObject *tuple)
233
260
{
234
260
    Py_ssize_t n = PyTuple_GET_SIZE(tuple);
235
1.05k
    for (Py_ssize_t i = 0; i < n; i++) {
236
796
        if (!unionbuilder_add_single(ub, PyTuple_GET_ITEM(tuple, i))) {
237
0
            return false;
238
0
        }
239
796
    }
240
260
    return true;
241
260
}
242
243
static int
244
is_unionable(PyObject *obj)
245
3.33k
{
246
3.33k
    if (obj == Py_None ||
247
2.28k
        PyType_Check(obj) ||
248
3.33k
        PySentinel_Check(obj) ||
249
3.33k
        _PyGenericAlias_Check(obj) ||
250
3.33k
        _PyUnion_Check(obj) ||
251
3.23k
        Py_IS_TYPE(obj, &_PyTypeAlias_Type)) {
252
3.23k
        return 1;
253
3.23k
    }
254
104
    return 0;
255
3.33k
}
256
257
PyObject *
258
_Py_union_type_or(PyObject* self, PyObject* other)
259
1.17k
{
260
1.17k
    if (!is_unionable(self) || !is_unionable(other)) {
261
16
        Py_RETURN_NOTIMPLEMENTED;
262
16
    }
263
264
1.16k
    unionbuilder ub;
265
    // unchecked because we already checked is_unionable()
266
1.16k
    if (!unionbuilder_init(&ub, false)) {
267
0
        return NULL;
268
0
    }
269
1.16k
    if (!unionbuilder_add_single(&ub, self) ||
270
1.16k
        !unionbuilder_add_single(&ub, other)) {
271
0
        unionbuilder_finalize(&ub);
272
0
        return NULL;
273
0
    }
274
275
1.16k
    PyObject *new_union = make_union(&ub);
276
1.16k
    return new_union;
277
1.16k
}
278
279
static PyObject *
280
union_repr(PyObject *self)
281
0
{
282
0
    unionobject *alias = (unionobject *)self;
283
0
    Py_ssize_t len = PyTuple_GET_SIZE(alias->args);
284
285
    // Shortest type name "int" (3 chars) + " | " (3 chars) separator
286
0
    Py_ssize_t estimate = (len <= PY_SSIZE_T_MAX / 6) ? len * 6 : len;
287
0
    PyUnicodeWriter *writer = PyUnicodeWriter_Create(estimate);
288
0
    if (writer == NULL) {
289
0
        return NULL;
290
0
    }
291
292
0
    for (Py_ssize_t i = 0; i < len; i++) {
293
0
        if (i > 0 && PyUnicodeWriter_WriteASCII(writer, " | ", 3) < 0) {
294
0
            goto error;
295
0
        }
296
0
        PyObject *p = PyTuple_GET_ITEM(alias->args, i);
297
0
        if (_Py_typing_type_repr(writer, p) < 0) {
298
0
            goto error;
299
0
        }
300
0
    }
301
302
#if 0
303
    PyUnicodeWriter_WriteASCII(writer, "|args=", 6);
304
    PyUnicodeWriter_WriteRepr(writer, alias->args);
305
    PyUnicodeWriter_WriteASCII(writer, "|h=", 3);
306
    PyUnicodeWriter_WriteRepr(writer, alias->hashable_args);
307
    if (alias->unhashable_args) {
308
        PyUnicodeWriter_WriteASCII(writer, "|u=", 3);
309
        PyUnicodeWriter_WriteRepr(writer, alias->unhashable_args);
310
    }
311
#endif
312
313
0
    return PyUnicodeWriter_Finish(writer);
314
315
0
error:
316
0
    PyUnicodeWriter_Discard(writer);
317
0
    return NULL;
318
0
}
319
320
static PyMemberDef union_members[] = {
321
        {"__args__", _Py_T_OBJECT, offsetof(unionobject, args), Py_READONLY},
322
        {0}
323
};
324
325
// Populate __parameters__ if needed.
326
static int
327
union_init_parameters(unionobject *alias)
328
24
{
329
24
    int result = 0;
330
24
    Py_BEGIN_CRITICAL_SECTION(alias);
331
24
    if (alias->parameters == NULL) {
332
24
        alias->parameters = _Py_make_parameters(alias->args);
333
24
        if (alias->parameters == NULL) {
334
0
            result = -1;
335
0
        }
336
24
    }
337
24
    Py_END_CRITICAL_SECTION();
338
24
    return result;
339
24
}
340
341
static PyObject *
342
union_getitem(PyObject *self, PyObject *item)
343
0
{
344
0
    unionobject *alias = (unionobject *)self;
345
0
    if (union_init_parameters(alias) < 0) {
346
0
        return NULL;
347
0
    }
348
349
0
    PyObject *newargs = _Py_subs_parameters(self, alias->args, alias->parameters, item);
350
0
    if (newargs == NULL) {
351
0
        return NULL;
352
0
    }
353
354
0
    PyObject *res = _Py_union_from_tuple(newargs);
355
0
    Py_DECREF(newargs);
356
0
    return res;
357
0
}
358
359
static PyMappingMethods union_as_mapping = {
360
    .mp_subscript = union_getitem,
361
};
362
363
static PyObject *
364
union_parameters(PyObject *self, void *Py_UNUSED(unused))
365
24
{
366
24
    unionobject *alias = (unionobject *)self;
367
24
    if (union_init_parameters(alias) < 0) {
368
0
        return NULL;
369
0
    }
370
24
    return Py_NewRef(alias->parameters);
371
24
}
372
373
static PyObject *
374
union_name(PyObject *Py_UNUSED(self), void *Py_UNUSED(ignored))
375
0
{
376
0
    return PyUnicode_FromString("Union");
377
0
}
378
379
static PyObject *
380
union_origin(PyObject *Py_UNUSED(self), void *Py_UNUSED(ignored))
381
0
{
382
0
    return Py_NewRef(&_PyUnion_Type);
383
0
}
384
385
static PyGetSetDef union_properties[] = {
386
    {"__name__", union_name, NULL,
387
     PyDoc_STR("Name of the type"), NULL},
388
    {"__qualname__", union_name, NULL,
389
     PyDoc_STR("Qualified name of the type"), NULL},
390
    {"__origin__", union_origin, NULL,
391
     PyDoc_STR("Always returns the type"), NULL},
392
    {"__parameters__", union_parameters, NULL,
393
     PyDoc_STR("Type variables in the types.UnionType."), NULL},
394
    {0}
395
};
396
397
static PyObject *
398
union_nb_or(PyObject *a, PyObject *b)
399
188
{
400
188
    unionbuilder ub;
401
188
    if (!unionbuilder_init(&ub, true)) {
402
0
        return NULL;
403
0
    }
404
188
    if (!unionbuilder_add_single(&ub, a) ||
405
188
        !unionbuilder_add_single(&ub, b)) {
406
0
        unionbuilder_finalize(&ub);
407
0
        return NULL;
408
0
    }
409
188
    return make_union(&ub);
410
188
}
411
412
static PyNumberMethods union_as_number = {
413
        .nb_or = union_nb_or, // Add __or__ function
414
};
415
416
static const char* const cls_attrs[] = {
417
        "__module__",  // Required for compatibility with typing module
418
        NULL,
419
};
420
421
static PyObject *
422
union_getattro(PyObject *self, PyObject *name)
423
968
{
424
968
    unionobject *alias = (unionobject *)self;
425
968
    if (PyUnicode_Check(name)) {
426
1.93k
        for (const char * const *p = cls_attrs; ; p++) {
427
1.93k
            if (*p == NULL) {
428
968
                break;
429
968
            }
430
968
            if (_PyUnicode_EqualToASCIIString(name, *p)) {
431
0
                return PyObject_GetAttr((PyObject *) Py_TYPE(alias), name);
432
0
            }
433
968
        }
434
968
    }
435
968
    return PyObject_GenericGetAttr(self, name);
436
968
}
437
438
PyObject *
439
_Py_union_args(PyObject *self)
440
0
{
441
0
    assert(_PyUnion_Check(self));
442
0
    return ((unionobject *) self)->args;
443
0
}
444
445
static PyObject *
446
call_typing_func_object(const char *name, PyObject **args, size_t nargs)
447
88
{
448
88
    PyObject *typing = PyImport_ImportModule("typing");
449
88
    if (typing == NULL) {
450
0
        return NULL;
451
0
    }
452
88
    PyObject *func = PyObject_GetAttrString(typing, name);
453
88
    if (func == NULL) {
454
0
        Py_DECREF(typing);
455
0
        return NULL;
456
0
    }
457
88
    PyObject *result = PyObject_Vectorcall(func, args, nargs, NULL);
458
88
    Py_DECREF(func);
459
88
    Py_DECREF(typing);
460
88
    return result;
461
88
}
462
463
static PyObject *
464
type_check(PyObject *arg, const char *msg)
465
984
{
466
984
    if (Py_IsNone(arg)) {
467
        // NoneType is immortal, so don't need an INCREF
468
0
        return (PyObject *)Py_TYPE(arg);
469
0
    }
470
    // Fast path to avoid calling into typing.py
471
984
    if (is_unionable(arg)) {
472
896
        return Py_NewRef(arg);
473
896
    }
474
88
    PyObject *message_str = PyUnicode_FromString(msg);
475
88
    if (message_str == NULL) {
476
0
        return NULL;
477
0
    }
478
88
    PyObject *args[2] = {arg, message_str};
479
88
    PyObject *result = call_typing_func_object("_type_check", args, 2);
480
88
    Py_DECREF(message_str);
481
88
    return result;
482
88
}
483
484
PyObject *
485
_Py_union_from_tuple(PyObject *args)
486
72
{
487
72
    unionbuilder ub;
488
72
    if (!unionbuilder_init(&ub, true)) {
489
0
        return NULL;
490
0
    }
491
72
    if (PyTuple_CheckExact(args)) {
492
72
        if (!unionbuilder_add_tuple(&ub, args)) {
493
0
            unionbuilder_finalize(&ub);
494
0
            return NULL;
495
0
        }
496
72
    }
497
0
    else {
498
0
        if (!unionbuilder_add_single(&ub, args)) {
499
0
            unionbuilder_finalize(&ub);
500
0
            return NULL;
501
0
        }
502
0
    }
503
72
    return make_union(&ub);
504
72
}
505
506
static PyObject *
507
union_class_getitem(PyObject *cls, PyObject *args)
508
64
{
509
64
    return _Py_union_from_tuple(args);
510
64
}
511
512
static PyObject *
513
union_mro_entries(PyObject *self, PyObject *args)
514
0
{
515
0
    return PyErr_Format(PyExc_TypeError,
516
0
                        "Cannot subclass %R", self);
517
0
}
518
519
static PyMethodDef union_methods[] = {
520
    {"__mro_entries__", union_mro_entries, METH_O},
521
    {"__class_getitem__", union_class_getitem, METH_O|METH_CLASS, PyDoc_STR("See PEP 585")},
522
    {0}
523
};
524
525
PyTypeObject _PyUnion_Type = {
526
    PyVarObject_HEAD_INIT(&PyType_Type, 0)
527
    .tp_name = "typing.Union",
528
    .tp_doc = PyDoc_STR("Represent a union type\n"
529
              "\n"
530
              "E.g. for int | str"),
531
    .tp_basicsize = sizeof(unionobject),
532
    .tp_dealloc = unionobject_dealloc,
533
    .tp_alloc = PyType_GenericAlloc,
534
    .tp_free = PyObject_GC_Del,
535
    .tp_flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_HAVE_GC,
536
    .tp_traverse = union_traverse,
537
    .tp_hash = union_hash,
538
    .tp_getattro = union_getattro,
539
    .tp_members = union_members,
540
    .tp_methods = union_methods,
541
    .tp_richcompare = union_richcompare,
542
    .tp_as_mapping = &union_as_mapping,
543
    .tp_as_number = &union_as_number,
544
    .tp_repr = union_repr,
545
    .tp_getset = union_properties,
546
    .tp_weaklistoffset = offsetof(unionobject, weakreflist),
547
};
548
549
static PyObject *
550
make_union(unionbuilder *ub)
551
1.42k
{
552
1.42k
    Py_ssize_t n = PyList_GET_SIZE(ub->args);
553
1.42k
    if (n == 0) {
554
0
        PyErr_SetString(PyExc_TypeError, "Cannot take a Union of no types.");
555
0
        unionbuilder_finalize(ub);
556
0
        return NULL;
557
0
    }
558
1.42k
    if (n == 1) {
559
0
        PyObject *result = PyList_GET_ITEM(ub->args, 0);
560
0
        Py_INCREF(result);
561
0
        unionbuilder_finalize(ub);
562
0
        return result;
563
0
    }
564
565
1.42k
    PyObject *args = NULL, *hashable_args = NULL, *unhashable_args = NULL;
566
1.42k
    args = PyList_AsTuple(ub->args);
567
1.42k
    if (args == NULL) {
568
0
        goto error;
569
0
    }
570
1.42k
    hashable_args = PyFrozenSet_New(ub->hashable_args);
571
1.42k
    if (hashable_args == NULL) {
572
0
        goto error;
573
0
    }
574
1.42k
    if (ub->unhashable_args != NULL) {
575
0
        unhashable_args = PyList_AsTuple(ub->unhashable_args);
576
0
        if (unhashable_args == NULL) {
577
0
            goto error;
578
0
        }
579
0
    }
580
581
1.42k
    unionobject *result = PyObject_GC_New(unionobject, &_PyUnion_Type);
582
1.42k
    if (result == NULL) {
583
0
        goto error;
584
0
    }
585
1.42k
    unionbuilder_finalize(ub);
586
587
1.42k
    result->parameters = NULL;
588
1.42k
    result->args = args;
589
1.42k
    result->hashable_args = hashable_args;
590
1.42k
    result->unhashable_args = unhashable_args;
591
1.42k
    result->weakreflist = NULL;
592
1.42k
    _PyObject_GC_TRACK(result);
593
1.42k
    return (PyObject*)result;
594
0
error:
595
0
    Py_XDECREF(args);
596
0
    Py_XDECREF(hashable_args);
597
0
    Py_XDECREF(unhashable_args);
598
0
    unionbuilder_finalize(ub);
599
    return NULL;
600
1.42k
}