Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/pasta/base/codegen.py: 26%

102 statements  

« prev     ^ index     » next       coverage.py v7.2.7, created at 2023-06-07 06:12 +0000

1# coding=utf-8 

2"""Generate code from an annotated syntax tree.""" 

3# Copyright 2021 Google LLC 

4# 

5# Licensed under the Apache License, Version 2.0 (the "License"); 

6# you may not use this file except in compliance with the License. 

7# You may obtain a copy of the License at 

8# 

9# https://www.apache.org/licenses/LICENSE-2.0 

10# 

11# Unless required by applicable law or agreed to in writing, software 

12# distributed under the License is distributed on an "AS IS" BASIS, 

13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

14# See the License for the specific language governing permissions and 

15# limitations under the License. 

16 

17from __future__ import absolute_import 

18from __future__ import division 

19from __future__ import print_function 

20 

21import ast 

22import collections 

23import six 

24 

25from pasta.base import annotate 

26from pasta.base import formatting as fmt 

27from pasta.base import fstring_utils 

28 

29 

30class PrintError(Exception): 

31 """An exception for when we failed to print the tree.""" 

32 

33 

34class Printer(annotate.BaseVisitor): 

35 """Traverses an AST and generates formatted python source code. 

36  

37 This uses the same base visitor as annotating the AST, but instead of eating a 

38 token it spits one out. For special formatting information which was stored on 

39 the node, this is output exactly as it was read in unless one or more of the 

40 dependency attributes used to generate it has changed, in which case its 

41 default formatting is used. 

42 """ 

43 

44 def __init__(self): 

45 super(Printer, self).__init__() 

46 self.code = '' 

47 

48 def visit(self, node): 

49 node._printer_info = collections.defaultdict(lambda: False) 

50 try: 

51 super(Printer, self).visit(node) 

52 except (TypeError, ValueError, IndexError, KeyError) as e: 

53 raise PrintError(e) 

54 del node._printer_info 

55 

56 def visit_Module(self, node): 

57 self.prefix(node) 

58 bom = fmt.get(node, 'bom') 

59 if bom is not None: 

60 self.code += bom 

61 self.generic_visit(node) 

62 self.suffix(node) 

63 

64 def visit_Num(self, node): 

65 self.prefix(node) 

66 content = fmt.get(node, 'content') 

67 self.code += content if content is not None else repr(node.n) 

68 self.suffix(node) 

69 

70 def visit_Str(self, node): 

71 self.prefix(node) 

72 content = fmt.get(node, 'content') 

73 self.code += content if content is not None else repr(node.s) 

74 self.suffix(node) 

75 

76 def visit_JoinedStr(self, node): 

77 self.prefix(node) 

78 content = fmt.get(node, 'content') 

79 

80 if content is None: 

81 parts = [] 

82 for val in node.values: 

83 if isinstance(val, ast.Str): 

84 parts.append(val.s) 

85 else: 

86 parts.append(fstring_utils.placeholder(len(parts))) 

87 content = repr(''.join(parts)) 

88 

89 values = [to_str(v) for v in fstring_utils.get_formatted_values(node)] 

90 self.code += fstring_utils.perform_replacements(content, values) 

91 self.suffix(node) 

92 

93 def visit_Bytes(self, node): 

94 self.prefix(node) 

95 content = fmt.get(node, 'content') 

96 self.code += content if content is not None else repr(node.s) 

97 self.suffix(node) 

98 

99 def visit_Constant(self, node): 

100 self.prefix(node) 

101 if node.value is Ellipsis: 

102 content = '...' 

103 else: 

104 content = fmt.get(node, 'content') 

105 self.code += content if content is not None else repr(node.s) 

106 self.suffix(node) 

107 

108 def token(self, value): 

109 self.code += value 

110 

111 def optional_token(self, node, attr_name, token_val, 

112 allow_whitespace_prefix=False, default=False): 

113 del allow_whitespace_prefix 

114 value = fmt.get(node, attr_name) 

115 if value is None and default: 

116 value = token_val 

117 self.code += value or '' 

118 

119 def attr(self, node, attr_name, attr_vals, deps=None, default=None): 

120 """Add the formatted data stored for a given attribute on this node. 

121 

122 If any of the dependent attributes of the node have changed since it was 

123 annotated, then the stored formatted data for this attr_name is no longer 

124 valid, and we must use the default instead. 

125  

126 Arguments: 

127 node: (ast.AST) An AST node to retrieve formatting information from. 

128 attr_name: (string) Name to load the formatting information from. 

129 attr_vals: (list of functions/strings) Unused here. 

130 deps: (optional, set of strings) Attributes of the node which the stored 

131 formatting data depends on. 

132 default: (string) Default formatted data for this attribute. 

133 """ 

134 del attr_vals 

135 if not hasattr(node, '_printer_info') or node._printer_info[attr_name]: 

136 return 

137 node._printer_info[attr_name] = True 

138 val = fmt.get(node, attr_name) 

139 if (val is None or deps and 

140 any(getattr(node, dep, None) != fmt.get(node, dep + '__src') 

141 for dep in deps)): 

142 val = default 

143 self.code += val if val is not None else '' 

144 

145 def check_is_elif(self, node): 

146 try: 

147 return fmt.get(node, 'is_elif') 

148 except AttributeError: 

149 return False 

150 

151 def check_is_continued_try(self, node): 

152 # TODO: Don't set extra attributes on nodes 

153 return getattr(node, 'is_continued', False) 

154 

155 def check_is_continued_with(self, node): 

156 # TODO: Don't set extra attributes on nodes 

157 return getattr(node, 'is_continued', False) 

158 

159 

160def to_str(tree): 

161 """Convenient function to get the python source for an AST.""" 

162 p = Printer() 

163 

164 # Detect the most prevalent indentation style in the file and use it when 

165 # printing indented nodes which don't have formatting data. 

166 seen_indent_diffs = collections.defaultdict(lambda: 0) 

167 for node in ast.walk(tree): 

168 indent_diff = fmt.get(node, 'indent_diff', '') 

169 if indent_diff: 

170 seen_indent_diffs[indent_diff] += 1 

171 if seen_indent_diffs: 

172 indent_diff, _ = max(six.iteritems(seen_indent_diffs), 

173 key=lambda tup: tup[1] if tup[0] else -1) 

174 p.set_default_indent_diff(indent_diff) 

175 

176 p.visit(tree) 

177 return p.code