Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/util/dispatch.py: 71%

397 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""Type-based dispatch for TensorFlow's Python APIs. 

16 

17"Python APIs" refers to Python functions that have been exported with 

18`tf_export`, such as `tf.add` and `tf.linalg.matmul`; they are sometimes also 

19referred to as "ops". 

20 

21There are currently two dispatch systems for TensorFlow: 

22 

23 * The "fallback dispatch" system calls an API's standard implementation first, 

24 and only tries to perform dispatch if that standard implementation raises a 

25 TypeError (or ValueError) exception. 

26 

27 * The "type-based dispatch" system checks the types of the parameters passed 

28 to an API, and performs dispatch if those types match any signatures that 

29 have been registered for dispatch. 

30 

31The fallback dispatch system was the original dispatch system, but it was 

32somewhat brittle and had limitations, such as an inability to support dispatch 

33for some operations (like convert_to_tensor). We plan to remove the fallback 

34dispatch system in favor of the type-based dispatch system, once all users have 

35been switched over to use it. 

36 

37### Fallback Dispatch 

38 

39The fallback dispatch system is based on "operation dispatchers", which can be 

40used to override the behavior for TensorFlow ops when they are called with 

41otherwise unsupported argument types. In particular, when an operation is 

42called with arguments that would cause it to raise a TypeError, it falls back on 

43its registered operation dispatchers. If any registered dispatchers can handle 

44the arguments, then its result is returned. Otherwise, the original TypeError is 

45raised. 

46 

47### Type-based Dispatch 

48 

49The main interface for the type-based dispatch system is the `dispatch_for_api` 

50decorator, which overrides the default implementation for a TensorFlow API. 

51The decorated function (known as the "dispatch target") will override the 

52default implementation for the API when the API is called with parameters that 

53match a specified type signature. 

54 

55### Dispatch Support 

56 

57By default, dispatch support is added to the generated op wrappers for any 

58visible ops by default. APIs/ops that are implemented in Python can opt in to 

59dispatch support using the `add_dispatch_support` decorator. 

60""" 

61 

62import collections 

63import itertools 

64import typing # pylint: disable=unused-import (used in doctests) 

65 

66from tensorflow.python.framework import _pywrap_python_api_dispatcher as _api_dispatcher 

67from tensorflow.python.framework import ops 

68from tensorflow.python.util import tf_decorator 

69from tensorflow.python.util import tf_export as tf_export_lib 

70from tensorflow.python.util import tf_inspect 

71from tensorflow.python.util import traceback_utils 

72from tensorflow.python.util import type_annotations 

73from tensorflow.python.util.tf_export import tf_export 

74 

75 

76# Private function attributes used to store dispatchers on TensorFlow APIs. 

77FALLBACK_DISPATCH_ATTR = "_tf_fallback_dispatchers" 

78TYPE_BASED_DISPATCH_ATTR = "_tf_type_based_dispatcher" 

79 

80# OpDispatchers which should be used for all operations. 

81_GLOBAL_DISPATCHERS = [] 

82 

83 

84################################################################################ 

85# Fallback Dispatch 

86################################################################################ 

87 

88 

89@tf_export("__internal__.dispatch.OpDispatcher", v1=[]) 

90class OpDispatcher(object): 

91 """Abstract base class for TensorFlow operator dispatchers. 

92 

93 Each operation dispatcher acts as an override handler for a single 

94 TensorFlow operation, and its results are used when the handler indicates 

95 that it can handle the operation's arguments (by returning any value other 

96 than `OpDispatcher.NOT_SUPPORTED`). 

97 """ 

98 

99 # Sentinel value that can be returned to indicate that an operation 

100 # dispatcher does not support a given set of arguments. 

101 NOT_SUPPORTED = object() 

102 

103 def handle(self, args, kwargs): # pylint: disable=unused-argument 

104 """Handle this dispatcher's operation with the specified arguments. 

105 

106 If this operation dispatcher can handle the given arguments, then 

107 return an appropriate value (or raise an appropriate exception). 

108 

109 Args: 

110 args: The arguments to the operation. 

111 kwargs: They keyword arguments to the operation. 

112 

113 Returns: 

114 The result of the operation, or `OpDispatcher.NOT_SUPPORTED` if this 

115 dispatcher can not handle the given arguments. 

116 """ 

117 return self.NOT_SUPPORTED 

118 

119 def register(self, op): 

120 """Register this dispatcher as a handler for `op`. 

121 

122 Args: 

123 op: Python function: the TensorFlow operation that should be handled. Must 

124 have a dispatch list (which is added automatically for generated ops, 

125 and can be added to Python ops using the `add_dispatch_support` 

126 decorator). 

127 """ 

128 if not hasattr(op, FALLBACK_DISPATCH_ATTR): 

129 raise AssertionError("Dispatching not enabled for %s" % op) 

130 getattr(op, FALLBACK_DISPATCH_ATTR).append(self) 

131 

132 

133@tf_export("__internal__.dispatch.GlobalOpDispatcher", v1=[]) 

134class GlobalOpDispatcher(object): 

135 """Abstract base class for TensorFlow global operator dispatchers.""" 

136 

137 NOT_SUPPORTED = OpDispatcher.NOT_SUPPORTED 

138 

139 def handle(self, op, args, kwargs): 

140 """Handle the specified operation with the specified arguments.""" 

141 

142 def register(self): 

143 """Register this dispatcher as a handler for all ops.""" 

144 _GLOBAL_DISPATCHERS.append(self) 

145 

146 

147def dispatch(op, args, kwargs): 

148 """Returns the result from the first successful dispatcher for a given op. 

149 

150 Calls the `handle` method of each `OpDispatcher` that has been registered 

151 to handle `op`, and returns the value from the first successful handler. 

152 

153 Args: 

154 op: Python function: the operation to dispatch for. 

155 args: The arguments to the operation. 

156 kwargs: They keyword arguments to the operation. 

157 

158 Returns: 

159 The result of the operation, or `NOT_SUPPORTED` if no registered 

160 dispatcher can handle the given arguments. 

161 """ 

162 for dispatcher in getattr(op, FALLBACK_DISPATCH_ATTR): 

163 result = dispatcher.handle(args, kwargs) 

164 if result is not OpDispatcher.NOT_SUPPORTED: 

165 return result 

166 for dispatcher in _GLOBAL_DISPATCHERS: 

167 result = dispatcher.handle(op, args, kwargs) 

168 if result is not OpDispatcher.NOT_SUPPORTED: 

169 return result 

170 return OpDispatcher.NOT_SUPPORTED 

171 

172 

173class _TypeBasedDispatcher(OpDispatcher): 

174 """Dispatcher that handles op if any arguments have a specified type. 

