Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/autograph/converters/lists.py: 27%
101 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Converter for list operations.
17This includes converting Python lists to TensorArray/TensorList.
18"""
20# TODO(mdan): Elaborate the logic here.
21# TODO(mdan): Does it even make sense to attempt to try to use TAs?
22# The current rule (always convert to TensorArray) is naive and insufficient.
23# In general, a better mechanism could look like:
24# * convert to TensorList by default
25# * leave as Python list if the user explicitly forbids it
26# * convert to TensorArray only when complete write once behavior can be
27# guaranteed (e.g. list comprehensions)
29import gast
31from tensorflow.python.autograph.core import converter
32from tensorflow.python.autograph.lang import directives
33from tensorflow.python.autograph.pyct import anno
34from tensorflow.python.autograph.pyct import parser
35from tensorflow.python.autograph.pyct import qual_names
36from tensorflow.python.autograph.pyct import templates
37from tensorflow.python.autograph.pyct.static_analysis import activity
38from tensorflow.python.autograph.pyct.static_analysis.annos import NodeAnno
41class _Statement(object):
43 def __init__(self):
44 self.pop_uses = None
47class ListTransformer(converter.Base):
48 """Converts lists and related operations to their TF counterpart."""
50 def visit_List(self, node):
51 node = self.generic_visit(node)
52 template = """
53 ag__.new_list(elements)
54 """
55 return templates.replace_as_expression(template, elements=node)
57 def _replace_append_call(self, node):
58 assert len(node.args) == 1
59 assert isinstance(node.func, gast.Attribute)
60 template = """
61 target = ag__.list_append(target, element)
62 """
63 return templates.replace(
64 template,
65 target=node.func.value,
66 element=node.args[0])
68 def _replace_pop_call(self, node):
69 # Expressions that use pop() are converted to a statement + expression.
70 #
71 # For example:
72 #
73 # print(target.pop())
74 #
75 # ... is converted to:
76 #
77 # target, target_pop = ag__.list_pop(target)
78 # print(target_pop)
79 #
80 # Here, we just generate the variable name and swap it in,
81 # and _generate_pop_operation will handle the rest.
82 #
83 # Multiple uses of pop() are allowed:
84 #
85 # print(tartget.pop(), target.pop())
86 # print(tartget.pop().pop())
87 #
88 assert isinstance(node.func, gast.Attribute)
89 scope = anno.getanno(node, NodeAnno.ARGS_SCOPE)
90 target_node = node.func.value
92 # Attempt to use a related name if one exists. Otherwise use something
93 # generic.
94 if anno.hasanno(target_node, anno.Basic.QN):
95 target_name = anno.getanno(target_node, anno.Basic.QN).ssf()
96 else:
97 target_name = 'list_'
98 pop_var_name = self.ctx.namer.new_symbol(target_name, scope.referenced)
100 stmt = self.state[_Statement]
101 if stmt.pop_uses is None:
102 stmt.pop_uses = []
103 stmt.pop_uses.append((node, pop_var_name))
105 return templates.replace_as_expression('var_name', var_name=pop_var_name)
107 def _replace_stack_call(self, node):
108 assert len(node.args) == 1
109 dtype = self.get_definition_directive(
110 node.args[0],
111 directives.set_element_type,
112 'dtype',
113 default=templates.replace_as_expression('None'))
114 template = """
115 ag__.list_stack(
116 target,
117 opts=ag__.ListStackOpts(
118 element_dtype=dtype,
119 original_call=orig_call))
120 """
121 return templates.replace_as_expression(
122 template,
123 dtype=dtype,
124 target=node.args[0],
125 orig_call=node.func)
127 def visit_Call(self, node):
128 node = self.generic_visit(node)
130 # TODO(mdan): This is insufficient if target is a function argument.
131 # In the case of function arguments, we need to add the list to the
132 # function's return value, because it is being modified.
133 # TODO(mdan): Checking just the name is brittle, can it be improved?
134 if isinstance(node.func, gast.Attribute):
135 func_name = node.func.attr
136 if func_name == 'append' and (len(node.args) == 1):
137 node = self._replace_append_call(node)
138 elif func_name == 'pop' and (len(node.args) <= 1):
139 node = self._replace_pop_call(node)
140 elif (func_name == 'stack' and (len(node.args) == 1) and
141 (not node.keywords or node.keywords[0].arg == 'strict')):
142 # This avoids false positives with keyword args.
143 # TODO(mdan): handle kwargs properly.
144 node = self._replace_stack_call(node)
146 return node
148 def _generate_pop_operation(self, original_call_node, pop_var_name):
149 assert isinstance(original_call_node.func, gast.Attribute)
151 if original_call_node.args:
152 pop_element = original_call_node.args[0]
153 else:
154 pop_element = parser.parse_expression('None')
156 # The call will be something like "target.pop()", and the dtype is hooked to
157 # target, hence the func.value.
158 # TODO(mdan): For lists of lists, this won't work.
159 # The reason why it won't work is because it's unclear how to annotate
160 # the list as a "list of lists with a certain element type" when using
161 # operations like `l.pop().pop()`.
162 dtype = self.get_definition_directive(
163 original_call_node.func.value,
164 directives.set_element_type,
165 'dtype',
166 default=templates.replace_as_expression('None'))
167 shape = self.get_definition_directive(
168 original_call_node.func.value,
169 directives.set_element_type,
170 'shape',
171 default=templates.replace_as_expression('None'))
173 template = """
174 target, pop_var_name = ag__.list_pop(
175 target, element,
176 opts=ag__.ListPopOpts(element_dtype=dtype, element_shape=shape))
177 """
178 return templates.replace(
179 template,
180 target=original_call_node.func.value,
181 pop_var_name=pop_var_name,
182 element=pop_element,
183 dtype=dtype,
184 shape=shape)
186 def _postprocess_statement(self, node):
187 """Inserts any separate pop() calls that node may use."""
188 pop_uses = self.state[_Statement].pop_uses
189 if pop_uses:
190 replacements = []
191 for original_call_node, pop_var_name in pop_uses:
192 replacements.extend(
193 self._generate_pop_operation(original_call_node, pop_var_name))
194 replacements.append(node)
195 node = replacements
196 self.state[_Statement].exit()
197 return node, None
199 def _visit_and_process_block(self, block):
200 return self.visit_block(
201 block,
202 before_visit=self.state[_Statement].enter,
203 after_visit=self._postprocess_statement)
205 def visit_FunctionDef(self, node):
206 node.args = self.generic_visit(node.args)
207 node.decorator_list = self.visit_block(node.decorator_list)
208 node.body = self._visit_and_process_block(node.body)
209 return node
211 def visit_For(self, node):
212 node.target = self.visit(node.target)
213 node.body = self._visit_and_process_block(node.body)
214 node.orelse = self._visit_and_process_block(node.orelse)
215 return node
217 def visit_While(self, node):
218 node.test = self.visit(node.test)
219 node.body = self._visit_and_process_block(node.body)
220 node.orelse = self._visit_and_process_block(node.orelse)
221 return node
223 def visit_If(self, node):
224 node.test = self.visit(node.test)
225 node.body = self._visit_and_process_block(node.body)
226 node.orelse = self._visit_and_process_block(node.orelse)
227 return node
229 def visit_With(self, node):
230 node.items = self.visit_block(node.items)
231 node.body = self._visit_and_process_block(node.body)
232 return node
235def transform(node, ctx):
236 node = qual_names.resolve(node)
237 node = activity.resolve(node, ctx, None)
239 return ListTransformer(ctx).visit(node)