Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/autograph/converters/variables.py: 28%

36 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2020 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""Overloads all variable read operations.""" 

16 

17import gast 

18 

19from tensorflow.python.autograph.core import converter 

20from tensorflow.python.autograph.pyct import anno 

21from tensorflow.python.autograph.pyct import templates 

22 

23 

24class VariableAccessTransformer(converter.Base): 

25 """Rewrites basic symbol reads. 

26 

27 This transformer rewrites variable reads with a "read" operator which allows 

28 tracking activity. 

29 

30 Example: 

31 

32 For a basic statement: 

33 

34 a = b + c 

35 

36 This is translated to: 

37 

38 a = ld(b) + ld(c) 

39 

40 Augmented assignment operations also introduce a `ld` operator: 

41 

42 a += b 

43 

44 The assignment target also receives an operator to properly represent the 

45 read: 

46 

47 a = ld(a) 

48 a += ld(b) 

49 """ 

50 

51 def visit_Name(self, node): 

52 # Only the loads which existed in the original code are overloaded. 

53 if not anno.hasanno(node, anno.Static.ORIG_DEFINITIONS): 

54 return node 

55 if isinstance(node.ctx, gast.Load): 

56 node = templates.replace_as_expression('ag__.ld(var_)', var_=node) 

57 return node 

58 

59 def visit_Delete(self, node): 

60 node = self.generic_visit(node) 

61 

62 rewrite_targets = [] 

63 for tgt in node.targets: 

64 # Don't rewrite composites like `del a[0]`. 

65 if isinstance(tgt, gast.Name): 

66 rewrite_targets.append(tgt) 

67 

68 if not rewrite_targets: 

69 return node 

70 

71 results = [] 

72 for tgt in rewrite_targets: 

73 template = """ 

74 var_ = ag__.Undefined(var_name) 

75 """ 

76 results.extend(templates.replace( 

77 template, var_=tgt, var_name=gast.Constant(tgt.id, kind=None))) 

78 remaining_targets = [n for n in node.targets if n not in rewrite_targets] 

79 if remaining_targets: 

80 results.append(gast.Delete(targets=remaining_targets)) 

81 

82 return results 

83 

84 def visit_AugAssign(self, node): 

85 if isinstance(node.target, gast.Name): 

86 template = """ 

87 var_ = ag__.ld(var_) 

88 original 

89 """ 

90 node = templates.replace(template, var_=node.target, original=node) 

91 else: 

92 node = self.generic_visit(node) 

93 return node 

94 

95 

96def transform(node, ctx): 

97 return VariableAccessTransformer(ctx).visit(node)