Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/autograph/converters/lists.py: 27%

101 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""Converter for list operations. 

16 

17This includes converting Python lists to TensorArray/TensorList. 

18""" 

19 

20# TODO(mdan): Elaborate the logic here. 

21# TODO(mdan): Does it even make sense to attempt to try to use TAs? 

22# The current rule (always convert to TensorArray) is naive and insufficient. 

23# In general, a better mechanism could look like: 

24# * convert to TensorList by default 

25# * leave as Python list if the user explicitly forbids it 

26# * convert to TensorArray only when complete write once behavior can be 

27# guaranteed (e.g. list comprehensions) 

28 

29import gast 

30 

31from tensorflow.python.autograph.core import converter 

32from tensorflow.python.autograph.lang import directives 

33from tensorflow.python.autograph.pyct import anno 

34from tensorflow.python.autograph.pyct import parser 

35from tensorflow.python.autograph.pyct import qual_names 

36from tensorflow.python.autograph.pyct import templates 

37from tensorflow.python.autograph.pyct.static_analysis import activity 

38from tensorflow.python.autograph.pyct.static_analysis.annos import NodeAnno 

39 

40 

41class _Statement(object): 

42 

43 def __init__(self): 

44 self.pop_uses = None 

45 

46 

47class ListTransformer(converter.Base): 

48 """Converts lists and related operations to their TF counterpart.""" 

49 

50 def visit_List(self, node): 

51 node = self.generic_visit(node) 

52 template = """ 

53 ag__.new_list(elements) 

54 """ 

55 return templates.replace_as_expression(template, elements=node) 

56 

57 def _replace_append_call(self, node): 

58 assert len(node.args) == 1 

59 assert isinstance(node.func, gast.Attribute) 

60 template = """ 

61 target = ag__.list_append(target, element) 

62 """ 

63 return templates.replace( 

64 template, 

65 target=node.func.value, 

66 element=node.args[0]) 

67 

68 def _replace_pop_call(self, node): 

69 # Expressions that use pop() are converted to a statement + expression. 

70 # 

71 # For example: 

72 # 

73 # print(target.pop()) 

74 # 

75 # ... is converted to: 

76 # 

77 # target, target_pop = ag__.list_pop(target) 

78 # print(target_pop) 

79 # 

80 # Here, we just generate the variable name and swap it in, 

81 # and _generate_pop_operation will handle the rest. 

82 # 

83 # Multiple uses of pop() are allowed: 

84 # 

85 # print(tartget.pop(), target.pop()) 

86 # print(tartget.pop().pop()) 

87 # 

88 assert isinstance(node.func, gast.Attribute) 

89 scope = anno.getanno(node, NodeAnno.ARGS_SCOPE) 

90 target_node = node.func.value 

91 

92 # Attempt to use a related name if one exists. Otherwise use something 

93 # generic. 

94 if anno.hasanno(target_node, anno.Basic.QN): 

95 target_name = anno.getanno(target_node, anno.Basic.QN).ssf() 

96 else: 

97 target_name = 'list_' 

98 pop_var_name = self.ctx.namer.new_symbol(target_name, scope.referenced) 

99 

100 stmt = self.state[_Statement] 

101 if stmt.pop_uses is None: 

102 stmt.pop_uses = [] 

103 stmt.pop_uses.append((node, pop_var_name)) 

104 

105 return templates.replace_as_expression('var_name', var_name=pop_var_name) 

106 

107 def _replace_stack_call(self, node): 

108 assert len(node.args) == 1 

109 dtype = self.get_definition_directive( 

110 node.args[0], 

111 directives.set_element_type, 

112 'dtype', 

113 default=templates.replace_as_expression('None')) 

114 template = """ 

115 ag__.list_stack( 

116 target, 

117 opts=ag__.ListStackOpts( 

118 element_dtype=dtype, 

119 original_call=orig_call)) 

