1# Copyright (c) Meta Platforms, Inc. and affiliates.
2#
3# This source code is licensed under the MIT license found in the
4# LICENSE file in the root directory of this source tree.
5
6from __future__ import annotations
7
8import dataclasses
9from typing import TYPE_CHECKING
10
11from libcst import IndentedBlock, Module
12from libcst._nodes.deep_equals import deep_equals
13
14if TYPE_CHECKING:
15 from typing import Sequence
16
17 from libcst import CSTNode
18
19
20def get_node_fields(node: CSTNode) -> Sequence[dataclasses.Field[CSTNode]]:
21 """
22 Returns the sequence of a given CST-node's fields.
23 """
24 return dataclasses.fields(node)
25
26
27def is_whitespace_node_field(node: CSTNode, field: dataclasses.Field[CSTNode]) -> bool:
28 """
29 Returns True if a given CST-node's field is a whitespace-related field
30 (whitespace, indent, header, footer, etc.).
31 """
32 if "whitespace" in field.name:
33 return True
34 if "leading_lines" in field.name:
35 return True
36 if "lines_after_decorators" in field.name:
37 return True
38 if isinstance(node, (IndentedBlock, Module)) and field.name in [
39 "header",
40 "footer",
41 ]:
42 return True
43 if isinstance(node, IndentedBlock) and field.name == "indent":
44 return True
45 return False
46
47
48def is_syntax_node_field(node: CSTNode, field: dataclasses.Field[CSTNode]) -> bool:
49 """
50 Returns True if a given CST-node's field is a syntax-related field
51 (colon, semicolon, dot, encoding, etc.).
52 """
53 if isinstance(node, Module) and field.name in [
54 "encoding",
55 "default_indent",
56 "default_newline",
57 "has_trailing_newline",
58 ]:
59 return True
60 type_str = repr(field.type)
61 if (
62 "Sentinel" in type_str
63 and field.name not in ["star_arg", "star", "posonly_ind"]
64 and "whitespace" not in field.name
65 ):
66 # This is a value that can optionally be specified, so its
67 # definitely syntax.
68 return True
69
70 for name in ["Semicolon", "Colon", "Comma", "Dot", "AssignEqual"]:
71 # These are all nodes that exist for separation syntax
72 if name in type_str:
73 return True
74
75 return False
76
77
78def get_field_default_value(field: dataclasses.Field[CSTNode]) -> object:
79 """
80 Returns the default value of a CST-node's field.
81 """
82 if field.default_factory is not dataclasses.MISSING:
83 # pyre-fixme[29]: `Union[dataclasses._MISSING_TYPE,
84 # dataclasses._DefaultFactory[object]]` is not a function.
85 return field.default_factory()
86 return field.default
87
88
89def is_default_node_field(node: CSTNode, field: dataclasses.Field[CSTNode]) -> bool:
90 """
91 Returns True if a given CST-node's field has its default value.
92 """
93 return deep_equals(getattr(node, field.name), get_field_default_value(field))
94
95
96def filter_node_fields(
97 node: CSTNode,
98 *,
99 show_defaults: bool,
100 show_syntax: bool,
101 show_whitespace: bool,
102) -> Sequence[dataclasses.Field[CSTNode]]:
103 """
104 Returns a filtered sequence of a CST-node's fields.
105
106 Setting ``show_whitespace`` to ``False`` will filter whitespace fields.
107
108 Setting ``show_defaults`` to ``False`` will filter fields if their value is equal to
109 the default value ; while respecting the value of ``show_whitespace``.
110
111 Setting ``show_syntax`` to ``False`` will filter syntax fields ; while respecting
112 the value of ``show_whitespace`` & ``show_defaults``.
113 """
114
115 fields: Sequence[dataclasses.Field[CSTNode]] = dataclasses.fields(node)
116 # Hide all fields prefixed with "_"
117 fields = [f for f in fields if f.name[0] != "_"]
118 # Filter whitespace nodes if needed
119 if not show_whitespace:
120 fields = [f for f in fields if not is_whitespace_node_field(node, f)]
121 # Filter values which aren't changed from their defaults
122 if not show_defaults:
123 fields = [f for f in fields if not is_default_node_field(node, f)]
124 # Filter out values which aren't interesting if needed
125 if not show_syntax:
126 fields = [f for f in fields if not is_syntax_node_field(node, f)]
127
128 return fields