175 

176 Checks the types of the arguments and keyword arguments (including elements 

177 of lists or tuples), and if any argument values have the indicated type(s), 

178 then delegates to an override function. 

179 """ 

180 

181 def __init__(self, override_func, types): 

182 self._types = types 

183 self._override_func = override_func 

184 

185 def _handles(self, args, kwargs): 

186 for arg in itertools.chain(args, kwargs.values()): 

187 if (isinstance(arg, self._types) or 

188 (isinstance(arg, (list, tuple)) and 

189 any(isinstance(elt, self._types) for elt in arg))): 

190 return True 

191 return False 

192 

193 def handle(self, args, kwargs): 

194 if self._handles(args, kwargs): 

195 return self._override_func(*args, **kwargs) 

196 else: 

197 return self.NOT_SUPPORTED 

198 

199 

200# pylint: disable=g-doc-return-or-yield 

201def dispatch_for_types(op, *types): 

202 """Decorator to declare that a Python function overrides an op for a type. 

203 

204 The decorated function is used to override `op` if any of the arguments or 

205 keyword arguments (including elements of lists or tuples) have one of the 

206 specified types. 

207 

208 Example: 

209 

210 ```python 

211 @dispatch_for_types(math_ops.add, RaggedTensor, RaggedTensorValue) 

212 def ragged_add(x, y, name=None): ... 

213 ``` 

214 

215 Args: 

216 op: Python function: the operation that should be overridden. 

217 *types: The argument types for which this function should be used. 

218 """ 

219 

220 def decorator(func): 

221 if tf_inspect.getargspec(func) != tf_inspect.getargspec(op): 

222 raise AssertionError("The decorated function's signature must exactly " 

223 "match the signature of the overridden op.") 

224 _TypeBasedDispatcher(func, types).register(op) 

225 return func 

226 

227 return decorator 

228 

229 

230# pylint: enable=g-doc-return-or-yield 

231 

232 

233def add_fallback_dispatch_list(target): 

234 """Decorator that adds a dispatch_list attribute to an op.""" 

235 if hasattr(target, FALLBACK_DISPATCH_ATTR): 

236 raise AssertionError("%s already has a dispatch list" % target) 

237 setattr(target, FALLBACK_DISPATCH_ATTR, []) 

238 return target 

239 

240 

241# Alias for backwards-compatibility. 

242add_dispatch_list = add_fallback_dispatch_list 

243 

244 

245################################################################################ 

246# Type-based Dispatch 

247################################################################################ 

248 

249 

250@tf_export("experimental.dispatch_for_api") 

251def dispatch_for_api(api, *signatures): 

252 """Decorator that overrides the default implementation for a TensorFlow API. 

253 

254 The decorated function (known as the "dispatch target") will override the 

255 default implementation for the API when the API is called with parameters that 

256 match a specified type signature. Signatures are specified using dictionaries 

257 that map parameter names to type annotations. E.g., in the following example, 

258 `masked_add` will be called for `tf.add` if both `x` and `y` are 

259 `MaskedTensor`s: 

260 

261 >>> class MaskedTensor(tf.experimental.ExtensionType): 

262 ... values: tf.Tensor 

263 ... mask: tf.Tensor 

264 

265 >>> @dispatch_for_api(tf.math.add, {'x': MaskedTensor, 'y': MaskedTensor}) 

266 ... def masked_add(x, y, name=None): 

267 ... return MaskedTensor(x.values + y.values, x.mask & y.mask) 

268 

269 >>> mt = tf.add(MaskedTensor([1, 2], [True, False]), MaskedTensor(10, True)) 

270 >>> print(f"values={mt.values.numpy()}, mask={mt.mask.numpy()}") 

271 values=[11 12], mask=[ True False] 

272 

273 If multiple type signatures are specified, then the dispatch target will be 

274 called if any of the signatures match. For example, the following code 

275 registers `masked_add` to be called if `x` is a `MaskedTensor` *or* `y` is 

276 a `MaskedTensor`. 

277 

278 >>> @dispatch_for_api(tf.math.add, {'x': MaskedTensor}, {'y':MaskedTensor}) 

279 ... def masked_add(x, y): 

280 ... x_values = x.values if isinstance(x, MaskedTensor) else x 

281 ... x_mask = x.mask if isinstance(x, MaskedTensor) else True 

282 ... y_values = y.values if isinstance(y, MaskedTensor) else y 

283 ... y_mask = y.mask if isinstance(y, MaskedTensor) else True 

284 ... return MaskedTensor(x_values + y_values, x_mask & y_mask) 

285 

286 The type annotations in type signatures may be type objects (e.g., 

287 `MaskedTensor`), `typing.List` values, or `typing.Union` values. For 

288 example, the following will register `masked_concat` to be called if `values` 

289 is a list of `MaskedTensor` values: 

290 

291 >>> @dispatch_for_api(tf.concat, {'values': typing.List[MaskedTensor]}) 

292 ... def masked_concat(values, axis): 

293 ... return MaskedTensor(tf.concat([v.values for v in values], axis), 

294 ... tf.concat([v.mask for v in values], axis)) 

295 

296 Each type signature must contain at least one subclass of `tf.CompositeTensor` 

297 (which includes subclasses of `tf.ExtensionType`), and dispatch will only be 

298 triggered if at least one type-annotated parameter contains a 

299 `CompositeTensor` value. This rule avoids invoking dispatch in degenerate 

300 cases, such as the following examples: 

301 

302 * `@dispatch_for_api(tf.concat, {'values': List[MaskedTensor]})`: Will not 

303 dispatch to the decorated dispatch target when the user calls 

304 `tf.concat([])`. 

305 

306 * `@dispatch_for_api(tf.add, {'x': Union[MaskedTensor, Tensor], 'y': 

307 Union[MaskedTensor, Tensor]})`: Will not dispatch to the decorated dispatch 

308 target when the user calls `tf.add(tf.constant(1), tf.constant(2))`. 

309 

310 The dispatch target's signature must match the signature of the API that is 

311 being overridden. In particular, parameters must have the same names, and 

312 must occur in the same order. The dispatch target may optionally elide the 

313 "name" parameter, in which case it will be wrapped with a call to 

314 `tf.name_scope` when appropraite. 

315 

316 Args: 

317 api: The TensorFlow API to override. 

318 *signatures: Dictionaries mapping parameter names or indices to type 

319 annotations, specifying when the dispatch target should be called. In 

320 particular, the dispatch target will be called if any signature matches; 

321 and a signature matches if all of the specified parameters have types that 

322 match with the indicated type annotations. If no signatures are 

323 specified, then a signature will be read from the dispatch target 

324 function's type annotations. 

325 

326 Returns: 

