Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/autograph/core/converter.py: 47%

95 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""Converter construction support. 

16 

17This module contains a base class for all converters, as well as supporting 

18structures. These structures are referred to as contexts. 

19 

20The class hierarchy is as follows: 

21 

22 <your converter> 

23 [extends] converter.Base 

24 [extends] transformer.Base 

25 [extends] gast.nodeTransformer 

26 [uses] transformer.SourceInfo 

27 [uses] converter.EntityContext 

28 [uses] converter.ProgramContext 

29 [uses] transformer.SourceInfo 

30 

31converter.Base is a specialization of transformer.Base for AutoGraph. It's a 

32very lightweight subclass that adds a `ctx` attribute holding the corresponding 

33EntityContext object (see below). Note that converters are not reusable, and 

34`visit` will raise an error if called more than once. 

35 

36converter.EntityContext contains mutable state associated with an entity that 

37the converter processes. 

38 

39converter.ProgramContext contains mutable state across related entities. For 

40example, when converting several functions that call one another, the 

41ProgramContext should be shared across these entities. 

42 

43Below is the overall flow at conversion: 

44 

45 program_ctx = ProgramContext(<entities to convert>, <global settings>, ...) 

46 while <program_ctx has more entities to convert>: 

47 entity, source_info = <get next entity from program_ctx> 

48 entity_ctx = EntityContext(program_ctx, source_info) 

49 for <each ConverterClass>: 

50 converter = ConverterClass(entity_ctx) 

51 

52 # May update entity_ctx and program_ctx 

53 entity = converter.visit(entity) 

54 

55 <add entity's dependencies to program_ctx> 

56 

57Note that pyct contains a small number of transformers used for static analysis. 

58These implement transformer.Base, rather than converter.Base, to avoid a 

59dependency on AutoGraph. 

60""" 

61 

62import enum 

63 

64from tensorflow.python.autograph.pyct import anno 

65from tensorflow.python.autograph.pyct import ast_util 

66from tensorflow.python.autograph.pyct import parser 

67from tensorflow.python.autograph.pyct import templates 

68from tensorflow.python.autograph.pyct import transformer 

69from tensorflow.python.util.tf_export import tf_export 

70 

71# TODO(mdan): These contexts can be refactored into first class objects. 

72# For example, we could define Program and Entity abstractions that hold on 

73# to the actual entity and have conversion methods. 

74 

75# TODO(mdan): Add a test specific to this converter. 

76 

77 

78@tf_export('autograph.experimental.Feature') 

79class Feature(enum.Enum): 

80 """This enumeration represents optional conversion options. 

81 

82 These conversion options are experimental. They are subject to change without 

83 notice and offer no guarantees. 

84 

85 _Example Usage_ 

86 

87 ```python 

88 optionals= tf.autograph.experimental.Feature.EQUALITY_OPERATORS 

89 @tf.function(experimental_autograph_options=optionals) 

90 def f(i): 

91 if i == 0: # EQUALITY_OPERATORS allows the use of == here. 

92 tf.print('i is zero') 

93 ``` 

94 

95 Attributes: 

96 ALL: Enable all features. 

97 AUTO_CONTROL_DEPS: Insert of control dependencies in the generated code. 

98 ASSERT_STATEMENTS: Convert Tensor-dependent assert statements to tf.Assert. 

99 BUILTIN_FUNCTIONS: Convert builtin functions applied to Tensors to 

100 their TF counterparts. 

101 EQUALITY_OPERATORS: Whether to convert the equality operator ('==') to 

102 tf.math.equal. 

103 LISTS: Convert list idioms, like initializers, slices, append, etc. 

104 NAME_SCOPES: Insert name scopes that name ops according to context, like the 

105 function they were defined in. 

106 """ 

107 

108 ALL = 'ALL' 

109 

110 AUTO_CONTROL_DEPS = 'AUTO_CONTROL_DEPS' 

111 ASSERT_STATEMENTS = 'ASSERT_STATEMENTS' 

112 BUILTIN_FUNCTIONS = 'BUILTIN_FUNCTIONS' 

113 EQUALITY_OPERATORS = 'EQUALITY_OPERATORS' 

114 LISTS = 'LISTS' 

115 NAME_SCOPES = 'NAME_SCOPES' 

116 

117 @classmethod 

118 def all(cls): 

119 """Returns a tuple that enables all options.""" 

120 return tuple(cls.__members__.values()) 

121 

122 @classmethod 

123 def all_but(cls, exclude): 

124 """Returns a tuple that enables all but the excluded options.""" 

125 if not isinstance(exclude, (list, tuple, set)): 

126 exclude = (exclude,) 

127 return tuple(set(cls.all()) - set(exclude) - {cls.ALL}) 

128 

129 

130STANDARD_OPTIONS = None # Forward definition. 

131 

132 

133class ConversionOptions(object): 

134 """Immutable container for global conversion flags. 

135 

136 Attributes: 

137 recursive: bool, whether to recursively convert any user functions or 

138 classes that the converted function may use. 

139 user_requested: bool, whether the conversion was explicitly requested by 

140 the user, as opposed to being performed as a result of other logic. This 

141 value always auto-resets to False in child conversions. 

142 optional_features: Union[Feature, Set[Feature]], controls the use of 

143 optional features in the conversion process. See Feature for available 

144 options. 

