1# Copyright (c) Meta Platforms, Inc. and affiliates.
2#
3# This source code is licensed under the MIT license found in the
4# LICENSE file in the root directory of this source tree.
5
6"""
7Provides the implementation of `CSTNode.deep_equals`.
8"""
9
10from dataclasses import fields
11from typing import Sequence
12
13from libcst._nodes.base import CSTNode
14
15
16def deep_equals(a: object, b: object) -> bool:
17 if isinstance(a, CSTNode) and isinstance(b, CSTNode):
18 return _deep_equals_cst_node(a, b)
19 elif (
20 isinstance(a, Sequence)
21 and not isinstance(a, (str, bytes))
22 and isinstance(b, Sequence)
23 and not isinstance(b, (str, bytes))
24 ):
25 return _deep_equals_sequence(a, b)
26 else:
27 return a == b
28
29
30def _deep_equals_sequence(a: Sequence[object], b: Sequence[object]) -> bool:
31 """
32 A helper function for `CSTNode.deep_equals`.
33
34 Normalizes and compares sequences. Because we only ever expose `Sequence[]`
35 types, and not `List[]`, `Tuple[]`, or `Iterable[]` values, all sequences should
36 be treated as equal if they have the same values.
37 """
38 if a is b: # short-circuit
39 return True
40 if len(a) != len(b):
41 return False
42 return all(deep_equals(a_el, b_el) for (a_el, b_el) in zip(a, b))
43
44
45def _deep_equals_cst_node(a: "CSTNode", b: "CSTNode") -> bool:
46 if type(a) is not type(b):
47 return False
48 if a is b: # short-circuit
49 return True
50 # Ignore metadata and other hidden fields
51 for field in (f for f in fields(a) if f.compare is True):
52 a_value = getattr(a, field.name)
53 b_value = getattr(b, field.name)
54 if not deep_equals(a_value, b_value):
55 return False
56 return True