120 """ 

121 return templates.replace_as_expression( 

122 template, 

123 dtype=dtype, 

124 target=node.args[0], 

125 orig_call=node.func) 

126 

127 def visit_Call(self, node): 

128 node = self.generic_visit(node) 

129 

130 # TODO(mdan): This is insufficient if target is a function argument. 

131 # In the case of function arguments, we need to add the list to the 

132 # function's return value, because it is being modified. 

133 # TODO(mdan): Checking just the name is brittle, can it be improved? 

134 if isinstance(node.func, gast.Attribute): 

135 func_name = node.func.attr 

136 if func_name == 'append' and (len(node.args) == 1): 

137 node = self._replace_append_call(node) 

138 elif func_name == 'pop' and (len(node.args) <= 1): 

139 node = self._replace_pop_call(node) 

140 elif (func_name == 'stack' and (len(node.args) == 1) and 

141 (not node.keywords or node.keywords[0].arg == 'strict')): 

142 # This avoids false positives with keyword args. 

143 # TODO(mdan): handle kwargs properly. 

144 node = self._replace_stack_call(node) 

145 

146 return node 

147 

148 def _generate_pop_operation(self, original_call_node, pop_var_name): 

149 assert isinstance(original_call_node.func, gast.Attribute) 

150 

151 if original_call_node.args: 

152 pop_element = original_call_node.args[0] 

153 else: 

154 pop_element = parser.parse_expression('None') 

155 

156 # The call will be something like "target.pop()", and the dtype is hooked to 

157 # target, hence the func.value. 

158 # TODO(mdan): For lists of lists, this won't work. 

159 # The reason why it won't work is because it's unclear how to annotate 

160 # the list as a "list of lists with a certain element type" when using 

161 # operations like `l.pop().pop()`. 

162 dtype = self.get_definition_directive( 

163 original_call_node.func.value, 

164 directives.set_element_type, 

165 'dtype', 

166 default=templates.replace_as_expression('None')) 

167 shape = self.get_definition_directive( 

168 original_call_node.func.value, 

169 directives.set_element_type, 

170 'shape', 

171 default=templates.replace_as_expression('None')) 

172 

173 template = """ 

174 target, pop_var_name = ag__.list_pop( 

175 target, element, 

176 opts=ag__.ListPopOpts(element_dtype=dtype, element_shape=shape)) 

177 """ 

178 return templates.replace( 

179 template, 

180 target=original_call_node.func.value, 

181 pop_var_name=pop_var_name, 

182 element=pop_element, 

183 dtype=dtype, 

184 shape=shape) 

185 

186 def _postprocess_statement(self, node): 

187 """Inserts any separate pop() calls that node may use.""" 

188 pop_uses = self.state[_Statement].pop_uses 

189 if pop_uses: 

190 replacements = [] 

191 for original_call_node, pop_var_name in pop_uses: 

192 replacements.extend( 

193 self._generate_pop_operation(original_call_node, pop_var_name)) 

194 replacements.append(node) 

195 node = replacements 

196 self.state[_Statement].exit() 

197 return node, None 

198 

199 def _visit_and_process_block(self, block): 

200 return self.visit_block( 

201 block, 

202 before_visit=self.state[_Statement].enter, 

203 after_visit=self._postprocess_statement) 

204 

205 def visit_FunctionDef(self, node): 

206 node.args = self.generic_visit(node.args) 

207 node.decorator_list = self.visit_block(node.decorator_list) 

208 node.body = self._visit_and_process_block(node.body) 

209 return node 

210 

211 def visit_For(self, node): 

212 node.target = self.visit(node.target) 

213 node.body = self._visit_and_process_block(node.body) 

214 node.orelse = self._visit_and_process_block(node.orelse) 

215 return node 

216 

217 def visit_While(self, node): 

218 node.test = self.visit(node.test) 

219 node.body = self._visit_and_process_block(node.body) 

220 node.orelse = self._visit_and_process_block(node.orelse) 

221 return node 

222 

223 def visit_If(self, node): 

224 node.test = self.visit(node.test) 

225 node.body = self._visit_and_process_block(node.body) 

226 node.orelse = self._visit_and_process_block(node.orelse) 

227 return node 

228 

229 def visit_With(self, node): 

230 node.items = self.visit_block(node.items) 

231 node.body = self._visit_and_process_block(node.body) 

232 return node 

233 

234 

235def transform(node, ctx): 

236 node = qual_names.resolve(node) 

237 node = activity.resolve(node, ctx, None) 

238 

239 return ListTransformer(ctx).visit(node)