Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/autograph/converters/call_trees.py: 28%

95 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""Handles function calls, by generating compiled function names and calls. 

16 

17Note: this transformer does not rename the top level object being converted; 

18that is the caller's responsibility. 

19 

20Requires function_scopes. 

21""" 

22 

23import gast 

24 

25from tensorflow.python.autograph.core import converter 

26from tensorflow.python.autograph.pyct import anno 

27from tensorflow.python.autograph.pyct import parser 

28from tensorflow.python.autograph.pyct import qual_names 

29from tensorflow.python.autograph.pyct import templates 

30from tensorflow.python.autograph.utils import ag_logging 

31 

32 

33# TODO(mdan): Rename to FunctionCallsTransformer. 

34 

35 

36class _Function(object): 

37 

38 no_root = True 

39 

40 def __init__(self): 

41 self.context_name = None 

42 

43 

44set_trace_warned = False 

45 

46 

47class _ArgTemplateBuilder(object): 

48 """Constructs a tuple representing the positional arguments in a call. 

49 

50 Example (yes, it's legal Python 3): 

51 

52 f(*args1, b, *args2, c, d) -> args1 + (b,) + args2 + (c, d) 

53 """ 

54 

55 def __init__(self): 

56 self._arg_accumulator = [] 

57 self._argspec = [] 

58 self._finalized = False 

59 

60 def _consume_args(self): 

61 if self._arg_accumulator: 

62 self._argspec.append( 

63 gast.Tuple(elts=self._arg_accumulator, ctx=gast.Load())) 

64 self._arg_accumulator = [] 

65 

66 def add_arg(self, a): 

67 self._arg_accumulator.append(a) 

68 

69 def add_stararg(self, a): 

70 self._consume_args() 

71 self._argspec.append( 

72 gast.Call( 

73 gast.Name( 

74 'tuple', ctx=gast.Load(), annotation=None, type_comment=None), 

75 args=[a], 

76 keywords=())) 

77 

78 def finalize(self): 

79 self._consume_args() 

80 self._finalized = True 

81 

82 def to_ast(self): 

83 assert self._finalized 

84 if self._argspec: 

85 result = self._argspec[0] 

86 for i in range(1, len(self._argspec)): 

87 result = gast.BinOp(result, gast.Add(), self._argspec[i]) 

88 return result 

89 return gast.Tuple([], gast.Load()) 

90 

91 

92class CallTreeTransformer(converter.Base): 

93 """Transforms the call tree by renaming transformed symbols.""" 

94 

95 def visit_Lambda(self, node): 

96 if not anno.hasanno(node, 'function_context_name'): 

97 # Lambda functions created during the conversion process have no 

98 # context manager. 

99 return self.generic_visit(node) 

100 with self.state[_Function] as fn_scope: 

101 fn_scope.context_name = anno.getanno(node, 'function_context_name') 

102 return self.generic_visit(node) 

103 

104 def visit_FunctionDef(self, node): 

105 # Decorators and arg defaults are part of the outer scope. 

106 node.decorator_list = self.visit_block(node.decorator_list) 

107 node.args.defaults = self.visit_block(node.args.defaults) 

108 for i, d in enumerate(node.args.kw_defaults): 

109 if d is not None: 

110 node.args.kw_defaults[i] = self.visit(d) 

111 with self.state[_Function] as fn_scope: 

112 # Note: if the conversion process ever creates helper functions, this 

113 # assumption will no longer hold. 

114 assert anno.hasanno(node, 'function_context_name'), ( 

115 'The function_scopes converter always creates a scope for functions.') 

116 fn_scope.context_name = anno.getanno(node, 'function_context_name') 

117 node.body = self.visit_block(node.body) 

118 if node.returns: 

119 node.returns = self.visit(node.returns) 

120 return node 

121 

122 def visit_With(self, node): 

123 # Context manager calls (in node.items) are not converted. 

124 node.body = self.visit_block(node.body) 

125 return node 

126 

127 def _args_to_tuple(self, node): 

128 """Ties together all positional and *arg arguments in a single tuple.""" 

129 # TODO(mdan): We could rewrite this to just a call to tuple(). Maybe better? 

130 # For example for 

131 # f(a, b, *args) 

132 # instead of writing: 

133 # (a, b) + args 

134 # just write this? 

135 # tuple(a, b, *args) 

136 builder = _ArgTemplateBuilder() 

137 for a in node.args: 

138 if isinstance(a, gast.Starred): 

139 builder.add_stararg(a.value) 

140 else: 

141 builder.add_arg(a) 

142 builder.finalize() 

143 return builder.to_ast() 

144 

145 def _kwargs_to_dict(self, node): 

146 """Ties together all keyword and **kwarg arguments in a single dict.""" 

147 if node.keywords: 

148 return gast.Call( 

149 gast.Name( 

150 'dict', ctx=gast.Load(), annotation=None, type_comment=None), 

151 args=(), 

152 keywords=node.keywords) 

153 else: 

154 return parser.parse_expression('None') 

155 

156 def visit_Call(self, node): 

157 full_name = str(anno.getanno(node.func, anno.Basic.QN, default='')) 

158 function_context_name = self.state[_Function].context_name 

159 node = self.generic_visit(node) 

160 

161 # TODO(mdan): Refactor converted_call as a 'Call' operator. 

162 

163 # Calls to the internal 'ag__' module are never converted (though their 

164 # arguments might be). 

165 if full_name.startswith('ag__.'): 

166 return node 

167 

168 # Calls to the function context manager (inserted by function_scopes) are 

169 # also safe. 

170 if full_name.startswith(function_context_name + '.'): 

171 return node 

172 

173 # Calls to pdb.set_trace or ipdb.set_trace are never converted. We don't use 

174 # the normal mechanisms to bypass these literals because they are sensitive 

175 # to the frame they are being called from. 

176 # TODO(mdan): Generalize this to a "static allowlist" config. 

177 if full_name in ('pdb.set_trace', 'ipdb.set_trace', 'breakpoint'): 

178 global set_trace_warned 

179 if not set_trace_warned: 

180 # TODO(mdan): Update and shorten once available on tensorflow.org. 

181 ag_logging.warning( 

182 'Detected `pdb.set_trace()` in user code. The code' 

183 ' generated by AutoGraph is not optimized for step-by-step' 

184 ' debugging. See https://github.com/tensorflow/tensorflow/' 

185 'blob/master/tensorflow/python/autograph/g3doc/reference/' 

186 'debugging.md.') 

187 set_trace_warned = True 

188 return node 

189 

190 if (full_name == 'print' and 

191 not self.ctx.user.options.uses(converter.Feature.BUILTIN_FUNCTIONS)): 

192 return node 

193 

194 template = """ 

195 ag__.converted_call(func, args, kwargs, function_ctx) 

196 """ 

197 new_call = templates.replace_as_expression( 

198 template, 

199 func=node.func, 

200 args=self._args_to_tuple(node), 

201 kwargs=self._kwargs_to_dict(node), 

202 function_ctx=function_context_name) 

203 

204 return new_call 

205 

206 

207def transform(node, ctx): 

208 """Transform function call to the compiled counterparts. 

209 

210 Args: 

211 node: AST 

212 ctx: EntityContext 

213 Returns: 

214 A tuple (node, new_names): 

215 node: The transformed AST 

216 new_names: set(string), containing any newly-generated names 

217 """ 

218 node = qual_names.resolve(node) 

219 

220 node = CallTreeTransformer(ctx).visit(node) 

221 return node