145 """ 

146 

147 def __init__(self, 

148 recursive=False, 

149 user_requested=False, 

150 internal_convert_user_code=True, 

151 optional_features=Feature.ALL): 

152 self.recursive = recursive 

153 self.user_requested = user_requested 

154 # TODO(mdan): Rename to conversion_recursion_depth? 

155 self.internal_convert_user_code = internal_convert_user_code 

156 

157 if optional_features is None: 

158 optional_features = () 

159 elif isinstance(optional_features, Feature): 

160 optional_features = (optional_features,) 

161 optional_features = frozenset(optional_features) 

162 self.optional_features = optional_features 

163 

164 def as_tuple(self): 

165 return (self.recursive, self.user_requested, 

166 self.internal_convert_user_code, self.optional_features) 

167 

168 def __hash__(self): 

169 return hash(self.as_tuple()) 

170 

171 def __eq__(self, other): 

172 assert isinstance(other, ConversionOptions) 

173 return self.as_tuple() == other.as_tuple() 

174 

175 def __str__(self): 

176 return 'ConversionOptions[{}]' 

177 

178 def uses(self, feature): 

179 return (Feature.ALL in self.optional_features or 

180 feature in self.optional_features) 

181 

182 def call_options(self): 

183 """Returns the corresponding options to be used for recursive conversion.""" 

184 return ConversionOptions( 

185 recursive=self.recursive, 

186 user_requested=False, 

187 internal_convert_user_code=self.recursive, 

188 optional_features=self.optional_features) 

189 

190 def to_ast(self): 

191 """Returns a representation of this object as an AST node. 

192 

193 The AST node encodes a constructor that would create an object with the 

194 same contents. 

195 

196 Returns: 

197 ast.Node 

198 """ 

199 if self == STANDARD_OPTIONS: 

200 return parser.parse_expression('ag__.STD') 

201 

202 template = """ 

203 ag__.ConversionOptions( 

204 recursive=recursive_val, 

205 user_requested=user_requested_val, 

206 optional_features=optional_features_val, 

207 internal_convert_user_code=internal_convert_user_code_val) 

208 """ 

209 

210 def list_of_features(values): 

211 return parser.parse_expression('({})'.format(', '.join( 

212 'ag__.{}'.format(str(v)) for v in values))) 

213 

214 expr_ast = templates.replace( 

215 template, 

216 recursive_val=parser.parse_expression(str(self.recursive)), 

217 user_requested_val=parser.parse_expression(str(self.user_requested)), 

218 internal_convert_user_code_val=parser.parse_expression( 

219 str(self.internal_convert_user_code)), 

220 optional_features_val=list_of_features(self.optional_features)) 

221 return expr_ast[0].value 

222 

223 

224STANDARD_OPTIONS = ConversionOptions( 

225 recursive=True, 

226 user_requested=False, 

227 internal_convert_user_code=True, 

228 optional_features=None) 

229 

230 

231class ProgramContext(object): 

232 """ProgramContext keeps track of converting function hierarchies. 

233 

234 Attributes: 

235 options: ConversionOptions 

236 autograph_module: Deprecated. Do not use. 

237 """ 

238 

239 def __init__(self, options, autograph_module=None): 

240 self.options = options 

241 self.autograph_module = autograph_module 

242 

243 

244class Base(transformer.Base): 

245 """All converters should inherit from this class. 

246 

247 Attributes: 

248 ctx: EntityContext 

249 """ 

250 

251 def __init__(self, ctx): 

252 super(Base, self).__init__(ctx) 

253 

254 self._used = False 

255 self._ast_depth = 0 

256 

257 def get_definition_directive(self, node, directive, arg, default): 

258 """Returns the unique directive argument for a symbol. 

259 

260 See lang/directives.py for details on directives. 

261 

262 Example: 

263 # Given a directive in the code: 

264 ag.foo_directive(bar, baz=1) 

265 

266 # One can write for an AST node Name(id='bar'): 

267 get_definition_directive(node, ag.foo_directive, 'baz') 

268 

269 Args: 

270 node: ast.AST, the node representing the symbol for which the directive 

271 argument is needed. 

272 directive: Callable[..., Any], the directive to search. 

273 arg: str, the directive argument to return. 

274 default: Any 

275 

276 Raises: 

277 ValueError: if conflicting annotations have been found 

278 """ 

279 defs = anno.getanno(node, anno.Static.ORIG_DEFINITIONS, ()) 

280 if not defs: 

281 return default 

282 

283 arg_values_found = [] 

284 for def_ in defs: 

285 if (directive in def_.directives and arg in def_.directives[directive]): 

286 arg_values_found.append(def_.directives[directive][arg]) 

287 

288 if not arg_values_found: 

289 return default 

290 

291 if len(arg_values_found) == 1: 

292 return arg_values_found[0] 

293 

294 # If multiple annotations reach the symbol, they must all match. If they do, 

295 # return any of them. 

296 first_value = arg_values_found[0] 

297 for other_value in arg_values_found[1:]: 

298 if not ast_util.matches(first_value, other_value): 

299 qn = anno.getanno(node, anno.Basic.QN) 

300 raise ValueError( 

301 '%s has ambiguous annotations for %s(%s): %s, %s' % 

302 (qn, directive.__name__, arg, parser.unparse(other_value).strip(), 

303 parser.unparse(first_value).strip())) 

304 return first_value 

305 

306 def visit(self, node): 

307 if not self._ast_depth: 

308 if self._used: 

309 raise ValueError('converter objects cannot be reused') 

310 self._used = True 

311 

312 self._ast_depth += 1 

313 try: 

314 return super(Base, self).visit(node) 

315 finally: 

316 self._ast_depth -= 1