Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/eager/polymorphic_function/tracing_compiler.py: 35%

177 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2022 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""Tracing Compiler implementation.""" 

16 

17import collections 

18import contextlib 

19import threading 

20import types as types_lib 

21from typing import List 

22import weakref 

23 

24from tensorflow.core.function import trace_type 

25from tensorflow.core.function.capture import capture_container 

26from tensorflow.core.function.polymorphism import function_cache 

27from tensorflow.core.function.polymorphism import function_type as function_type_lib 

28from tensorflow.python.eager import monitoring 

29from tensorflow.python.eager.polymorphic_function import attributes as attributes_lib 

30from tensorflow.python.eager.polymorphic_function import function_context 

31from tensorflow.python.eager.polymorphic_function import function_spec 

32from tensorflow.python.eager.polymorphic_function import monomorphic_function 

33from tensorflow.python.eager.polymorphic_function import tf_method_target 

34from tensorflow.python.framework import func_graph as func_graph_module 

35from tensorflow.python.framework import ops 

36from tensorflow.python.platform import tf_logging as logging 

37from tensorflow.python.profiler import trace 

38from tensorflow.python.util import compat 

39from tensorflow.python.util import lazy_loader 

40from tensorflow.python.util import tf_decorator 

41 

42# Loaded lazily due to a circular dependency (roughly 

43# tf.function->autograph->->dataset->tf.function). 

44# TODO(b/133251390): Use a regular import. 

45ag_ctx = lazy_loader.LazyLoader( 

46 "ag_ctx", globals(), 

47 "tensorflow.python.autograph.core.ag_ctx") 

48 

49_graph_building_time_counter = monitoring.Counter( 

50 "/tensorflow/core/tf_function/graph_building_time_usecs", 

51 "Time for tf.function to build a graph (us).") 

52 

53 

54# TODO(fmuham): Revamp the API of this class to be 100% compiler-focused. 

55class TracingCompiler: 

56 """Generates, caches and dispatchs traced Monomorphic Concrete Functions. 

57 

58 The tracing is done using the Python source function with respect to inputs 

59 and other options specified by constructor. 

60 

61 See the documentation for `tf.function` for more information on the semantics 

62 of defined functions. 

63 

64 `TracingCompiler` class is thread-compatible meaning that minimal usage of 

65 tf.function (defining and calling) is thread-safe, but if users call other 

66 methods or invoke the base `python_function` themselves, external 

67 synchronization is necessary. 

68 

69 In addition, TracingCompiler is not reentrant, so recursive functions need 

70 to call the wrapped function, not the wrapper. 

71 """ 

72 

73 def __init__(self, 

74 python_function, 

75 name, 

76 input_signature=None, 

77 attributes=None, 

78 autograph=True, 

79 autograph_options=None, 

80 reduce_retracing=False, 

81 capture_by_value=None, 

82 jit_compile=None): 

83 """Initializes a `TracingCompiler`. 

84 

85 Args: 

86 python_function: the function to be wrapped. 

87 name: the name given to it. 

88 input_signature: a possibly nested sequence of `TensorSpec` objects 

89 specifying the input signature of this function. If `None`, a separate 

90 function is instantiated for each inferred input signature. 

91 attributes: dict, extra keyword arguments that will be added as attribute 

92 of the function. 

93 autograph: whether to use autograph to compile `python_function`. See 

94 https://www.tensorflow.org/guide/autograph for more information. 

95 autograph_options: Experimental knobs to control behavior `when 

96 autograph=True`. See https://www.tensorflow.org/guide/autograph for more 

97 information. 

98 reduce_retracing: When True, `tf.function` uses 

99 `tf.types.experimental.TraceType` to trace supertypes of arguments to 

100 reduce the number of traces. 

101 capture_by_value: Experimental. Whether to capture resource variables by 

102 value or reference. If None, will inherit from a parent context or 

103 default to False. 

104 jit_compile: Force-compile the function with XLA, cf. tf.function doc on 

105 jit_compile. 

106 

107 Raises: 

108 ValueError: if `input_signature` is not None and the `python_function`'s 

109 argspec has keyword arguments. 

