Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/autograph/converters/directives.py: 25%

91 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""Handles directives. 

16 

17This converter removes the directive functions from the code and moves the 

18information they specify into AST annotations. It is a specialized form of 

19static analysis, one that is specific to AutoGraph. 

20 

21Note that this requires that the actual directive functions are static - that 

22is, they do not change at runtime. So if you do something like this: 

23 

24 tf.autograph.set_loop_options = <new function> 

25 

26Then the directive will may no longer be recognized. Furthermore, if the 

27converted function is cached, such an action may be irreversible. 

28""" 

29 

30import inspect 

31 

32import gast 

33 

34from tensorflow.python.autograph.core import converter 

35from tensorflow.python.autograph.lang import directives 

36from tensorflow.python.autograph.pyct import anno 

37from tensorflow.python.util import tf_inspect 

38 

39 

40STATIC_VALUE = 'static_value' 

41"""Used for AST annotations, see visit_Name.""" 

42 

43 

44class _LoopScope(object): 

45 

46 def __init__(self): 

47 self.ast_node = None 

48 self.statements_visited = 0 

49 

50 

51def _map_args(call_node, function): 

52 """Maps AST call nodes to the actual function's arguments. 

53 

54 Args: 

55 call_node: ast.Call 

56 function: Callable[..., Any], the actual function matching call_node 

57 Returns: 

58 Dict[Text, ast.AST], mapping each of the function's argument names to 

59 the respective AST node. 

60 Raises: 

61 ValueError: if the default arguments are not correctly set 

62 """ 

63 args = call_node.args 

64 kwds = {kwd.arg: kwd.value for kwd in call_node.keywords} 

65 call_args = tf_inspect.getcallargs(function, *args, **kwds) 

66 

67 # Keyword arguments not specified in kwds will be mapped to their defaults, 

68 # which are Python values. Since we don't currently have a way to transform 

69 # those into AST references, we simply remove them. By convention, directives 

70 # use UNSPECIFIED as default value for optional arguments. No other 

71 # defaults should be present. 

72 unexpected_defaults = [] 

73 for k in call_args: 

74 if (k not in kwds 

75 and call_args[k] not in args 

76 and call_args[k] is not directives.UNSPECIFIED): 

77 unexpected_defaults.append(k) 

78 if unexpected_defaults: 

79 raise ValueError('Unexpected keyword argument values, %s, for function %s' 

80 % (zip(unexpected_defaults, 

81 [call_args[k] for k in unexpected_defaults]), 

82 function)) 

83 return {k: v for k, v in call_args.items() if v is not directives.UNSPECIFIED} 

84 

85 

86class DirectivesTransformer(converter.Base): 

87 """Parses compiler directives and converts them into AST annotations.""" 

88 

89 def _process_symbol_directive(self, call_node, directive): 

90 if len(call_node.args) < 1: 

91 raise ValueError('"%s" requires a positional first argument' 

92 ' as the target' % directive.__name__) 

93 target = call_node.args[0] 

94 defs = anno.getanno(target, anno.Static.ORIG_DEFINITIONS) 

95 for def_ in defs: 

96 def_.directives[directive] = _map_args(call_node, directive) 

97 return call_node 

98 

99 def _process_statement_directive(self, call_node, directive): 

100 if self.state[_LoopScope].statements_visited > 1: 

101 raise ValueError( 

102 '"%s" must be the first statement in the loop block' % ( 

103 directive.__name__)) 

104 if self.state[_LoopScope].level < 2: 

105 raise ValueError( 

106 '"%s" must be used inside a statement' % directive.__name__) 

107 target = self.state[_LoopScope].ast_node 

108 node_anno = anno.getanno(target, anno.Basic.DIRECTIVES, {}) 

109 node_anno[directive] = _map_args(call_node, directive) 

110 anno.setanno(target, anno.Basic.DIRECTIVES, node_anno) 

111 return call_node 

112 

113 def visit_Name(self, node): 

114 node = self.generic_visit(node) 

115 if isinstance(node.ctx, gast.Load): 

116 defs = anno.getanno(node, anno.Static.DEFINITIONS, ()) 

117 is_defined = bool(defs) 

118 if not is_defined and node.id in self.ctx.info.namespace: 

119 anno.setanno(node, STATIC_VALUE, self.ctx.info.namespace[node.id]) 

120 return node 

121 

122 def visit_Attribute(self, node): 

123 node = self.generic_visit(node) 

124 parent_val = anno.getanno(node.value, STATIC_VALUE, default=None) 

125 if parent_val is not None and inspect.ismodule(parent_val): 

126 if hasattr(parent_val, node.attr): 

127 anno.setanno(node, STATIC_VALUE, getattr(parent_val, node.attr)) 

128 return node 

129 

130 def visit_Assign(self, node): 

131 self.state[_LoopScope].statements_visited += 1 

132 return self.generic_visit(node) 

133 

134 def visit_AugAssign(self, node): 

135 self.state[_LoopScope].statements_visited += 1 

136 return self.generic_visit(node) 

137 

138 def visit_Expr(self, node): 

139 self.state[_LoopScope].statements_visited += 1 

140 node = self.generic_visit(node) 

141 if isinstance(node.value, gast.Call): 

142 call_node = node.value 

143 static_val = anno.getanno(call_node.func, STATIC_VALUE, default=None) 

144 if static_val is not None: 

145 # Note: directive calls are not output in the generated code, hence 

146 # the removal from the code by returning None. 

147 

148 if static_val is directives.set_element_type: 

149 self._process_symbol_directive(call_node, static_val) 

150 return None 

151 elif static_val is directives.set_loop_options: 

152 self._process_statement_directive(call_node, static_val) 

153 return None 

154 return node 

155 

156 # TODO(mdan): This will be insufficient for other control flow. 

157 # That means that if we ever have a directive that affects things other than 

158 # loops, we'll need support for parallel scopes, or have multiple converters. 

159 def _track_and_visit_loop(self, node): 

160 self.state[_LoopScope].enter() 

161 self.state[_LoopScope].ast_node = node 

162 node = self.generic_visit(node) 

163 # Edge case: a loop with just one directive statement would become empty. 

164 if not node.body: 

165 node.body = [gast.Pass()] 

166 self.state[_LoopScope].exit() 

167 return node 

168 

169 def visit_While(self, node): 

170 return self._track_and_visit_loop(node) 

171 

172 def visit_For(self, node): 

173 return self._track_and_visit_loop(node) 

174 

175 

176def transform(node, ctx): 

177 return DirectivesTransformer(ctx).visit(node)