Coverage Report

Created: 2025-11-24 06:11

next uncovered line (L), next uncovered region (R), next uncovered branch (B)
/src/cpython/Objects/unionobject.c
Line
Count
Source
1
// typing.Union -- used to represent e.g. Union[int, str], int | str
2
#include "Python.h"
3
#include "pycore_object.h"  // _PyObject_GC_TRACK/UNTRACK
4
#include "pycore_typevarobject.h"  // _PyTypeAlias_Type, _Py_typing_type_repr
5
#include "pycore_unicodeobject.h" // _PyUnicode_EqualToASCIIString
6
#include "pycore_unionobject.h"
7
#include "pycore_weakref.h"       // FT_CLEAR_WEAKREFS()
8
9
10
typedef struct {
11
    PyObject_HEAD
12
    PyObject *args;  // all args (tuple)
13
    PyObject *hashable_args;  // frozenset or NULL
14
    PyObject *unhashable_args;  // tuple or NULL
15
    PyObject *parameters;
16
    PyObject *weakreflist;
17
} unionobject;
18
19
static void
20
unionobject_dealloc(PyObject *self)
21
24
{
22
24
    unionobject *alias = (unionobject *)self;
23
24
24
    _PyObject_GC_UNTRACK(self);
25
24
    FT_CLEAR_WEAKREFS(self, alias->weakreflist);
26
27
24
    Py_XDECREF(alias->args);
28
24
    Py_XDECREF(alias->hashable_args);
29
24
    Py_XDECREF(alias->unhashable_args);
30
24
    Py_XDECREF(alias->parameters);
31
24
    Py_TYPE(self)->tp_free(self);
32
24
}
33
34
static int
35
union_traverse(PyObject *self, visitproc visit, void *arg)
36
26.7k
{
37
26.7k
    unionobject *alias = (unionobject *)self;
38
26.7k
    Py_VISIT(alias->args);
39
26.7k
    Py_VISIT(alias->hashable_args);
40
26.7k
    Py_VISIT(alias->unhashable_args);
41
26.7k
    Py_VISIT(alias->parameters);
42
26.7k
    return 0;
43
26.7k
}
44
45
static Py_hash_t
46
union_hash(PyObject *self)
47
0
{
48
0
    unionobject *alias = (unionobject *)self;
49
    // If there are any unhashable args, treat this union as unhashable.
50
    // Otherwise, two unions might compare equal but have different hashes.
51
0
    if (alias->unhashable_args) {
52
        // Attempt to get an error from one of the values.
53
0
        assert(PyTuple_CheckExact(alias->unhashable_args));
54
0
        Py_ssize_t n = PyTuple_GET_SIZE(alias->unhashable_args);
55
0
        for (Py_ssize_t i = 0; i < n; i++) {
56
0
            PyObject *arg = PyTuple_GET_ITEM(alias->unhashable_args, i);
57
0
            Py_hash_t hash = PyObject_Hash(arg);
58
0
            if (hash == -1) {
59
0
                return -1;
60
0
            }
61
0
        }
62
        // The unhashable values somehow became hashable again. Still raise
63
        // an error.
64
0
        PyErr_Format(PyExc_TypeError, "union contains %d unhashable elements", n);
65
0
        return -1;
66
0
    }
67
0
    return PyObject_Hash(alias->hashable_args);
68
0
}
69
70
static int
71
unions_equal(unionobject *a, unionobject *b)
72
0
{
73
0
    int result = PyObject_RichCompareBool(a->hashable_args, b->hashable_args, Py_EQ);
74
0
    if (result == -1) {
75
0
        return -1;
76
0
    }
77
0
    if (result == 0) {
78
0
        return 0;
79
0
    }
80
0
    if (a->unhashable_args && b->unhashable_args) {
81
0
        Py_ssize_t n = PyTuple_GET_SIZE(a->unhashable_args);
82
0
        if (n != PyTuple_GET_SIZE(b->unhashable_args)) {
83
0
            return 0;
84
0
        }
85
0
        for (Py_ssize_t i = 0; i < n; i++) {
86
0
            PyObject *arg_a = PyTuple_GET_ITEM(a->unhashable_args, i);
87
0
            int result = PySequence_Contains(b->unhashable_args, arg_a);
88
0
            if (result == -1) {
89
0
                return -1;
90
0
            }
91
0
            if (!result) {
92
0
                return 0;
93
0
            }
94
0
        }
95
0
        for (Py_ssize_t i = 0; i < n; i++) {
96
0
            PyObject *arg_b = PyTuple_GET_ITEM(b->unhashable_args, i);
97
0
            int result = PySequence_Contains(a->unhashable_args, arg_b);
98
0
            if (result == -1) {
99
0
                return -1;
100
0
            }
101
0
            if (!result) {
102
0
                return 0;
103
0
            }
104
0
        }
105
0
    }
106
0
    else if (a->unhashable_args || b->unhashable_args) {
107
0
        return 0;
108
0
    }
109
0
    return 1;
110
0
}
111
112
static PyObject *
113
union_richcompare(PyObject *a, PyObject *b, int op)
114
0
{
115
0
    if (!_PyUnion_Check(b) || (op != Py_EQ && op != Py_NE)) {
116
0
        Py_RETURN_NOTIMPLEMENTED;
117
0
    }
118
119
0
    int equal = unions_equal((unionobject*)a, (unionobject*)b);
120
0
    if (equal == -1) {
121
0
        return NULL;
122
0
    }
123
0
    if (op == Py_EQ) {
124
0
        return PyBool_FromLong(equal);
125
0
    }
126
0
    else {
127
0
        return PyBool_FromLong(!equal);
128
0
    }
129
0
}
130
131
typedef struct {
132
    PyObject *args;  // list
133
    PyObject *hashable_args;  // set
134
    PyObject *unhashable_args;  // list or NULL
135
    bool is_checked;  // whether to call type_check()
136
} unionbuilder;
137
138
static bool unionbuilder_add_tuple(unionbuilder *, PyObject *);
139
static PyObject *make_union(unionbuilder *);
140
static PyObject *type_check(PyObject *, const char *);
141
142
static bool
143
unionbuilder_init(unionbuilder *ub, bool is_checked)
144
762
{
145
762
    ub->args = PyList_New(0);
146
762
    if (ub->args == NULL) {
147
0
        return false;
148
0
    }
149
762
    ub->hashable_args = PySet_New(NULL);
150
762
    if (ub->hashable_args == NULL) {
151
0
        Py_DECREF(ub->args);
152
0
        return false;
153
0
    }
154
762
    ub->unhashable_args = NULL;
155
762
    ub->is_checked = is_checked;
156
762
    return true;
157
762
}
158
159
static void
160
unionbuilder_finalize(unionbuilder *ub)
161
762
{
162
762
    Py_DECREF(ub->args);
163
762
    Py_DECREF(ub->hashable_args);
164
762
    Py_XDECREF(ub->unhashable_args);
165
762
}
166
167
static bool
168
unionbuilder_add_single_unchecked(unionbuilder *ub, PyObject *arg)
169
1.52k
{
170
1.52k
    Py_hash_t hash = PyObject_Hash(arg);
171
1.52k
    if (hash == -1) {
172
0
        PyErr_Clear();
173
0
        if (ub->unhashable_args == NULL) {
174
0
            ub->unhashable_args = PyList_New(0);
175
0
            if (ub->unhashable_args == NULL) {
176
0
                return false;
177
0
            }
178
0
        }
179
0
        else {
180
0
            int contains = PySequence_Contains(ub->unhashable_args, arg);
181
0
            if (contains < 0) {
182
0
                return false;
183
0
            }
184
0
            if (contains == 1) {
185
0
                return true;
186
0
            }
187
0
        }
188
0
        if (PyList_Append(ub->unhashable_args, arg) < 0) {
189
0
            return false;
190
0
        }
191
0
    }
192
1.52k
    else {
193
1.52k
        int contains = PySet_Contains(ub->hashable_args, arg);
194
1.52k
        if (contains < 0) {
195
0
            return false;
196
0
        }
197
1.52k
        if (contains == 1) {
198
0
            return true;
199
0
        }
200
1.52k
        if (PySet_Add(ub->hashable_args, arg) < 0) {
201
0
            return false;
202
0
        }
203
1.52k
    }
204
1.52k
    return PyList_Append(ub->args, arg) == 0;
205
1.52k
}
206
207
static bool
208
unionbuilder_add_single(unionbuilder *ub, PyObject *arg)
209
1.52k
{
210
1.52k
    if (Py_IsNone(arg)) {
211
738
        arg = (PyObject *)&_PyNone_Type;  // immortal, so no refcounting needed
212
738
    }
213
786
    else if (_PyUnion_Check(arg)) {
214
0
        PyObject *args = ((unionobject *)arg)->args;
215
0
        return unionbuilder_add_tuple(ub, args);
216
0
    }
217
1.52k
    if (ub->is_checked) {
218
0
        PyObject *type = type_check(arg, "Union[arg, ...]: each arg must be a type.");
219
0
        if (type == NULL) {
220
0
            return false;
221
0
        }
222
0
        bool result = unionbuilder_add_single_unchecked(ub, type);
223
0
        Py_DECREF(type);
224
0
        return result;
225
0
    }
226
1.52k
    else {
227
1.52k
        return unionbuilder_add_single_unchecked(ub, arg);
228
1.52k
    }
229
1.52k
}
230
231
static bool
232
unionbuilder_add_tuple(unionbuilder *ub, PyObject *tuple)
233
0
{
234
0
    Py_ssize_t n = PyTuple_GET_SIZE(tuple);
235
0
    for (Py_ssize_t i = 0; i < n; i++) {
236
0
        if (!unionbuilder_add_single(ub, PyTuple_GET_ITEM(tuple, i))) {
237
0
            return false;
238
0
        }
239
0
    }
240
0
    return true;
241
0
}
242
243
static int
244
is_unionable(PyObject *obj)
245
1.52k
{
246
1.52k
    if (obj == Py_None ||
247
786
        PyType_Check(obj) ||
248
1.52k
        _PyGenericAlias_Check(obj) ||
249
1.52k
        _PyUnion_Check(obj) ||
250
1.52k
        Py_IS_TYPE(obj, &_PyTypeAlias_Type)) {
251
1.52k
        return 1;
252
1.52k
    }
253
0
    return 0;
254
1.52k
}
255
256
PyObject *
257
_Py_union_type_or(PyObject* self, PyObject* other)
258
762
{
259
762
    if (!is_unionable(self) || !is_unionable(other)) {
260
0
        Py_RETURN_NOTIMPLEMENTED;
261
0
    }
262
263
762
    unionbuilder ub;
264
    // unchecked because we already checked is_unionable()
265
762
    if (!unionbuilder_init(&ub, false)) {
266
0
        return NULL;
267
0
    }
268
762
    if (!unionbuilder_add_single(&ub, self) ||
269
762
        !unionbuilder_add_single(&ub, other)) {
270
0
        unionbuilder_finalize(&ub);
271
0
        return NULL;
272
0
    }
273
274
762
    PyObject *new_union = make_union(&ub);
275
762
    return new_union;
276
762
}
277
278
static PyObject *
279
union_repr(PyObject *self)
280
0
{
281
0
    unionobject *alias = (unionobject *)self;
282
0
    Py_ssize_t len = PyTuple_GET_SIZE(alias->args);
283
284
    // Shortest type name "int" (3 chars) + " | " (3 chars) separator
285
0
    Py_ssize_t estimate = (len <= PY_SSIZE_T_MAX / 6) ? len * 6 : len;
286
0
    PyUnicodeWriter *writer = PyUnicodeWriter_Create(estimate);
287
0
    if (writer == NULL) {
288
0
        return NULL;
289
0
    }
290
291
0
    for (Py_ssize_t i = 0; i < len; i++) {
292
0
        if (i > 0 && PyUnicodeWriter_WriteASCII(writer, " | ", 3) < 0) {
293
0
            goto error;
294
0
        }
295
0
        PyObject *p = PyTuple_GET_ITEM(alias->args, i);
296
0
        if (_Py_typing_type_repr(writer, p) < 0) {
297
0
            goto error;
298
0
        }
299
0
    }
300
301
#if 0
302
    PyUnicodeWriter_WriteASCII(writer, "|args=", 6);
303
    PyUnicodeWriter_WriteRepr(writer, alias->args);
304
    PyUnicodeWriter_WriteASCII(writer, "|h=", 3);
305
    PyUnicodeWriter_WriteRepr(writer, alias->hashable_args);
306
    if (alias->unhashable_args) {
307
        PyUnicodeWriter_WriteASCII(writer, "|u=", 3);
308
        PyUnicodeWriter_WriteRepr(writer, alias->unhashable_args);
309
    }
310
#endif
311
312
0
    return PyUnicodeWriter_Finish(writer);
313
314
0
error:
315
0
    PyUnicodeWriter_Discard(writer);
316
0
    return NULL;
317
0
}
318
319
static PyMemberDef union_members[] = {
320
        {"__args__", _Py_T_OBJECT, offsetof(unionobject, args), Py_READONLY},
321
        {0}
322
};
323
324
// Populate __parameters__ if needed.
325
static int
326
union_init_parameters(unionobject *alias)
327
0
{
328
0
    int result = 0;
329
0
    Py_BEGIN_CRITICAL_SECTION(alias);
330
0
    if (alias->parameters == NULL) {
331
0
        alias->parameters = _Py_make_parameters(alias->args);
332
0
        if (alias->parameters == NULL) {
333
0
            result = -1;
334
0
        }
335
0
    }
336
0
    Py_END_CRITICAL_SECTION();
337
0
    return result;
338
0
}
339
340
static PyObject *
341
union_getitem(PyObject *self, PyObject *item)
342
0
{
343
0
    unionobject *alias = (unionobject *)self;
344
0
    if (union_init_parameters(alias) < 0) {
345
0
        return NULL;
346
0
    }
347
348
0
    PyObject *newargs = _Py_subs_parameters(self, alias->args, alias->parameters, item);
349
0
    if (newargs == NULL) {
350
0
        return NULL;
351
0
    }
352
353
0
    PyObject *res = _Py_union_from_tuple(newargs);
354
0
    Py_DECREF(newargs);
355
0
    return res;
356
0
}
357
358
static PyMappingMethods union_as_mapping = {
359
    .mp_subscript = union_getitem,
360
};
361
362
static PyObject *
363
union_parameters(PyObject *self, void *Py_UNUSED(unused))
364
0
{
365
0
    unionobject *alias = (unionobject *)self;
366
0
    if (union_init_parameters(alias) < 0) {
367
0
        return NULL;
368
0
    }
369
0
    return Py_NewRef(alias->parameters);
370
0
}
371
372
static PyObject *
373
union_name(PyObject *Py_UNUSED(self), void *Py_UNUSED(ignored))
374
0
{
375
0
    return PyUnicode_FromString("Union");
376
0
}
377
378
static PyObject *
379
union_origin(PyObject *Py_UNUSED(self), void *Py_UNUSED(ignored))
380
0
{
381
0
    return Py_NewRef(&_PyUnion_Type);
382
0
}
383
384
static PyGetSetDef union_properties[] = {
385
    {"__name__", union_name, NULL,
386
     PyDoc_STR("Name of the type"), NULL},
387
    {"__qualname__", union_name, NULL,
388
     PyDoc_STR("Qualified name of the type"), NULL},
389
    {"__origin__", union_origin, NULL,
390
     PyDoc_STR("Always returns the type"), NULL},
391
    {"__parameters__", union_parameters, NULL,
392
     PyDoc_STR("Type variables in the types.UnionType."), NULL},
393
    {0}
394
};
395
396
static PyObject *
397
union_nb_or(PyObject *a, PyObject *b)
398
0
{
399
0
    unionbuilder ub;
400
0
    if (!unionbuilder_init(&ub, true)) {
401
0
        return NULL;
402
0
    }
403
0
    if (!unionbuilder_add_single(&ub, a) ||
404
0
        !unionbuilder_add_single(&ub, b)) {
405
0
        unionbuilder_finalize(&ub);
406
0
        return NULL;
407
0
    }
408
0
    return make_union(&ub);
409
0
}
410
411
static PyNumberMethods union_as_number = {
412
        .nb_or = union_nb_or, // Add __or__ function
413
};
414
415
static const char* const cls_attrs[] = {
416
        "__module__",  // Required for compatibility with typing module
417
        NULL,
418
};
419
420
static PyObject *
421
union_getattro(PyObject *self, PyObject *name)
422
0
{
423
0
    unionobject *alias = (unionobject *)self;
424
0
    if (PyUnicode_Check(name)) {
425
0
        for (const char * const *p = cls_attrs; ; p++) {
426
0
            if (*p == NULL) {
427
0
                break;
428
0
            }
429
0
            if (_PyUnicode_EqualToASCIIString(name, *p)) {
430
0
                return PyObject_GetAttr((PyObject *) Py_TYPE(alias), name);
431
0
            }
432
0
        }
433
0
    }
434
0
    return PyObject_GenericGetAttr(self, name);
435
0
}
436
437
PyObject *
438
_Py_union_args(PyObject *self)
439
0
{
440
0
    assert(_PyUnion_Check(self));
441
0
    return ((unionobject *) self)->args;
442
0
}
443
444
static PyObject *
445
call_typing_func_object(const char *name, PyObject **args, size_t nargs)
446
0
{
447
0
    PyObject *typing = PyImport_ImportModule("typing");
448
0
    if (typing == NULL) {
449
0
        return NULL;
450
0
    }
451
0
    PyObject *func = PyObject_GetAttrString(typing, name);
452
0
    if (func == NULL) {
453
0
        Py_DECREF(typing);
454
0
        return NULL;
455
0
    }
456
0
    PyObject *result = PyObject_Vectorcall(func, args, nargs, NULL);
457
0
    Py_DECREF(func);
458
0
    Py_DECREF(typing);
459
0
    return result;
460
0
}
461
462
static PyObject *
463
type_check(PyObject *arg, const char *msg)
464
0
{
465
0
    if (Py_IsNone(arg)) {
466
        // NoneType is immortal, so don't need an INCREF
467
0
        return (PyObject *)Py_TYPE(arg);
468
0
    }
469
    // Fast path to avoid calling into typing.py
470
0
    if (is_unionable(arg)) {
471
0
        return Py_NewRef(arg);
472
0
    }
473
0
    PyObject *message_str = PyUnicode_FromString(msg);
474
0
    if (message_str == NULL) {
475
0
        return NULL;
476
0
    }
477
0
    PyObject *args[2] = {arg, message_str};
478
0
    PyObject *result = call_typing_func_object("_type_check", args, 2);
479
0
    Py_DECREF(message_str);
480
0
    return result;
481
0
}
482
483
PyObject *
484
_Py_union_from_tuple(PyObject *args)
485
0
{
486
0
    unionbuilder ub;
487
0
    if (!unionbuilder_init(&ub, true)) {
488
0
        return NULL;
489
0
    }
490
0
    if (PyTuple_CheckExact(args)) {
491
0
        if (!unionbuilder_add_tuple(&ub, args)) {
492
0
            unionbuilder_finalize(&ub);
493
0
            return NULL;
494
0
        }
495
0
    }
496
0
    else {
497
0
        if (!unionbuilder_add_single(&ub, args)) {
498
0
            unionbuilder_finalize(&ub);
499
0
            return NULL;
500
0
        }
501
0
    }
502
0
    return make_union(&ub);
503
0
}
504
505
static PyObject *
506
union_class_getitem(PyObject *cls, PyObject *args)
507
0
{
508
0
    return _Py_union_from_tuple(args);
509
0
}
510
511
static PyObject *
512
union_mro_entries(PyObject *self, PyObject *args)
513
0
{
514
0
    return PyErr_Format(PyExc_TypeError,
515
0
                        "Cannot subclass %R", self);
516
0
}
517
518
static PyMethodDef union_methods[] = {
519
    {"__mro_entries__", union_mro_entries, METH_O},
520
    {"__class_getitem__", union_class_getitem, METH_O|METH_CLASS, PyDoc_STR("See PEP 585")},
521
    {0}
522
};
523
524
PyTypeObject _PyUnion_Type = {
525
    PyVarObject_HEAD_INIT(&PyType_Type, 0)
526
    .tp_name = "typing.Union",
527
    .tp_doc = PyDoc_STR("Represent a union type\n"
528
              "\n"
529
              "E.g. for int | str"),
530
    .tp_basicsize = sizeof(unionobject),
531
    .tp_dealloc = unionobject_dealloc,
532
    .tp_alloc = PyType_GenericAlloc,
533
    .tp_free = PyObject_GC_Del,
534
    .tp_flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_HAVE_GC,
535
    .tp_traverse = union_traverse,
536
    .tp_hash = union_hash,
537
    .tp_getattro = union_getattro,
538
    .tp_members = union_members,
539
    .tp_methods = union_methods,
540
    .tp_richcompare = union_richcompare,
541
    .tp_as_mapping = &union_as_mapping,
542
    .tp_as_number = &union_as_number,
543
    .tp_repr = union_repr,
544
    .tp_getset = union_properties,
545
    .tp_weaklistoffset = offsetof(unionobject, weakreflist),
546
};
547
548
static PyObject *
549
make_union(unionbuilder *ub)
550
762
{
551
762
    Py_ssize_t n = PyList_GET_SIZE(ub->args);
552
762
    if (n == 0) {
553
0
        PyErr_SetString(PyExc_TypeError, "Cannot take a Union of no types.");
554
0
        unionbuilder_finalize(ub);
555
0
        return NULL;
556
0
    }
557
762
    if (n == 1) {
558
0
        PyObject *result = PyList_GET_ITEM(ub->args, 0);
559
0
        Py_INCREF(result);
560
0
        unionbuilder_finalize(ub);
561
0
        return result;
562
0
    }
563
564
762
    PyObject *args = NULL, *hashable_args = NULL, *unhashable_args = NULL;
565
762
    args = PyList_AsTuple(ub->args);
566
762
    if (args == NULL) {
567
0
        goto error;
568
0
    }
569
762
    hashable_args = PyFrozenSet_New(ub->hashable_args);
570
762
    if (hashable_args == NULL) {
571
0
        goto error;
572
0
    }
573
762
    if (ub->unhashable_args != NULL) {
574
0
        unhashable_args = PyList_AsTuple(ub->unhashable_args);
575
0
        if (unhashable_args == NULL) {
576
0
            goto error;
577
0
        }
578
0
    }
579
580
762
    unionobject *result = PyObject_GC_New(unionobject, &_PyUnion_Type);
581
762
    if (result == NULL) {
582
0
        goto error;
583
0
    }
584
762
    unionbuilder_finalize(ub);
585
586
762
    result->parameters = NULL;
587
762
    result->args = args;
588
762
    result->hashable_args = hashable_args;
589
762
    result->unhashable_args = unhashable_args;
590
762
    result->weakreflist = NULL;
591
762
    _PyObject_GC_TRACK(result);
592
762
    return (PyObject*)result;
593
0
error:
594
0
    Py_XDECREF(args);
595
0
    Py_XDECREF(hashable_args);
596
0
    Py_XDECREF(unhashable_args);
597
0
    unionbuilder_finalize(ub);
598
    return NULL;
599
762
}