110 """ 

111 self._python_function = python_function 

112 pure_function = attributes and attributes_lib.IMPLEMENTS in attributes 

113 self._function_spec = ( 

114 function_spec.FunctionSpec.from_function_and_signature( 

115 python_function, input_signature, is_pure=pure_function 

116 ) 

117 ) 

118 self._name = name 

119 self._autograph = autograph 

120 self._autograph_options = autograph_options 

121 self._reduce_retracing = reduce_retracing 

122 self._function_cache = function_cache.FunctionCache() 

123 

124 self._function_attributes = attributes or {} 

125 for attribute in self._function_attributes: 

126 if attribute not in attributes_lib.TRACING_COMPILER_ALLOWLIST: 

127 raise ValueError( 

128 f"TracingCompiler does not support `{attribute}` as an attribute." 

129 ) 

130 

131 self._capture_by_value = capture_by_value 

132 self.tracing_count = 0 

133 # Maintein a dict of all captures: identifier -> lambda function. It's used 

134 # to get runtime values for all captures during ConcreteFunction dispatch, 

135 self._func_captures = capture_container.FunctionCaptures() 

136 self._lock = threading.RLock() 

137 # _descriptor_cache is a of instance of a class to an instance-specific 

138 # `TracingCompiler`, used to make sure tf.function-decorated methods 

139 # create different functions for each instance. 

140 self._descriptor_cache = weakref.WeakKeyDictionary() 

141 self._jit_compile = jit_compile 

142 

143 def __call__(self, *args, **kwargs): 

144 """Calls a graph function specialized to the inputs.""" 

145 with self._lock: 

146 (concrete_function, 

147 filtered_flat_args) = self._maybe_define_function(args, kwargs) 

148 return concrete_function._call_flat( 

149 filtered_flat_args, captured_inputs=concrete_function.captured_inputs) # pylint: disable=protected-access 

150 

151 @property 

152 def python_function(self): 

153 """Returns the wrapped Python function.""" 

154 return self._python_function # pylint: disable=protected-access 

155 

156 @property 

157 def function_spec(self): 

158 return self._function_spec 

159 

160 @property 

161 def input_signature(self): 

162 """Returns the input signature.""" 

163 return self._function_spec.input_signature 

164 

165 def _maybe_define_concrete_function(self, args, kwargs): 

166 if self.input_signature and not args and not kwargs: 

167 # TODO(b/215596825): Throw error here if multiple entries are defined. 

168 args = self.input_signature 

169 kwargs = {} 

170 

171 return self._maybe_define_function(args, kwargs) 

172 

173 def _get_concrete_function_internal_garbage_collected(self, *args, **kwargs): 

174 """Returns a concrete function which cleans up its graph function.""" 

175 with self._lock: 

176 concrete_function, _ = self._maybe_define_concrete_function(args, kwargs) 

177 return concrete_function 

178 

179 def _get_concrete_function_internal(self, *args, **kwargs): 

180 """Bypasses error checking when getting a graph function.""" 

181 concrete_function = self._get_concrete_function_internal_garbage_collected( 

182 *args, **kwargs) 

183 # We're returning this concrete function to someone, and they may keep a 

184 # reference to the FuncGraph without keeping a reference to the 

185 # ConcreteFunction object. So we won't clean up the reference cycles 

186 # manually and instead will leave them to Python's garbage collector. 

187 concrete_function._garbage_collector.release() # pylint: disable=protected-access 

188 return concrete_function 

189 

190 def _get_concrete_function_garbage_collected(self, *args, **kwargs): 

191 """Returns a `ConcreteFunction` specialized to inputs and execution context. 

192 

193 Unlike `get_concrete_function(...)`, the graph will be deleted when the 

194 returned function is deleted. It's useful to avoid creating a reference 

195 cycle when you know for sure that the graph will be no longer used without 

196 the returned function. 

197 

198 Args: 

199 *args: inputs to specialize on. 

200 **kwargs: inputs to specialize on. 