327 A decorator that overrides the default implementation for `api`. 

328 

329 #### Registered APIs 

330 

331 The TensorFlow APIs that may be overridden by `@dispatch_for_api` are: 

332 

333 <<API_LIST>> 

334 """ 

335 dispatcher = getattr(api, TYPE_BASED_DISPATCH_ATTR, None) 

336 if dispatcher is None: 

337 raise ValueError(f"{api} does not support dispatch.") 

338 

339 api_signature = tf_inspect.signature(api) 

340 signature_checkers = [ 

341 _make_signature_checker(api_signature, signature) 

342 for signature in signatures 

343 ] 

344 

345 def decorator(dispatch_target): 

346 """Decorator that registers the given dispatch target.""" 

347 if not callable(dispatch_target): 

348 raise TypeError("Expected dispatch_target to be callable; " 

349 f"got {dispatch_target!r}") 

350 dispatch_target = _add_name_scope_wrapper(dispatch_target, api_signature) 

351 _check_signature(api_signature, dispatch_target) 

352 

353 for signature_checker in signature_checkers: 

354 dispatcher.Register(signature_checker, dispatch_target) 

355 _TYPE_BASED_DISPATCH_SIGNATURES[api][dispatch_target].extend(signatures) 

356 

357 if not signature_checkers: 

358 signature = _signature_from_annotations(dispatch_target) 

359 checker = _make_signature_checker(api_signature, signature) 

360 dispatcher.Register(checker, dispatch_target) 

361 _TYPE_BASED_DISPATCH_SIGNATURES[api][dispatch_target].append(signature) 

362 

363 return dispatch_target 

364 

365 return decorator 

366 

367 

368# Nested dict mapping `api_func` -> `dispatch_target` -> `List[signature]`, 

369# which can be used for documentation generation and for improved error messages 

370# when APIs are called with unsupported types. 

371_TYPE_BASED_DISPATCH_SIGNATURES = {} 

372 

373 

374def apis_with_type_based_dispatch(): 

375 """Returns a list of TensorFlow APIs that support type-based dispatch.""" 

376 return sorted( 

377 _TYPE_BASED_DISPATCH_SIGNATURES, 

378 key=lambda api: f"{api.__module__}.{api.__name__}") 

379 

380 

381def type_based_dispatch_signatures_for(cls): 

382 """Returns dispatch signatures that have been registered for a given class. 

383 

384 This function is intended for documentation-generation purposes. 

385 

386 Args: 

387 cls: The class to search for. Type signatures are searched recursively, so 

388 e.g., if `cls=RaggedTensor`, then information will be returned for all 

389 dispatch targets that have `RaggedTensor` anywhere in their type 

390 annotations (including nested in `typing.Union` or `typing.List`.) 

391 

392 Returns: 

393 A `dict` mapping `api` -> `signatures`, where `api` is a TensorFlow API 

394 function; and `signatures` is a list of dispatch signatures for `api` 

395 that include `cls`. (Each signature is a dict mapping argument names to 

396 type annotations; see `dispatch_for_api` for more info.) 

397 """ 

398 

399 def contains_cls(x): 

400 """Returns true if `x` contains `cls`.""" 

401 if isinstance(x, dict): 

402 return any(contains_cls(v) for v in x.values()) 

403 elif x is cls: 

404 return True 

405 elif (type_annotations.is_generic_list(x) or 

406 type_annotations.is_generic_union(x)): 

407 type_args = type_annotations.get_generic_type_args(x) 

408 return any(contains_cls(arg) for arg in type_args) 

409 else: 

410 return False 

411 

412 result = {} 

413 for api, api_signatures in _TYPE_BASED_DISPATCH_SIGNATURES.items(): 

414 for _, signatures in api_signatures.items(): 

415 filtered = list(filter(contains_cls, signatures)) 

416 if filtered: 

417 result.setdefault(api, []).extend(filtered) 

418 return result 

419 

420 

421# TODO(edloper): Consider using a mechanism like this to automatically add 

422# the `name` argument to all TensorFlow APIs that are implemented in Python 

423# (so each Python function doesn't need to do it manually). 

424def _add_name_scope_wrapper(func, api_signature): 

425 """Wraps `func` to expect a "name" arg, and use it to call `ops.name_scope`. 

426 

427 If `func` already expects a "name" arg, or if `api_signature` does not 

428 expect a "name" arg, then returns `func` as-is. 

429 

430 Args: 

