Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/autograph/converters/variables.py: 28%
36 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Overloads all variable read operations."""
17import gast
19from tensorflow.python.autograph.core import converter
20from tensorflow.python.autograph.pyct import anno
21from tensorflow.python.autograph.pyct import templates
24class VariableAccessTransformer(converter.Base):
25 """Rewrites basic symbol reads.
27 This transformer rewrites variable reads with a "read" operator which allows
28 tracking activity.
30 Example:
32 For a basic statement:
34 a = b + c
36 This is translated to:
38 a = ld(b) + ld(c)
40 Augmented assignment operations also introduce a `ld` operator:
42 a += b
44 The assignment target also receives an operator to properly represent the
45 read:
47 a = ld(a)
48 a += ld(b)
49 """
51 def visit_Name(self, node):
52 # Only the loads which existed in the original code are overloaded.
53 if not anno.hasanno(node, anno.Static.ORIG_DEFINITIONS):
54 return node
55 if isinstance(node.ctx, gast.Load):
56 node = templates.replace_as_expression('ag__.ld(var_)', var_=node)
57 return node
59 def visit_Delete(self, node):
60 node = self.generic_visit(node)
62 rewrite_targets = []
63 for tgt in node.targets:
64 # Don't rewrite composites like `del a[0]`.
65 if isinstance(tgt, gast.Name):
66 rewrite_targets.append(tgt)
68 if not rewrite_targets:
69 return node
71 results = []
72 for tgt in rewrite_targets:
73 template = """
74 var_ = ag__.Undefined(var_name)
75 """
76 results.extend(templates.replace(
77 template, var_=tgt, var_name=gast.Constant(tgt.id, kind=None)))
78 remaining_targets = [n for n in node.targets if n not in rewrite_targets]
79 if remaining_targets:
80 results.append(gast.Delete(targets=remaining_targets))
82 return results
84 def visit_AugAssign(self, node):
85 if isinstance(node.target, gast.Name):
86 template = """
87 var_ = ag__.ld(var_)
88 original
89 """
90 node = templates.replace(template, var_=node.target, original=node)
91 else:
92 node = self.generic_visit(node)
93 return node
96def transform(node, ctx):
97 return VariableAccessTransformer(ctx).visit(node)