201 """ 

202 if self.input_signature and (args or kwargs): 

203 # Check to see if a valid type can be generated from the args, kwargs 

204 self._function_spec.make_canonicalized_monomorphic_type(args, kwargs) 

205 

206 with self._lock: 

207 concrete_function, _ = self._maybe_define_concrete_function(args, kwargs) 

208 seen_names = set() 

209 concrete_function._arg_keywords = [] # pylint: disable=protected-access 

210 prefix_counts = {} 

211 graph = concrete_function.graph 

212 num_captures = len( 

213 graph.internal_captures + graph.deferred_internal_captures) 

214 num_positional = len(graph.inputs) - num_captures 

215 for arg in concrete_function.graph.inputs[:num_positional]: 

216 user_arg_name = compat.as_str(arg.op.get_attr("_user_specified_name")) 

217 proposal = user_arg_name 

218 while proposal in seen_names: 

219 index = prefix_counts.get(user_arg_name, 1) 

220 proposal = "{}_{}".format(user_arg_name, index) 

221 prefix_counts[user_arg_name] = index + 1 

222 seen_names.add(proposal) 

223 concrete_function._arg_keywords.append(proposal) # pylint: disable=protected-access 

224 # Anything can be a positional argument, in the same order as .inputs 

225 concrete_function._num_positional_args = num_positional # pylint: disable=protected-access 

226 return concrete_function 

227 

228 def get_concrete_function(self, *args, **kwargs): 

229 """Returns a `ConcreteFunction` specialized to inputs and execution context. 

230 

231 Args: 

232 *args: inputs to specialize on. Can be concrete values (e.g. 1) or 

233 `tf.Tensor` or `tf.TensorSpec`. 

234 **kwargs: keyword inputs to specialize on. Concrete values (e.g. 1) or 

235 `tf.Tensor` or `tf.TensorSpec`. 

236 """ 

237 concrete_function = self._get_concrete_function_garbage_collected( 

238 *args, **kwargs) 

239 concrete_function._garbage_collector.release() # pylint: disable=protected-access 

240 return concrete_function 

241 

242 def _list_all_concrete_functions( 

243 self) -> List[monomorphic_function.ConcreteFunction]: 

244 return self._function_cache.values() 

245 

246 def __get__(self, instance, owner): 

247 """Makes it possible to decorate instance methods.""" 

248 del owner 

249 # `instance` here is the instance that this `TracingCompiler` was 

250 # accessed through e.g., for 

251 # 

252 # class Foo: 

253 # 

254 # @tf.function 

255 # def bar(self): 

256 # ... 

257 # 

258 # foo = Foo() 

259 # foo.bar() # `foo.bar` is a `tf.function` instance 

260 # 

261 # then `instance` will be `foo` (and `owner` will be `Foo`). We create a 

262 # new instance of `TracingCompiler` here to allow different instances 

263 # to create variables once, thereby allowing methods to be decorated with 

264 # tf.function. Keeps a cache to avoid retracing the function every time the 

265 # descriptor is accessed. 

266 if instance not in self._descriptor_cache: 

267 if instance is None: 

268 return self 

269 # If there is no instance-specific `TracingCompiler` in the cache, we 

270 # construct an instance-specific `TracingCompiler` that uses a weak 

271 # reference to the instance (so that the instance will be correctly gc'd). 

272 

273 # And finally add the wrapped function to the description cache 

274 self._descriptor_cache[instance] = class_method_to_instance_method( 

275 self, instance) 

276 

277 # Return the cached `TracingCompiler` for the instance 

278 return self._descriptor_cache[instance] 

279 

280 def _create_concrete_function(self, args, kwargs, func_graph): 

281 """Create a `ConcreteFunction` from `args`, `kwargs`, and `func_graph`.""" 

282 self.tracing_count += 1 

283 

284 arglen = len(args) 

285 base_arg_names = self._function_spec.arg_names[:arglen] 

286 num_missing_args = arglen - len(self._function_spec.arg_names) 

287 if num_missing_args > 0: 

288 # Must have variable positional args if there are missing args. 

289 var_arg_name = next( 

290 p.name 

291 for p in self._function_spec.function_type.parameters.values() 

292 if p.kind is function_type_lib.Parameter.VAR_POSITIONAL 

293 ) 

294 missing_arg_names = [var_arg_name] * num_missing_args 

295 # Produce a list of missing args of the form ["arg_0", "arg_1", ...], 

296 # where arg is based on the self._function_spec.vararg_name. 

297 missing_arg_names = [ 

298 "%s_%d" % (arg, i) for i, arg in enumerate(missing_arg_names) 

299 ] 

300 arg_names = base_arg_names + missing_arg_names 

301 else: 

302 arg_names = base_arg_names 

303 

304 concrete_function = monomorphic_function.ConcreteFunction( 

305 func_graph_module.func_graph_from_py_func( 

306 self._name, 

307 self._python_function, 

308 args, 

309 kwargs, 

310 None, 

311 func_graph=func_graph, 

312 arg_names=arg_names, 

313 capture_by_value=self._capture_by_value, 

314 create_placeholders=False), 

315 self._function_attributes, 

316 spec=self.function_spec, 

317 # Tell the ConcreteFunction to clean up its graph once it goes out of 

318 # scope. This is not the default behavior since it gets used in some 

319 # places (like Keras) where the FuncGraph lives longer than the 

320 # ConcreteFunction. 

321 shared_func_graph=False) 

322 return concrete_function 

323 

324 def _maybe_define_function(self, args, kwargs): 

325 """Gets a function for these inputs, defining it if necessary. 