431 func: The function to wrap. Signature must match `api_signature` (except 

432 the "name" parameter may be missing. 

433 api_signature: The signature of the original API (used to find the index for 

434 the "name" parameter). 

435 

436 Returns: 

437 The wrapped function (or the original function if no wrapping is needed). 

438 """ 

439 if "name" not in api_signature.parameters: 

440 return func # no wrapping needed (API has no name parameter). 

441 

442 func_signature = tf_inspect.signature(func) 

443 func_argspec = tf_inspect.getargspec(func) 

444 if "name" in func_signature.parameters or func_argspec.keywords is not None: 

445 return func # No wrapping needed (already has name parameter). 

446 

447 name_index = list(api_signature.parameters).index("name") 

448 

449 def wrapped_func(*args, **kwargs): 

450 if name_index < len(args): 

451 name = args[name_index] 

452 args = args[:name_index] + args[name_index + 1:] 

453 else: 

454 name = kwargs.pop("name", None) 

455 if name is None: 

456 return func(*args, **kwargs) 

457 else: 

458 with ops.name_scope(name): 

459 return func(*args, **kwargs) 

460 

461 wrapped_func = tf_decorator.make_decorator(func, wrapped_func) 

462 wrapped_func.__signature__ = func_signature.replace( 

463 parameters=(list(func_signature.parameters.values()) + 

464 [api_signature.parameters["name"]])) 

465 del wrapped_func._tf_decorator 

466 return wrapped_func 

467 

468 

469@tf_export("experimental.unregister_dispatch_for") 

470def unregister_dispatch_for(dispatch_target): 

471 """Unregisters a function that was registered with `@dispatch_for_*`. 

472 

473 This is primarily intended for testing purposes. 

474 

475 Example: 

476 

477 >>> # Define a type and register a dispatcher to override `tf.abs`: 

478 >>> class MyTensor(tf.experimental.ExtensionType): 

479 ... value: tf.Tensor 

480 >>> @tf.experimental.dispatch_for_api(tf.abs) 

481 ... def my_abs(x: MyTensor): 

482 ... return MyTensor(tf.abs(x.value)) 

483 >>> tf.abs(MyTensor(5)) 

484 MyTensor(value=<tf.Tensor: shape=(), dtype=int32, numpy=5>) 

485 

486 >>> # Unregister the dispatcher, so `tf.abs` no longer calls `my_abs`. 

487 >>> unregister_dispatch_for(my_abs) 

488 >>> tf.abs(MyTensor(5)) 

489 Traceback (most recent call last): 

490 ... 

491 ValueError: Attempt to convert a value ... to a Tensor. 

492 

493 Args: 

494 dispatch_target: The function to unregister. 

495 

496 Raises: 

497 ValueError: If `dispatch_target` was not registered using `@dispatch_for`, 

498 `@dispatch_for_unary_elementwise_apis`, or 

499 `@dispatch_for_binary_elementwise_apis`. 

500 """ 

501 found = False 

502 

503 # Check if dispatch_target registered by `@dispatch_for_api` 

504 for api, signatures in _TYPE_BASED_DISPATCH_SIGNATURES.items(): 

505 if dispatch_target in signatures: 

506 dispatcher = getattr(api, TYPE_BASED_DISPATCH_ATTR) 

507 dispatcher.Unregister(dispatch_target) 

508 del signatures[dispatch_target] 

509 found = True 

510 

511 # Check if dispatch_target registered by `@dispatch_for_*_elementwise_apis` 

512 elementwise_keys_to_delete = [ 

513 key for (key, handler) in _ELEMENTWISE_API_HANDLERS.items() 

514 if handler is dispatch_target 

515 ] 

516 for key in set(elementwise_keys_to_delete): 

517 for _, target in _ELEMENTWISE_API_TARGETS[key]: 

518 unregister_dispatch_for(target) 

519 del _ELEMENTWISE_API_HANDLERS[key] 

520 del _ELEMENTWISE_API_TARGETS[key] 

521 found = True 

522 

523 if not found: 

524 raise ValueError(f"Function {dispatch_target} was not registered using " 

525 "a `@dispatch_for_*` decorator.") 

526 

527 

528def register_dispatchable_type(cls): 

529 """Class decorator that registers a type for use with type-based dispatch. 

530 

531 Should *not* be used with subclasses of `CompositeTensor` or `ExtensionType` 

532 (which are automatically registered). 

533 

534 Note: this function is intended to support internal legacy use cases (such 

535 as RaggedTensorValue), and will probably not be exposed as a public API. 

536 

537 Args: 

538 cls: The class to register. 

539 

540 Returns: 

541 `cls`. 

542 """ 

543 _api_dispatcher.register_dispatchable_type(cls) 

544 return cls 

545 

546 

547def add_type_based_api_dispatcher(target): 

548 """Adds a PythonAPIDispatcher to the given TensorFlow API function.""" 

549 if hasattr(target, TYPE_BASED_DISPATCH_ATTR): 

550 raise ValueError(f"{target} already has a type-based API dispatcher.") 

551 

552 _, unwrapped = tf_decorator.unwrap(target) 

553 target_argspec = tf_inspect.getargspec(unwrapped) 

554 if target_argspec.varargs or target_argspec.keywords: 

555 # @TODO(b/194903203) Add v2 dispatch support for APIs that take varargs 

556 # and keywords. Examples of APIs that take varargs and kwargs: meshgrid, 

557 # einsum, map_values, map_flat_values. 

558 return target 

559 

560 setattr( 

561 target, TYPE_BASED_DISPATCH_ATTR, 

562 _api_dispatcher.PythonAPIDispatcher(unwrapped.__name__, 

563 target_argspec.args, 

564 target_argspec.defaults)) 

565 _TYPE_BASED_DISPATCH_SIGNATURES[target] = collections.defaultdict(list) 

566 return target 

567 

568 

569def _check_signature(api_signature, func): 

570 """Checks that a dispatch target's signature is compatible with an API. 

571 

572 Args: 

573 api_signature: The signature of the TensorFlow API. 

574 func: The dispatch target. 

575 

576 Raises: 

577 ValueError: if the signatures are incompatible. Two signatures are 

578 considered compatible if they have the same number of parameters, and all 

579 corresponding parameters have the same `name` and `kind`. (Parameters 

580 are not required to have the same default value or the same annotation.) 

581 """ 

582 # Special case: if func_signature is (*args, **kwargs), then assume it's ok. 

583 func_argspec = tf_inspect.getargspec(func) 

584 if (func_argspec.varargs is not None and func_argspec.keywords is not None 

585 and not func_argspec.args): 

586 return 

587 

588 func_signature = tf_inspect.signature(func) 

589 ok = len(api_signature.parameters) == len(func_signature.parameters) 

590 if ok: 

591 for param_1, param_2 in zip(api_signature.parameters.values(), 

592 func_signature.parameters.values()): 

593 if (param_1.name != param_2.name) or (param_1.kind != param_2.kind): 

594 ok = False 

595 if not ok: 

596 raise ValueError(f"Dispatch function's signature {func_signature} does " 

597 f"not match API's signature {api_signature}.") 

598 

599 

600def _make_signature_checker(api_signature, signature): 

601 """Builds a PySignatureChecker for the given type signature. 

602 

603 Args: 

604 api_signature: The `inspect.Signature` of the API whose signature is 

605 being checked. 

606 signature: Dictionary mapping parameter names to type annotations. 

607 

608 Returns: 

609 A `PySignatureChecker`. 

610 """ 

611 if not (isinstance(signature, dict) and 

612 all(isinstance(k, (str, int)) for k in signature)): 

613 raise TypeError("signatures must be dictionaries mapping parameter names " 

614 "to type annotations.") 

615 checkers = [] 

616 

617 param_names = list(api_signature.parameters) 

618 for param_name, param_type in signature.items(): 

619 # Convert positional parameters to named parameters. 

620 if (isinstance(param_name, int) and 

621 param_name < len(api_signature.parameters)): 

622 param_name = list(api_signature.parameters.values())[param_name].name 

623 

624 # Check that the parameter exists, and has an appropriate kind. 

625 param = api_signature.parameters.get(param_name, None) 

626 if param is None: 

627 raise ValueError("signature includes annotation for unknown " 

628 f"parameter {param_name!r}.") 

629 if param.kind not in (tf_inspect.Parameter.POSITIONAL_ONLY, 

630 tf_inspect.Parameter.POSITIONAL_OR_KEYWORD): 

631 raise ValueError("Dispatch currently only supports type annotations " 

632 "for positional parameters; can't handle annotation " 

633 f"for {param.kind!r} parameter {param_name}.") 

634 

635 checker = make_type_checker(param_type) 

636 index = param_names.index(param_name) 

637 checkers.append((index, checker)) 

638 

639 return _api_dispatcher.PySignatureChecker(checkers) 

640 

641 

642# Cache for InstanceTypeChecker objects (we only want to create one 

643# InstanceTypeChecker for each type, since each one uses an internal cache 

644# to avoid repeated calls back into Python's isinstance). 

645_is_instance_checker_cache = {} 

646 

647 

648def make_type_checker(annotation): 

649 """Builds a PyTypeChecker for the given type annotation.""" 

650 if type_annotations.is_generic_union(annotation): 

651 type_args = type_annotations.get_generic_type_args(annotation) 

652 

653 # If the union contains two or more simple types, then use a single 

654 # InstanceChecker to check them. 

655 simple_types = [t for t in type_args if isinstance(t, type)] 

656 simple_types = tuple(sorted(simple_types, key=id)) 

657 if len(simple_types) > 1: 

658 if simple_types not in _is_instance_checker_cache: 

659 checker = _api_dispatcher.MakeInstanceChecker(*simple_types) 

660 _is_instance_checker_cache[simple_types] = checker 

661 options = ([_is_instance_checker_cache[simple_types]] + 

662 [make_type_checker(t) for t in type_args 

663 if not isinstance(t, type)]) 

664 return _api_dispatcher.MakeUnionChecker(options) 

665 

666 options = [make_type_checker(t) for t in type_args] 

667 return _api_dispatcher.MakeUnionChecker(options) 

668 

669 elif type_annotations.is_generic_list(annotation): 

670 type_args = type_annotations.get_generic_type_args(annotation) 

671 if len(type_args) != 1: 

672 raise AssertionError("Expected List[...] to have a single type parameter") 

673 elt_type = make_type_checker(type_args[0]) 

674 return _api_dispatcher.MakeListChecker(elt_type) 

675 

676 elif isinstance(annotation, type): 

677 if annotation not in _is_instance_checker_cache: 

678 checker = _api_dispatcher.MakeInstanceChecker(annotation) 

679 _is_instance_checker_cache[annotation] = checker 

680 return _is_instance_checker_cache[annotation] 

681 

682 elif annotation is None: 

683 return make_type_checker(type(None)) 

684 

685 else: 

686 raise ValueError(f"Type annotation {annotation} is not currently supported" 

687 " by dispatch. Supported annotations: type objects, " 

688 " List[...], and Union[...]") 

689 

690 

691def _signature_from_annotations(func): 

692 """Builds a dict mapping from parameter names to type annotations.""" 

693 func_signature = tf_inspect.signature(func) 

694 

695 signature = dict([(name, param.annotation) 

696 for (name, param) in func_signature.parameters.items() 

697 if param.annotation != tf_inspect.Parameter.empty]) 

698 if not signature: 

699 raise ValueError("The dispatch_for_api decorator must be called with at " 

700 "least one signature, or applied to a function that " 

701 "has type annotations on its parameters.") 

702 return signature 

703 

704 

705# Registries for elementwise APIs and API handlers. 

706# 

707# _*_ELEMENTWISE_APIS: A list of TensorFlow APIs that have been registered 

708# as elementwise operations using the `register_*_elementwise_api` 

709# decorators. 

710# 

711# _ELEMENTWISE_API_HANDLERS: Dicts mapping from argument type(s) to API 

712# handlers that have been registered with the `dispatch_for_*_elementwise_apis` 

713# decorators. 

714# 

715# _ELEMENTWISE_API_TARGETS: Dict mapping from argument type(s) to lists of 

716# `(api, dispatch_target)` pairs. Used to impelement 

717# `unregister_elementwise_api_handler`. 

718_UNARY_ELEMENTWISE_APIS = [] 

719_BINARY_ELEMENTWISE_APIS = [] 

720_BINARY_ELEMENTWISE_ASSERT_APIS = [] 

721_ELEMENTWISE_API_HANDLERS = {} 

722_ELEMENTWISE_API_TARGETS = {} 

723 

724_ASSERT_API_TAG = "ASSERT_API_TAG" 

725 

726 

727@tf_export("experimental.dispatch_for_unary_elementwise_apis") 

728def dispatch_for_unary_elementwise_apis(x_type): 

729 """Decorator to override default implementation for unary elementwise APIs. 

730 

731 The decorated function (known as the "elementwise api handler") overrides 

732 the default implementation for any unary elementwise API whenever the value 

733 for the first argument (typically named `x`) matches the type annotation 

734 `x_type`. The elementwise api handler is called with two arguments: 

735 

736 `elementwise_api_handler(api_func, x)` 

737 

738 Where `api_func` is a function that takes a single parameter and performs the 

739 elementwise operation (e.g., `tf.abs`), and `x` is the first argument to the 

740 elementwise api. 

741 

742 The following example shows how this decorator can be used to update all 

743 unary elementwise operations to handle a `MaskedTensor` type: 

744 

745 >>> class MaskedTensor(tf.experimental.ExtensionType): 

746 ... values: tf.Tensor 

747 ... mask: tf.Tensor 

748 >>> @dispatch_for_unary_elementwise_apis(MaskedTensor) 

749 ... def unary_elementwise_api_handler(api_func, x): 

750 ... return MaskedTensor(api_func(x.values), x.mask) 

751 >>> mt = MaskedTensor([1, -2, -3], [True, False, True]) 

752 >>> abs_mt = tf.abs(mt) 

753 >>> print(f"values={abs_mt.values.numpy()}, mask={abs_mt.mask.numpy()}") 

754 values=[1 2 3], mask=[ True False True] 

755 

756 For unary elementwise operations that take extra arguments beyond `x`, those 

757 arguments are *not* passed to the elementwise api handler, but are 

758 automatically added when `api_func` is called. E.g., in the following 

759 example, the `dtype` parameter is not passed to 

760 `unary_elementwise_api_handler`, but is added by `api_func`. 

761 

762 >>> ones_mt = tf.ones_like(mt, dtype=tf.float32) 

763 >>> print(f"values={ones_mt.values.numpy()}, mask={ones_mt.mask.numpy()}") 

764 values=[1.0 1.0 1.0], mask=[ True False True] 

765 

766 Args: 

767 x_type: A type annotation indicating when the api handler should be called. 

768 See `dispatch_for_api` for a list of supported annotation types. 

769 

770 Returns: 

771 A decorator. 

772 

773 #### Registered APIs 

774 

775 The unary elementwise APIs are: 

776 

777 <<API_LIST>> 

778 """ 

779 

780 def decorator(handler): 

781 if (x_type,) in _ELEMENTWISE_API_HANDLERS: 

782 raise ValueError("A unary elementwise dispatch handler " 

783 f"({_ELEMENTWISE_API_HANDLERS[(x_type,)]}) " 

784 f"has already been registered for {x_type}.") 

785 _ELEMENTWISE_API_HANDLERS[(x_type,)] = handler 

786 for api in _UNARY_ELEMENTWISE_APIS: 

787 _add_dispatch_for_unary_elementwise_api(api, x_type, handler) 

788 

789 return handler 

790 

791 return decorator 

792 

793 

794@tf_export("experimental.dispatch_for_binary_elementwise_apis") 

795def dispatch_for_binary_elementwise_apis(x_type, y_type): 

796 """Decorator to override default implementation for binary elementwise APIs. 

797 

798 The decorated function (known as the "elementwise api handler") overrides 

799 the default implementation for any binary elementwise API whenever the value 

800 for the first two arguments (typically named `x` and `y`) match the specified 

801 type annotations. The elementwise api handler is called with two arguments: 

802 

803 `elementwise_api_handler(api_func, x, y)` 

804 

805 Where `x` and `y` are the first two arguments to the elementwise api, and 

806 `api_func` is a TensorFlow function that takes two parameters and performs the 

807 elementwise operation (e.g., `tf.add`). 

808 

809 The following example shows how this decorator can be used to update all 

810 binary elementwise operations to handle a `MaskedTensor` type: 

811 

812 >>> class MaskedTensor(tf.experimental.ExtensionType): 

813 ... values: tf.Tensor 

814 ... mask: tf.Tensor 

815 >>> @dispatch_for_binary_elementwise_apis(MaskedTensor, MaskedTensor) 

816 ... def binary_elementwise_api_handler(api_func, x, y): 

817 ... return MaskedTensor(api_func(x.values, y.values), x.mask & y.mask) 

818 >>> a = MaskedTensor([1, 2, 3, 4, 5], [True, True, True, True, False]) 

819 >>> b = MaskedTensor([2, 4, 6, 8, 0], [True, True, True, False, True]) 

820 >>> c = tf.add(a, b) 

821 >>> print(f"values={c.values.numpy()}, mask={c.mask.numpy()}") 

822 values=[ 3 6 9 12 5], mask=[ True True True False False] 

823 

824 Args: 

825 x_type: A type annotation indicating when the api handler should be called. 

826 y_type: A type annotation indicating when the api handler should be called. 

827 

828 Returns: 

829 A decorator. 

830 

831 #### Registered APIs 

832 

833 The binary elementwise APIs are: 

834 

835 <<API_LIST>> 

836 """ 

837 

838 def decorator(handler): 

839 if (x_type, y_type) in _ELEMENTWISE_API_HANDLERS: 

840 raise ValueError("A binary elementwise dispatch handler " 

841 f"({_ELEMENTWISE_API_HANDLERS[x_type, y_type]}) " 

842 f"has already been registered for ({x_type}, {y_type}).") 

843 _ELEMENTWISE_API_HANDLERS[x_type, y_type] = handler 

844 for api in _BINARY_ELEMENTWISE_APIS: 

845 _add_dispatch_for_binary_elementwise_api(api, x_type, y_type, handler) 

846 

847 return handler 

848 

849 return decorator 

850 

851 

852@tf_export("experimental.dispatch_for_binary_elementwise_assert_apis") 

853def dispatch_for_binary_elementwise_assert_apis(x_type, y_type): 

854 """Decorator to override default implementation for binary elementwise assert APIs. 

855 

856 The decorated function (known as the "elementwise assert handler") 

857 overrides the default implementation for any binary elementwise assert API 

858 whenever the value for the first two arguments (typically named `x` and `y`) 

859 match the specified type annotations. The handler is called with two 

860 arguments: 

861 

862 `elementwise_assert_handler(assert_func, x, y)` 

863 

864 Where `x` and `y` are the first two arguments to the binary elementwise assert 

865 operation, and `assert_func` is a TensorFlow function that takes two 

866 parameters and performs the elementwise assert operation (e.g., 

867 `tf.debugging.assert_equal`). 

868 

869 The following example shows how this decorator can be used to update all 

870 binary elementwise assert operations to handle a `MaskedTensor` type: 

871 

872 >>> class MaskedTensor(tf.experimental.ExtensionType): 

873 ... values: tf.Tensor 

874 ... mask: tf.Tensor 

875 >>> @dispatch_for_binary_elementwise_assert_apis(MaskedTensor, MaskedTensor) 

876 ... def binary_elementwise_assert_api_handler(assert_func, x, y): 

877 ... merged_mask = tf.logical_and(x.mask, y.mask) 

878 ... selected_x_values = tf.boolean_mask(x.values, merged_mask) 

879 ... selected_y_values = tf.boolean_mask(y.values, merged_mask) 

880 ... assert_func(selected_x_values, selected_y_values) 

881 >>> a = MaskedTensor([1, 1, 0, 1, 1], [False, False, True, True, True]) 

882 >>> b = MaskedTensor([2, 2, 0, 2, 2], [True, True, True, False, False]) 

883 >>> tf.debugging.assert_equal(a, b) # assert passed; no exception was thrown 

884 

885 >>> a = MaskedTensor([1, 1, 1, 1, 1], [True, True, True, True, True]) 

886 >>> b = MaskedTensor([0, 0, 0, 0, 2], [True, True, True, True, True]) 

887 >>> tf.debugging.assert_greater(a, b) 

888 Traceback (most recent call last): 

889 ... 

890 InvalidArgumentError: Condition x > y did not hold. 

891 

892 Args: 

893 x_type: A type annotation indicating when the api handler should be called. 

894 y_type: A type annotation indicating when the api handler should be called. 

895 

896 Returns: 

897 A decorator. 

898 

899 #### Registered APIs 

900 

901 The binary elementwise assert APIs are: 

902 

903 <<API_LIST>> 

904 """ 

905 

906 def decorator(handler): 

907 api_handler_key = (x_type, y_type, _ASSERT_API_TAG) 

908 if api_handler_key in _ELEMENTWISE_API_HANDLERS: 

909 raise ValueError("A binary elementwise assert dispatch handler " 

910 f"({_ELEMENTWISE_API_HANDLERS[api_handler_key]}) " 

911 f"has already been registered for ({x_type}, {y_type}).") 

912 _ELEMENTWISE_API_HANDLERS[api_handler_key] = handler 

913 for api in _BINARY_ELEMENTWISE_ASSERT_APIS: 

914 _add_dispatch_for_binary_elementwise_api(api, x_type, y_type, handler) 

915 

916 return handler 

917 

918 return decorator 

919 

920 

921def register_unary_elementwise_api(func): 

922 """Decorator that registers a TensorFlow op as a unary elementwise API.""" 

923 _UNARY_ELEMENTWISE_APIS.append(func) 

924 for args, handler in _ELEMENTWISE_API_HANDLERS.items(): 

925 if len(args) == 1: 

926 _add_dispatch_for_unary_elementwise_api(func, args[0], handler) 

927 return func 

928 

929 

930def register_binary_elementwise_api(func): 

931 """Decorator that registers a TensorFlow op as a binary elementwise API.""" 

932 _BINARY_ELEMENTWISE_APIS.append(func) 

933 for args, handler in _ELEMENTWISE_API_HANDLERS.items(): 

934 if len(args) == 2: 

935 _add_dispatch_for_binary_elementwise_api(func, args[0], args[1], handler) 

936 return func 

937 

938 

939def register_binary_elementwise_assert_api(func): 

940 """Decorator that registers a TensorFlow op as a binary elementwise assert API. 

941 

942 Different from `dispatch_for_binary_elementwise_apis`, this decorator is used 

943 for assert apis, such as assert_equal, assert_none_equal, etc, which return 

944 None in eager mode and an op in graph mode. 

945 

946 Args: 

947 func: The function that implements the binary elementwise assert API. 

948 

949 Returns: 

950 `func` 

951 """ 

952 _BINARY_ELEMENTWISE_ASSERT_APIS.append(func) 

953 for args, handler in _ELEMENTWISE_API_HANDLERS.items(): 

954 if len(args) == 3 and args[2] is _ASSERT_API_TAG: 

955 _add_dispatch_for_binary_elementwise_api(func, args[0], args[1], handler) 

956 return func 

957 

958 

959def unary_elementwise_apis(): 

960 """Returns a list of APIs that have been registered as unary elementwise.""" 

961 return tuple(_UNARY_ELEMENTWISE_APIS) 

962 

963 

964def binary_elementwise_apis(): 

965 """Returns a list of APIs that have been registered as binary elementwise.""" 

966 return tuple(_BINARY_ELEMENTWISE_APIS) 

967 

968 

969def _add_dispatch_for_unary_elementwise_api(api, x_type, 

970 elementwise_api_handler): 

971 """Registers a unary elementwise handler as a dispatcher for a given API.""" 

972 api_signature = tf_inspect.signature(api) 

973 x_name = list(api_signature.parameters)[0] 

974 name_index = _find_name_index(api_signature) 

975 

976 need_to_bind_api_args = ( 

977 len(api_signature.parameters) > 2 or 

978 "name" not in api_signature.parameters) 

979 

980 @dispatch_for_api(api, {x_name: x_type}) 

981 def dispatch_target(*args, **kwargs): 

982 args, kwargs, name = _extract_name_arg(args, kwargs, name_index) 

983 if args: 

984 x, args = args[0], args[1:] 

985 else: 

986 x = kwargs.pop(x_name) 

987 

988 if need_to_bind_api_args: 

989 tensor_api = lambda v: api(v, *args, **kwargs) 

990 else: 

991 tensor_api = api 

992 

993 if name is None: 

994 return elementwise_api_handler(tensor_api, x) 

995 else: 

996 with ops.name_scope(name, None, [x]): 

997 return elementwise_api_handler(tensor_api, x) 

998 

999 dispatch_target.__name__ = "elementwise_dispatch_target_for_" + api.__name__ 

1000 dispatch_target.__qualname__ = dispatch_target.__name__ 

1001 # Keep track of what targets we've registered (so we can unregister them). 

1002 target_list = _ELEMENTWISE_API_TARGETS.setdefault((x_type,), []) 

1003 target_list.append((api, dispatch_target)) 

1004 

1005 

1006def _add_dispatch_for_binary_elementwise_api(api, x_type, y_type, 

1007 elementwise_api_handler): 

1008 """Registers a binary elementwise handler as a dispatcher for a given API.""" 

1009 api_signature = tf_inspect.signature(api) 

1010 x_name, y_name = list(api_signature.parameters)[:2] 

1011 name_index = _find_name_index(api_signature) 

1012 

1013 need_to_bind_api_args = (len(api_signature.parameters) > 3 or 

1014 "name" not in api_signature.parameters) 

1015 

1016 @dispatch_for_api(api, {x_name: x_type, y_name: y_type}) 

1017 def dispatch_target(*args, **kwargs): 

1018 args, kwargs, name = _extract_name_arg(args, kwargs, name_index) 

1019 if len(args) > 1: 

1020 x, y, args = args[0], args[1], args[2:] 

1021 elif args: 

1022 x, args = args[0], args[1:] 

1023 y = kwargs.pop(y_name, None) 

1024 else: 

1025 x = kwargs.pop(x_name, None) 

1026 y = kwargs.pop(y_name, None) 

1027 

1028 if need_to_bind_api_args: 

1029 tensor_api = lambda v1, v2: api(v1, v2, *args, **kwargs) 

1030 else: 

1031 tensor_api = api 

1032 

1033 if name is None: 

1034 return elementwise_api_handler(tensor_api, x, y) 

1035 else: 

1036 with ops.name_scope(name, None, [x, y]): 

1037 return elementwise_api_handler(tensor_api, x, y) 

1038 

1039 dispatch_target.__name__ = "elementwise_dispatch_target_for_" + api.__name__ 

1040 dispatch_target.__qualname__ = dispatch_target.__name__ 

1041 # Keep track of what targets we've registered (so we can unregister them). 

1042 target_list = _ELEMENTWISE_API_TARGETS.setdefault((x_type, y_type), []) 

1043 target_list.append((api, dispatch_target)) 

1044 

1045 

1046def _find_name_index(signature): 

1047 """Returns the index of the `name` parameter, or -1 if it's not present.""" 

1048 try: 

1049 return list(signature.parameters).index("name") 

1050 except ValueError: 

1051 return -1 

1052 

1053 

1054def _extract_name_arg(args, kwargs, name_index): 

1055 """Extracts the parameter `name` and returns `(args, kwargs, name_value)`.""" 

1056 if name_index < 0: 

1057 name_value = None 

1058 elif name_index < len(args): 

1059 name_value = args[name_index] 

1060 args = args[:name_index] + args[name_index + 1:] 

1061 else: 

1062 name_value = kwargs.pop("name", None) 

1063 return args, kwargs, name_value 

1064 

1065 

1066def update_docstrings_with_api_lists(): 

1067 """Updates the docstrings of dispatch decorators with API lists. 

1068 

1069 Updates docstrings for `dispatch_for_api`, 

1070 `dispatch_for_unary_elementwise_apis`, and 

1071 `dispatch_for_binary_elementwise_apis`, by replacing the string '<<API_LIST>>' 

1072 with a list of APIs that have been registered for that decorator. 

1073 """ 

1074 _update_docstring_with_api_list(dispatch_for_unary_elementwise_apis, 

1075 _UNARY_ELEMENTWISE_APIS) 

1076 _update_docstring_with_api_list(dispatch_for_binary_elementwise_apis, 

1077 _BINARY_ELEMENTWISE_APIS) 

1078 _update_docstring_with_api_list(dispatch_for_binary_elementwise_assert_apis, 

1079 _BINARY_ELEMENTWISE_ASSERT_APIS) 

1080 _update_docstring_with_api_list(dispatch_for_api, 

1081 _TYPE_BASED_DISPATCH_SIGNATURES) 

1082 

1083 

1084def _update_docstring_with_api_list(target, api_list): 

1085 """Replaces `<<API_LIST>>` in target.__doc__ with the given list of APIs.""" 

1086 lines = [] 

1087 for func in api_list: 

1088 name = tf_export_lib.get_canonical_name_for_symbol( 

1089 func, add_prefix_to_v1_names=True) 

1090 if name is not None: 

1091 params = tf_inspect.signature(func).parameters.keys() 

1092 lines.append(f" * `tf.{name}({', '.join(params)})`") 

1093 lines.sort() 

1094 target.__doc__ = target.__doc__.replace(" <<API_LIST>>", "\n".join(lines)) 

1095 

1096 

1097################################################################################ 

1098# Dispatch Support 

1099################################################################################ 

1100@tf_export("__internal__.dispatch.add_dispatch_support", v1=[]) 

1101def add_dispatch_support(target=None, iterable_parameters=None): 

1102 """Decorator that adds a dispatch handling wrapper to a TensorFlow Python API. 

1103 

1104 This wrapper adds the decorated function as an API that can be overridden 

1105 using the `@dispatch_for_api` decorator. In the following example, we first 

1106 define a new API (`double`) that supports dispatch, then define a custom type 

1107 (`MaskedTensor`) and finally use `dispatch_for_api` to override the default 

1108 implementation of `double` when called with `MaskedTensor` values: 

1109 

1110 >>> @add_dispatch_support 

1111 ... def double(x): 

1112 ... return x * 2 

1113 >>> class MaskedTensor(tf.experimental.ExtensionType): 

1114 ... values: tf.Tensor 

1115 ... mask: tf.Tensor 

1116 >>> @dispatch_for_api(double, {'x': MaskedTensor}) 

1117 ... def masked_double(x): 

1118 ... return MaskedTensor(x.values * 2, y.mask) 

1119 

1120 The optional `iterable_parameter` argument can be used to mark parameters that 

1121 can take arbitrary iterable values (such as generator expressions). These 

1122 need to be handled specially during dispatch, since just iterating over an 

1123 iterable uses up its values. In the following example, we define a new API 

1124 whose second argument can be an iterable value; and then override the default 

1125 implementatio of that API when the iterable contains MaskedTensors: 

1126 

1127 >>> @add_dispatch_support(iterable_parameters=['ys']) 

1128 ... def add_tensor_to_list_of_tensors(x, ys): 

1129 ... return [x + y for y in ys] 

1130 >>> @dispatch_for_api(add_tensor_to_list_of_tensors, 

1131 ... {'ys': typing.List[MaskedTensor]}) 

1132 ... def masked_add_tensor_to_list_of_tensors(x, ys): 

1133 ... return [MaskedTensor(x+y.values, y.mask) for y in ys] 

1134 

1135 (Note: the only TensorFlow API that currently supports iterables is `add_n`.) 

1136 

1137 Args: 

1138 target: The TensorFlow API that should support dispatch. 

1139 iterable_parameters: Optional list of parameter names that may be called 

1140 with iterables (such as the `inputs` parameter for `tf.add_n`). 

1141 

1142 Returns: 

1143 A decorator. 

1144 """ 

1145 

1146 if not (iterable_parameters is None or 

1147 (isinstance(iterable_parameters, (list, tuple)) and 

1148 all(isinstance(p, str) for p in iterable_parameters))): 

1149 raise TypeError("iterable_parameters should be a list or tuple of string.") 

1150 

1151 def decorator(dispatch_target): 

1152 

1153 # Get the name & index for each iterable parameter. 

1154 if iterable_parameters is None: 

1155 iterable_params = None 

1156 else: 

1157 arg_names = tf_inspect.getargspec(dispatch_target).args 

1158 iterable_params = [ 

1159 (name, arg_names.index(name)) for name in iterable_parameters 

1160 ] 

1161 

1162 @traceback_utils.filter_traceback 

1163 def op_dispatch_handler(*args, **kwargs): 

1164 """Call `dispatch_target`, peforming dispatch when appropriate.""" 

1165 

1166 # Type-based dispatch system (dispatch v2): 

1167 if api_dispatcher is not None: 

1168 if iterable_params is not None: 

1169 args, kwargs = replace_iterable_params(args, kwargs, iterable_params) 

1170 result = api_dispatcher.Dispatch(args, kwargs) 

1171 if result is not NotImplemented: 

1172 return result 

1173 

1174 # Fallback dispatch system (dispatch v1): 

1175 try: 

1176 return dispatch_target(*args, **kwargs) 

1177 except (TypeError, ValueError): 

1178 # Note: convert_to_eager_tensor currently raises a ValueError, not a 

1179 # TypeError, when given unexpected types. So we need to catch both. 

1180 result = dispatch(op_dispatch_handler, args, kwargs) 

1181 if result is not OpDispatcher.NOT_SUPPORTED: 

1182 return result 

1183 else: 

1184 raise 

1185 

1186 add_fallback_dispatch_list(op_dispatch_handler) 

1187 op_dispatch_handler = tf_decorator.make_decorator(dispatch_target, 

1188 op_dispatch_handler) 

1189 add_type_based_api_dispatcher(op_dispatch_handler) 

1190 api_dispatcher = getattr(op_dispatch_handler, TYPE_BASED_DISPATCH_ATTR, 

1191 None) 

1192 return op_dispatch_handler 

1193 

1194 if target is None: 

1195 return decorator 

1196 else: 

1197 return decorator(target) 

1198 

1199 

1200def replace_iterable_params(args, kwargs, iterable_params): 

1201 """Returns (args, kwargs) with any iterable parameters converted to lists. 

1202 

1203 Args: 

1204 args: Positional rguments to a function 

1205 kwargs: Keyword arguments to a function. 

1206 iterable_params: A list of (name, index) tuples for iterable parameters. 

1207 

1208 Returns: 

1209 A tuple (args, kwargs), where any positional or keyword parameters in 

1210 `iterable_params` have their value converted to a `list`. 

1211 """ 

1212 args = list(args) 

1213 for name, index in iterable_params: 

1214 if index < len(args): 

1215 args[index] = list(args[index]) 

1216 elif name in kwargs: 

1217 kwargs[name] = list(kwargs[name]) 

1218 return tuple(args), kwargs