326 

327 Caller must hold self._lock. 

328 

329 Args: 

330 args: The varargs for the Python function. 

331 kwargs: The keyword args for the Python function. 

332 

333 Returns: 

334 A graph function corresponding to the input signature implied by args and 

335 kwargs, as well as filtered flattened inputs (only Tensors and Variables) 

336 that the object should be called with. 

337 

338 Raises: 

339 ValueError: If inputs are incompatible with the input signature. 

340 TypeError: If the function inputs include non-hashable objects 

341 RuntimeError: If there's an internal bug (inconsistency) in handling 

342 shape relaxation retracing. 

343 """ 

344 args, kwargs, filtered_flat_args = ( 

345 self._function_spec.canonicalize_function_inputs(args, kwargs)) 

346 

347 if self.input_signature is not None: 

348 args = (*self.input_signature, *args[len(self.input_signature):]) 

349 

350 # Get runtime values of captures 

351 captures = self._func_captures.get_by_ref_snapshot() 

352 

353 current_func_context = function_context.make_function_context() 

354 

355 # cache_key_deletion_observer is useless here. It's based on all captures. 

356 # A new cache key will be built later when saving ConcreteFunction because 

357 # only active captures should be saved. 

358 lookup_func_type, lookup_func_context = ( 

359 self._function_spec.make_canonicalized_monomorphic_type( 

360 args, kwargs, captures)) 

361 concrete_function = self._function_cache.lookup(current_func_context, 

362 lookup_func_type) 

363 if concrete_function is not None: 

364 return concrete_function, filtered_flat_args 

365 

366 # Use a timer for graph building only if not already inside a function. This 

367 # avoids double counting graph building time for nested functions. 

368 with monitoring.MonitoredTimer( 

369 _graph_building_time_counter.get_cell() 

370 ) if not ops.inside_function() else contextlib.nullcontext(): 

371 with trace.Trace("tf.function-graph_building"): 

372 logging.vlog( 

373 1, "Creating new FuncGraph for Python function %r (key: %r, %r)", 

374 self._python_function, current_func_context, lookup_func_type) 

375 logging.vlog(2, "Python function signature [args: %s] [kwargs: %s]", 

376 args, kwargs) 

377 ag_status = ( 

378 ag_ctx.Status.ENABLED 

379 if self._autograph else ag_ctx.Status.DISABLED) 

380 with ag_ctx.ControlStatusCtx( 

381 status=ag_status, options=self._autograph_options): 

382 func_graph = func_graph_module.FuncGraph( 

383 self._name, capture_by_value=self._capture_by_value) 

384 if self.input_signature is None and self._reduce_retracing: 

385 target_func_type = self._function_cache.generalize( 

386 current_func_context, lookup_func_type) 

387 else: 

388 target_func_type = lookup_func_type 

389 placeholder_mapping = lookup_func_context.get_placeholder_mapping() 

390 placeholder_context = trace_type.InternalPlaceholderContext( 

391 func_graph, placeholder_mapping) 

392 with func_graph.as_default(): 

393 placeholder_bound_args = target_func_type.placeholder_arguments( 

394 placeholder_context) 

395 args = placeholder_bound_args.args 

396 kwargs = placeholder_bound_args.kwargs 

397 

398 concrete_function = self._create_concrete_function( 

399 args, kwargs, func_graph) 

400 

401 # TODO(b/263520817): Remove access to private attribute. 

402 graph_capture_container = concrete_function.graph.function_captures 

403 # Maintain the list of all captures 

404 self._func_captures.merge_by_ref_with(graph_capture_container) 

405 # Get current active captures snapshot 

406 captures = graph_capture_container.get_by_ref_snapshot() 

407 

408 # Create a cache_key with args and captures 

409 traced_func_type = _insert_capture_type( 

410 target_func_type, captures, lookup_func_context) 

411 

412 self._function_cache.add(current_func_context, traced_func_type, 

413 concrete_function) 

414 

415 return concrete_function, filtered_flat_args 

416 

417 

418def class_method_to_instance_method(original_function, instance): 

419 """Constructs a new `TracingCompiler` with `self` bound.""" 

420 weak_instance = weakref.ref(instance) 

421 

422 # Note: while we could bind to a weakref proxy instead, that causes the 

423 # bound method to be unhashable. 

424 bound_method = types_lib.MethodType( 

425 original_function.python_function, 

426 tf_method_target.TfMethodTarget(weak_instance, 

427 original_function.python_function)) 

428 

429 # original_function is expected to be either `TracingCompiler` or 

430 # def_function.Function 

431 assert hasattr(original_function, "_name") 

432 assert hasattr(original_function, "_autograph") 

433 assert hasattr(original_function, "_function_spec") 

434 assert hasattr(original_function, "python_function") 

435 

436 weak_bound_method_wrapper = None 

437 

438 def bound_method_wrapper(*args, **kwargs): 

439 """Wraps either a dummy MethodType or a converted AutoGraph function.""" 

440 # __wrapped__ allows AutoGraph to swap in a converted function. 

441 strong_bound_method_wrapper = weak_bound_method_wrapper() 

442 wrapped_fn = strong_bound_method_wrapper.__wrapped__ 

443 

444 if wrapped_fn is strong_bound_method_wrapper.__original_wrapped__: 

445 # If __wrapped__ was not replaced, then call original_function. 

446 # TODO(mdan): For better consistency, use the wrapper's call(). 

447 wrapped_fn = original_function.python_function 

448 return wrapped_fn(weak_instance(), *args, **kwargs) 

449 

450 # If __wrapped__ was replaced, then it is always an unbound function. 

451 # However, the replacer is still responsible for attaching self properly. 

452 # TODO(mdan): Is it possible to do it here instead? 

453 return wrapped_fn(*args, **kwargs) 

454 

455 weak_bound_method_wrapper = weakref.ref(bound_method_wrapper) 

456 

457 # pylint: disable=protected-access 

458 # We make a dummy MethodType object to generate the correct bound method 

459 # signature. The actual call is to a function with a weak reference to 

460 # `instance`. 

461 instance_func = type(original_function)( 

462 tf_decorator.make_decorator(bound_method, bound_method_wrapper), 

463 name=original_function._name, 

464 autograph=original_function._autograph, 

465 input_signature=original_function.input_signature, 

466 reduce_retracing=original_function._reduce_retracing, 

467 jit_compile=original_function._jit_compile) 

468 # pylint: enable=protected-access 

469 

470 # We wrap the bound method with tf_decorator so inspection works correctly 

471 wrapped_instance_func = tf_decorator.make_decorator(bound_method, 

472 instance_func) 

473 return wrapped_instance_func 

474 

475 

476def _insert_capture_type(original_func_type, captures, type_context): 

477 capture_types = collections.OrderedDict() 

478 for name, value in captures.items(): 

479 capture_types[name] = trace_type.from_value(value, type_context) 

480 return function_type_lib.FunctionType( 

481 original_func_type.parameters.values(), capture_types)