Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/framework/function.py: 16%

523 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================= 

15"""Python front-end supports for functions. 

16 

17NOTE: At this time, functions are experimental and subject to change!. Proceed 

18with caution. 

19""" 

20 

21import collections 

22import hashlib 

23 

24from tensorflow.core.framework import attr_value_pb2 

25from tensorflow.core.framework import function_pb2 

26from tensorflow.python.client import pywrap_tf_session as c_api 

27from tensorflow.python.eager import context 

28from tensorflow.python.framework import c_api_util 

29from tensorflow.python.framework import dtypes 

30from tensorflow.python.framework import graph_to_function_def 

31from tensorflow.python.framework import ops 

32from tensorflow.python.ops import array_ops 

33from tensorflow.python.ops import resource_variable_ops 

34from tensorflow.python.ops import variable_scope as vs 

35from tensorflow.python.util import compat 

36from tensorflow.python.util import function_utils 

37from tensorflow.python.util import tf_contextlib 

38from tensorflow.python.util import tf_inspect 

39 

40 

41# TODO(b/136040013): Drop support for Defun. 

42class Defun(object): 

43 """Obsolete. Slated for deletion. Please use tf.function instead. 

44 

45 Known feature gaps while migrating to tf.function (could be outdated): 

46 - tf.function doesn’t support Send/Recv capability since it doesn’t share 

47 rendezvous with the main graph but always creates a new one. 

48 - tf.function doesn’t support custom gradient function directly, instead you 

49 need to define the function inside a tf.custom_gradient wrapper together 

50 with the gradient function. 

51 - Unlike Defun, Keras layers used inside a tf.function need to be created only 

52 once to avoid variable recreation. 

53 - Defun respects the device assignments and applies them to the function body 

54 but tf.function needs it to be done manually. 

55 - Defun might prune out unused ops automatically but tf.function doesn't. 

56 

57 Limitations of Defun: 

58 - Original source locations are not preserved so errors do not include 

59 full/valid stack traces. 

60 - Only supports linear sequence of arguments and return values, putting the 

61 burden on the caller to pack/unpack everything across a Defun boundary into 

62 tuples (as opposed to passing list and dict-like structures directly). 

63 - Does not support overloading or late-bound specializations. 

64 - Has its own way for defining gradient overrides which does not follow 

65 current conventions. 

66 - Cannot support imperative control flow or automatic control dependencies. 

67 - Does not reflect statefulness in the graph and has a calling convention that 

68 differs from how more modern tools interact. 

69 - Is only compatible with graph building mode. 

70 

71 Decorator used to define TensorFlow functions. 

72 

73 Use this decorator to make a Python function usable directly as a TensorFlow 

74 function. 

75 

76 The decorated function must add ops to the default graph and return zero or 

77 more `Tensor` objects. Call the decorator with named arguments, one for each 

78 argument of the function to decorate, with the expected type of the argument 

79 as value. 

80 

81 For example if the function to decorate accepts two `tf.float32` arguments 

82 named `x` and `y`, call the decorator with: 

83 

84 @Defun(tf.float32, tf.float32) 

85 def foo(x, y): 

86 ... 

87 

88 When you call the decorated function, it adds the `call` ops to the 

89 default graph. In addition, it adds the definition of the function into the 

90 default graph. Because the addition of the function into the graph 

91 is deferred, the decorator can be used anywhere in the program. 

92 

93 Any variables created inside of the function are hoisted into the outer graph. 

94 Note that the variables are created in the variable scope that was active 

95 during the first call to the function. Subsequent function calls will refer to 

96 the same set of variables. 

97 

98 Definitions of functions in a graph are frozen as soon as the graph is used to 

99 create a session. However, new functions and new calls to existing functions 

100 may be added to the graph, with the new functions themselves becoming 

101 immediately frozen. 

102 

103 Example, but also see the [How To on functions](link_needed). 

104 

105 ```python 

106 # Defining the function. 

107 @tf.Defun(tf.float32, tf.float32) 

108 def MyFunc(x, y): 

109 return x + y, x - y 

110 

111 # Building the graph. 

112 a = tf.constant([1.0]) 

113 b = tf.constant([2.0]) 

114 c, d = MyFunc(a, b, name='mycall') 

115 ``` 

116 """ 

117 

118 def __init__(self, *input_types, **kwargs): 

119 """Create a `Defun` decorator. 

120 

121 Args: 

122 *input_types: A list of `tf.DType` 

123 **kwargs: Optional keyword arguments, including 

124 func_name - (optional). A python string, the name to use to 

125 declare this `Function` in the graph. 

126 

127 grad_func - (optional). A function implementing the gradient 

128 of the function-to-register. This is must be a 

129 `_DefinedFunction` object. The gradient 

130 function must satisfy the criterion defined in 

131 function.proto:GradientDef. 

132 

133 python_grad_func - (optional). A function implementing the 

134 gradient of the function python-side. This function must 

135 take the current op and the gradients w.r.t. its outputs, 

136 and return the gradients w.r.t. the inputs. That is it must 

137 implement the interface expected by `tf.RegisterGradient`). 

138 This will be called by tf.gradients to add the gradient ops 

139 to the graph. At most one of grad_func and python_grad_func 

140 can be specified. 

141 

142 out_names = (optional). A list of strings, one per output 

143 tensor. 

144 

145 shape_func - (optional). A function taking the op and returning a list 

146 of static shapes to set for the function's outputs. 

147 """ 

148 self._input_types = input_types 

149 self._func_name = kwargs.pop("func_name", None) 

150 self._grad_func = kwargs.pop("grad_func", None) 

151 self._python_grad_func = kwargs.pop("python_grad_func", None) 

152 self._out_names = kwargs.pop("out_names", None) 

153 self._extra_kwargs = kwargs 

154 

155 def __call__(self, func): 

156 # Various sanity checks on the callable func. 

157 if not callable(func): 

158 raise ValueError(f"Function {func} must be a callable.") 

159 

160 # Func should not use kwargs and defaults. 

161 argspec = tf_inspect.getargspec(func) 

162 if argspec.keywords or argspec.defaults: 

163 raise ValueError( 

164 "Functions with argument defaults or keywords arguments are not " 

165 f"supported. {func} has defaults {argspec.defaults} and keywords " 

166 f"{argspec.keywords}.") 

167 

168 # Computes how many arguments 'func' has. 

169 min_args = len(argspec.args) 

170 max_args = min_args 

171 if argspec.varargs: 

172 max_args = 1000000 

173 argnames = argspec.args 

174 if tf_inspect.ismethod(func): 

175 # 1st argument is the "class" type. 

176 min_args -= 1 

177 argnames = argnames[1:] 

178 

179 if self._input_types: 

180 # If Defun is given a list of types for the inputs, the number 

181 # of input types should be compatible with 'func'. 

182 num = len(self._input_types) 

183 if num < min_args or num > max_args: 

184 raise ValueError( 

185 "The number of tf.function input types is not compatible with the " 

186 f"allowed arguments of {func}. The tf.function have {num} input " 

187 f"types, while the python function allows minimum {min_args} and " 

188 f"maximum {max_args} arguments.") 

189 return _DefinedFunction( 

190 func, 

191 argnames, 

192 self._input_types, 

193 self._func_name, 

194 self._grad_func, 

195 self._python_grad_func, 

196 out_names=self._out_names, 

197 **self._extra_kwargs) 

198 

199 # 'func' expects no arguments and input types is an empty list. 

200 if min_args == 0 and max_args == 0: 

201 return _DefinedFunction( 

202 func, [], [], 

203 self._func_name, 

204 self._grad_func, 

205 self._python_grad_func, 

206 out_names=self._out_names, 

207 **self._extra_kwargs) 

208 

209 # Input types are unknown. It's an overloaded function and hence 

210 # its definition needs to be deferred until it's called. 

211 return _OverloadedFunction( 

212 func, 

213 argnames, 

214 self._func_name, 

215 self._grad_func, 

216 self._python_grad_func, 

217 out_names=self._out_names, 

218 **self._extra_kwargs) 

219 

220 

221class _DefinedFunctionDeleter(object): 

222 """Unregister function from eager context.""" 

223 

224 __slots__ = ["name"] 

225 

226 def __init__(self, name): 

227 self.name = name 

228 

229 def __del__(self): 

230 try: 

231 context.remove_function(self.name) 

232 except TypeError: 

233 # Suppress some exceptions, mainly for the case when we're running on 

234 # module deletion. Things that can go wrong include the context module 

235 # already being unloaded, self._handle._handle_data no longer being 

236 # valid, and so on. Printing warnings in these cases is silly 

237 # (exceptions raised from __del__ are printed as warnings to stderr). 

238 pass # 'NoneType' object is not callable when the handle has been 

239 # partially unloaded. 

240 except AttributeError: 

241 pass # 'NoneType' object has no attribute 'eager_mode' when context has 

242 # been unloaded. Will catch other module unloads as well. 

243 

244 

245class _DefinedFunction(object): 

246 """_DefinedFunction encapsulates a function definition and its properties. 

247 

248 Attributes: 

249 name: The function name. 

250 definition: The definition of this function. A FunctionDef proto. 

251 cached_definition: Same as definition. Needed to match AtomicFunction API. 

252 grad_func_name: If not None, the name of this function's gradient function. 

253 python_grad_func: A python callable implementing the gradient of 

254 the function python-side. 

255 """ 

256 

257 def __init__(self, 

258 func, 

259 argnames, 

260 input_types, 

261 func_name=None, 

262 grad_func=None, 

263 python_grad_func=None, 

264 out_names=None, 

265 shape_func=None, 

266 capture_by_value=False, 

267 allowlisted_stateful_ops=None, 

268 capture_resource_var_by_value=True, 

269 **kwargs): 

270 """Creates _DefinedFunction. 

271 

272 Args: 

273 func: A python callable which constructs a tf function body. 

274 argnames: A list of strings for function argument names. 

275 input_types: The function's argument types. Can be a tuple, list of 

276 tf data types. 

277 func_name: The function name. Defaults to None, in which derives from 

278 'func'. 

279 grad_func: This function's gradient function, if not None. Defaults 

280 to None. 

281 python_grad_func: A python callable implementing the gradient of 

282 the function python-side. 

283 out_names: An optional list of strings for the function return value 

284 names. 

285 shape_func: An optional function mapping an op to a list of static 

286 output shapes. 

287 capture_by_value: Boolean (defaults to False). If True, captured values 

288 will be copied into the function body. 

289 allowlisted_stateful_ops: A set of ops that if stateful we ignore and 

290 copy into the function body, when `capture_by_value` is True. 

291 capture_resource_var_by_value: Boolean (defaults to True). If False, 

292 captured resource variable returns the handle instead of value. 

293 **kwargs: The keyword arguments. **kwargs is passed to every call 

294 site of this function. 

295 

296 Raises: 

297 ValueError: The function definition is invalid. 

298 

299 """ 

300 self._func = func 

301 self._input_types = input_types 

302 self._func_name = func_name 

303 self._grad_func = grad_func 

304 self._python_grad_func = python_grad_func 

305 self._out_names = out_names 

306 self._shape_func = shape_func 

307 self._capture_by_value = capture_by_value 

308 self._allowlisted_stateful_ops = allowlisted_stateful_ops 

309 if self._allowlisted_stateful_ops is None: 

310 self._allowlisted_stateful_ops = set() 

311 self._capture_resource_var_by_value = capture_resource_var_by_value 

312 self._extra_kwargs = kwargs 

313 # Constructed only when C API is disabled, lazily 

314 self._definition = None 

315 # Constructed only when C API is enabled, lazily 

316 self._c_func = None 

317 self._function_deleter = None 

318 self._sub_functions = {} # Constructed with _definition or _c_func 

319 # pylint: disable=protected-access 

320 device_funcs = ops.get_default_graph()._device_functions_outer_to_inner 

321 # pylint: enable=protected-access 

322 

323 # Get the innermost device if possible. 

324 self._caller_device = device_funcs[-1] if device_funcs else None 

325 

326 # Cached OpDef for this function. When C API is enabled, this is 

327 # the only part of FunctionDef that we cache in Python. When C API 

328 # is disabled the whole _definition is available and this is simply 

329 # another reference to _definition.signature 

330 self._op_def = None 

331 

332 assert isinstance(input_types, (list, tuple)) 

333 self._arg_types = input_types 

334 self._arg_names = [argnames[i] if i < len(argnames) else ("arg%d" % i) 

335 for i in range(len(input_types))] 

336 

337 @property 

338 def name(self): 

339 """Function name.""" 

340 self._create_definition_if_needed() 

341 return self._func_name 

342 

343 @property 

344 def cached_definition(self): 

345 return self.definition 

346 

347 @property 

348 def definition(self): 

349 """Function definition proto.""" 

350 self._create_definition_if_needed() 

351 if self._c_func: 

352 with c_api_util.tf_buffer() as buf: 

353 with self._c_func.get() as func: 

354 c_api.TF_FunctionToFunctionDef(func, buf) 

355 fdef = function_pb2.FunctionDef() 

356 proto_data = c_api.TF_GetBuffer(buf) 

357 fdef.ParseFromString(compat.as_bytes(proto_data)) 

358 with ops.init_scope(): 

359 if context.executing_eagerly(): 

360 context.add_c_function(func) 

361 self._function_deleter = _DefinedFunctionDeleter( 

362 fdef.signature.name) 

363 return fdef 

364 return self._definition 

365 

366 @property 

367 def _signature(self): 

368 self._create_definition_if_needed() 

369 return self._op_def 

370 

371 def set_grad_func(self, grad_func): 

372 """Specifies the gradient function of this function.""" 

373 assert not self._grad_func 

374 assert isinstance(grad_func, _DefinedFunction) 

375 self._grad_func = grad_func 

376 

377 @property 

378 def grad_func_name(self): 

379 """Returns the name of the gradient function.""" 

380 return self._grad_func.name if self._grad_func else None 

381 

382 @property 

383 def python_grad_func(self): 

384 """Python gradient function callable.""" 

385 return self._python_grad_func 

386 

387 @property 

388 def declared_input_types(self): 

389 """Returns the list of data types of explicit declared inputs.""" 

390 return self._input_types 

391 

392 @property 

393 def captured_inputs(self): 

394 """Returns the list of implicitly captured inputs.""" 

395 self._create_definition_if_needed() 

396 return self._extra_inputs 

397 

398 @property 

399 def stateful_ops(self): 

400 """Returns the list of stateful ops in function definition. 

401 

402 Returns: 

403 A list of (op.name, op.type) pairs. 

404 """ 

405 self._create_definition_if_needed() 

406 return self._stateful_ops 

407 

408 def _create_definition_if_needed(self): 

409 """Creates the function definition if it's not created yet.""" 

410 with context.graph_mode(): 

411 self._create_definition_if_needed_impl() 

412 

413 def _create_definition_if_needed_impl(self): 

414 """This is not what you want, see _create_definition_if_needed.""" 

415 if self._definition is not None or self._c_func is not None: 

416 return 

417 

418 # Copy variable collections (by reference) from the parent graph such that 

419 # name based variable sharing (e.g. via tf.make_template) works between the 

420 # func graph and parent graph. 

421 variable_keys = [] 

422 variable_keys.extend(ops.GraphKeys._VARIABLE_COLLECTIONS) # pylint: disable=protected-access 

423 variable_keys.append(vs._VARSTORE_KEY) # pylint: disable=protected-access 

424 

425 parent_graph = ops.get_default_graph() 

426 collections_ref = { 

427 key: parent_graph.get_collection_ref(key) for key in variable_keys} 

428 

429 temp_graph = func_graph_from_py_func( 

430 self._func, 

431 self._arg_names, 

432 self._arg_types, 

433 self._func_name, 

434 self._capture_by_value, 

435 self._caller_device, 

436 collections_ref=collections_ref, 

437 allowlisted_stateful_ops=self._allowlisted_stateful_ops, 

438 capture_resource_var_by_value=self._capture_resource_var_by_value) 

439 

440 self._extra_inputs = temp_graph.extra_inputs 

441 # pylint: disable=protected-access 

442 self._sub_functions = temp_graph._functions 

443 # pylint: enable=protected-access 

444 

445 # Extra kwargs are treated as attrs on the function def. 

446 if self._func_name: 

447 base_func_name = self._func_name 

448 else: 

449 base_func_name = function_utils.get_func_name(self._func) 

450 if self._grad_func: 

451 base_func_name += ("_%s" % self._grad_func.name) 

452 kwargs_attr = _parse_kwargs_as_attrs(base_func_name, **self._extra_kwargs) 

453 

454 # FIXME(feyu): C API is always enabled now. The if-true branch never runs. 

455 if not temp_graph._c_graph: # pylint: disable=protected-access 

456 # Build the FunctionDef 

457 self._definition = graph_to_function_def.graph_to_function_def( 

458 temp_graph, 

459 temp_graph.get_operations(), 

460 temp_graph.inputs, 

461 temp_graph.outputs, 

462 out_names=self._out_names) 

463 

464 for k in kwargs_attr: 

465 self._definition.attr[k].CopyFrom(kwargs_attr[k]) 

466 

467 # Hash the definition and its dependencies. 

468 self._hash_str = self._create_hash_str( 

469 self._definition.signature.input_arg, 

470 self._definition.signature.output_arg, self._definition.node_def) 

471 

472 # Finally, we decide the function name to use. If not specified, 

473 # make up something which is almost certainly unique (but deterministic). 

474 if not self._func_name: 

475 self._func_name = "_".join([base_func_name, self._hash_str]) 

476 self._definition.signature.name = self._func_name 

477 if self._func.__doc__: 

478 self._definition.signature.description = self._func.__doc__ 

479 

480 self._op_def = self._definition.signature 

481 else: # C API is enabled 

482 output_names = ([compat.as_bytes(x) for x in self._out_names] 

483 if self._out_names else []) 

484 description = self._func.__doc__ or None 

485 # pylint: disable=protected-access 

486 with temp_graph._c_graph.get() as c_graph: 

487 c_func = c_api.TF_GraphToFunction_wrapper( 

488 c_graph, 

489 base_func_name, 

490 self._func_name is None, # append_hash_to_fn_name 

491 None, # opers 

492 [t._as_tf_output() for t in temp_graph.inputs], 

493 [t._as_tf_output() for t in temp_graph.outputs], 

494 output_names, 

495 [], # control_outputs 

496 [], # control_output_names 

497 None, # opts 

498 description) 

499 self._c_func = c_api_util.ScopedTFFunction(c_func, base_func_name) 

500 # pylint: enable=protected-access 

501 self._set_c_attrs(kwargs_attr) 

502 

503 # Set cached fields: _op_def and _func_name (if not already set) 

504 self._op_def = self.definition.signature 

505 if self._func_name: 

506 assert self._func_name == self._op_def.name 

507 else: 

508 self._func_name = compat.as_str(self._op_def.name) 

509 

510 self._stateful_ops = [(op.name, op.type) 

511 for op in temp_graph.get_operations() 

512 if op._is_stateful] # pylint: disable=protected-access 

513 

514 def _set_c_attrs(self, attrs): 

515 """Sets `attrs` as attributes of self._c_func. 

516 

517 Requires that self._c_func is not None. 

518 

519 Args: 

520 attrs: a dictionary from attribute name to attribute proto value 

521 """ 

522 for name, attr_value in attrs.items(): 

523 serialized = attr_value.SerializeToString() 

524 # TODO(skyewm): this creates and deletes a new TF_Status for every attr. 

525 # It might be worth creating a convenient way to re-use the same status. 

526 with self._c_func.get() as func: 

527 c_api.TF_FunctionSetAttrValueProto(func, compat.as_str(name), 

528 serialized) 

529 

530 def _create_hash_str(self, input_arg, output_arg, node_def): 

531 """Creates an 8-character string unique to this input. 

532 

533 Args: 

534 input_arg: the input_arg field of an OpDef 

535 (e.g. self._definition.signature.input_arg) 

536 output_arg: the output_arg field of an OpDef 

537 (e.g. self._definition.signature.output_arg) 

538 node_def: the node_def field of a FunctionDef 

539 (e.g. self._definition.node_def) 

540 

541 Returns: 

542 The unique string for this input 

543 """ 

544 hasher = hashlib.sha1() 

545 

546 def update_num(n): 

547 hasher.update(compat.as_bytes("%x" % n)) 

548 

549 def update_str(s): 

550 update_num(len(s)) 

551 hasher.update(compat.as_bytes(s)) 

552 

553 def update_strs(slist): 

554 update_num(len(slist)) 

555 for s in slist: 

556 update_str(s) 

557 

558 for adef in input_arg: 

559 update_str(adef.SerializeToString()) 

560 

561 for adef in output_arg: 

562 update_str(adef.SerializeToString()) 

563 

564 for n in sorted(node_def, key=lambda n: n.name): 

565 update_str(n.name) 

566 update_str(n.op) 

567 update_strs(n.input) 

568 update_num(len(n.attr)) 

569 # NOTE: protobuf map serialization does not guarantee ordering. 

570 for k in sorted(n.attr): 

571 update_str(k) 

572 update_str(n.attr[k].SerializeToString()) 

573 

574 return hasher.hexdigest()[:8] 

575 

576 def add_to_graph(self, g): 

577 """Adds this function into the graph g.""" 

578 self._create_definition_if_needed() 

579 

580 # Adds this function into 'g'. 

581 # pylint: disable=protected-access 

582 if context.executing_eagerly(): 

583 context.context().add_function_def(self.definition) 

584 else: 

585 g._add_function(self) 

586 # pylint: enable=protected-access 

587 

588 # Ensures related sub-routines are defined in 'g', too. 

589 for f in self._sub_functions.values(): 

590 g._add_function_recursive(f) # pylint: disable=protected-access 

591 

592 # Adds its gradient function, too. 

593 if self._grad_func: 

594 self._grad_func.add_to_graph(g) 

595 

596 def __call__(self, *args, **kwargs): 

597 self.add_to_graph(ops.get_default_graph()) 

598 args = [ops.convert_to_tensor(_) for _ in args] + self._extra_inputs 

599 ret, op = _call(self._signature, *args, **kwargs) 

600 

601 # Set a hidden attr in 'op' so that gradients_impl can refer back 

602 # to this _DefinedFunction instance to access python_grad_func. 

603 assert isinstance(op, ops.Operation) 

604 setattr(op, "__defun", self) 

605 

606 if self._shape_func is not None: 

607 shapes = self._shape_func(op) 

608 if len(shapes) != len(op.outputs): 

609 raise ValueError(f"shape_func {self._shape_func} produced " 

610 f"{len(shapes):d} shapes, which does not match " 

611 f"{len(op.outputs)} outputs.") 

612 for (t, shape) in zip(op.outputs, shapes): 

613 t.set_shape(shape) 

614 return ret 

615 

616 

617class _OverloadedFunction(object): 

618 """_OverloadedFunction encapsulates an overloaded function. 

619 

620 _OverloadedFunction maintains a mapping from input types to 

621 instantiated _DefinedFunction in self._overload. 

622 

623 """ 

624 

625 def __init__(self, 

626 func, 

627 argnames, 

628 func_name=None, 

629 grad_func=None, 

630 python_grad_func=None, 

631 out_names=None, 

632 **kwargs): 

633 """Creates _DefinedFunction. 

634 

635 Args: 

636 func: A python callable which constructs a tf function body. 

637 argnames: A list of strings for function argument names. 

638 func_name: The function name. Defaults to None, in which derives from 

639 'func'. 

640 grad_func: This function's gradient function, if not None. Defaults 

641 to None. 

642 python_grad_func: A python callable implementing the gradient of 

643 the function python-side. 

644 out_names: A list of strings for the function return value names. 

645 **kwargs: The keyword arguments. **kwargs is passed to every call 

646 site of this function. 

647 

648 Raises: 

649 ValueError: The function definition is invalid. 

650 

651 """ 

652 self._func = func 

653 self._argnames = argnames 

654 self._func_name = func_name 

655 assert grad_func is None or isinstance(grad_func, _OverloadedFunction) 

656 self._grad_func = grad_func 

657 self._python_grad_func = python_grad_func 

658 self._out_names = out_names 

659 self._extra_kwargs = kwargs 

660 self._overload = {} 

661 

662 def instantiate(self, input_types): 

663 """Instantiate this function given input argument types. 

664 

665 Args: 

666 input_types: A list of data types for the inputs. 

667 

668 Returns: 

669 _DefinedFunction for the given input types. 

670 

671 """ 

672 # Stringify the type list. 

673 key = _type_list_to_str(input_types) 

674 defined = self._overload.get(key) 

675 if not defined: 

676 # If not defined yet, define the function given the input types. 

677 name = self._func_name 

678 if name is not None: 

679 name = "_".join([name, key]) 

680 defined = _DefinedFunction( 

681 self._func, 

682 self._argnames, 

683 input_types, 

684 name, 

685 None, 

686 self._python_grad_func, 

687 out_names=self._out_names, 

688 **self._extra_kwargs) 

689 _ = defined.name # Fully instantiate the function definition. 

690 if self._grad_func: 

691 # If _grad_func is given, it is another 

692 # _OverloadedFunction. We need to instantiate it with the 

693 # right input types. 

694 output_types = [ 

695 dtypes.DType(_.type) for _ in defined._signature.output_arg # pylint: disable=protected-access 

696 ] 

697 # pylint: disable=protected-access 

698 defined._grad_func = self._grad_func.instantiate(input_types + 

699 output_types) 

700 # pylint: enable=protected-access 

701 self._overload[key] = defined 

702 return defined 

703 

704 def __call__(self, *args, **kwargs): 

705 input_types = [] 

706 args = list(args) 

707 for (i, x) in enumerate(args): 

708 x = ops.convert_to_tensor(x) 

709 if not isinstance(x, ops.Tensor): 

710 raise ValueError(f"Expected a Tensor but got {x} with type {type(x)}.") 

711 input_types.append(x.dtype) 

712 args[i] = x 

713 return self.instantiate(input_types)(*args, **kwargs) 

714 

715 

716class _FuncGraph(ops.Graph): 

717 """A helper for constructing a function. 

718 

719 _FuncGraph overrides ops.Graph's create_op() so that we can keep 

720 track of all inputs into every op created inside the function. If 

721 any input is from other graphs, we keep track of it in self.capture 

722 and substitute the input with a place holder. 

723 

724 Each captured input's corresponding place holder is converted into a 

725 function argument and the caller passes in the captured tensor. 

726 """ 

727 

728 def __init__(self, name, capture_by_value, allowlisted_stateful_ops, 

729 capture_resource_var_by_value, *args, **kwargs): 

730 super(_FuncGraph, self).__init__(*args, **kwargs) 

731 self._capture_by_value = capture_by_value 

732 self._allowlisted_stateful_ops = allowlisted_stateful_ops 

733 self._capture_resource_var_by_value = capture_resource_var_by_value 

734 self._building_function = True 

735 self._outer_graph = ops.get_default_graph() 

736 self._vscope = vs.get_variable_scope() 

737 self._old_custom_getter = self._vscope.custom_getter 

738 

739 # The name of the function. 

740 self.name = name 

741 # Placeholder tensors representing the inputs to this function. The tensors 

742 # are in this _FuncGraph. 

743 self.inputs = [] 

744 # Tensors that will be returned this function. The tensors are in this 

745 # _FuncGraph. 

746 self.outputs = [] 

747 # Maps external tensor -> internal tensor (e.g. input placeholder). 

748 self._captured = {} 

749 # The external tensors that have been captured as inputs and must be passed 

750 # to this function (empty if capturing by value, otherwise these are the 

751 # keys of _captured). 

752 self.extra_inputs = [] 

753 # Input placeholders that been added for captured values (empty if capturing 

754 # by value). 

755 self.extra_args = [] 

756 # Captured variables. 

757 # TODO(skyewm): is this needed? 

758 self.extra_vars = [] 

759 

760 # pylint: disable=g-doc-return-or-yield 

761 

762 @property 

763 def outer_graph(self): 

764 """The graph active when this _FuncGraph was created.""" 

765 return self._outer_graph 

766 

767 @tf_contextlib.contextmanager 

768 def container(self, container_name): 

769 """Returns a context manager that specifies the resource container to use. 

770 

771 Overridden from `tf.Graph` to update both the init_scope container 

772 and the present inner container. This is necessary to make sure setting 

773 containers applies correctly both to created variables and to stateful 

774 ops. 

775 

776 Args: 

777 container_name: container name string. 

778 

779 Returns: 

780 A context manager for defining resource containers for stateful ops, 

781 yields the container name. 

782 """ 

783 original_container = self._container 

784 # pylint: disable=protected-access 

785 with ops.init_scope(): 

786 original_init_container = ops.get_default_graph()._container 

787 try: 

788 self._container = container_name 

789 with ops.init_scope(): 

790 ops.get_default_graph()._container = container_name 

791 yield self._container 

792 finally: 

793 self._container = original_container 

794 with ops.init_scope(): 

795 ops.get_default_graph()._container = original_init_container 

796 # pylint: enable=protected-access 

797 

798 # pylint: enable=g-doc-return-or-yield 

799 

800 def getvar( 

801 self, 

802 getter, 

803 name, 

804 shape=None, 

805 dtype=None, 

806 initializer=None, 

807 reuse=None, 

808 trainable=True, 

809 collections=None, # pylint: disable=redefined-outer-name 

810 use_resource=None, 

811 **kwargs): 

812 """A custom variable getter.""" 

813 # Here, we switch the default graph to the outer graph and ask the 

814 # variable scope in which the function is defined to give us the 

815 # variable. The variable is stashed in extra_vars and returned to 

816 # the caller. 

817 # 

818 # We capture these variables so that the variable definition is 

819 # hoisted upward to the outer most graph. 

820 with self._outer_graph.as_default(): 

821 # pylint: disable=protected-access 

822 var = self._vscope.get_variable( 

823 vs._get_default_variable_store(), 

824 name, 

825 shape=shape, 

826 dtype=dtype, 

827 initializer=initializer, 

828 reuse=reuse, 

829 trainable=trainable, 

830 collections=collections, 

831 use_resource=use_resource) 

832 self.extra_vars.append(var) 

833 if (isinstance(var, resource_variable_ops.BaseResourceVariable) and 

834 self._capture_resource_var_by_value): 

835 # For resource-based variables read the variable outside the function 

836 # and pass in the value. This ensures that the function is pure and 

837 # differentiable. TODO(apassos) this may have performance problems if 

838 # the function will only do embedding lookups on the variable. 

839 return var.value() 

840 return var 

841 

842 def _create_op_internal( 

843 self, 

844 op_type, 

845 inputs, 

846 dtypes=None, # pylint: disable=redefined-outer-name 

847 input_types=None, 

848 name=None, 

849 attrs=None, 

850 op_def=None, 

851 compute_device=True): 

852 for i, x in enumerate(inputs): 

853 if isinstance(x, ops.EagerTensor) or x.graph is not self: 

854 inputs[i] = self.capture(x) 

855 return super(_FuncGraph, self)._create_op_internal( 

856 op_type, 

857 inputs, 

858 dtypes=dtypes, 

859 input_types=input_types, 

860 name=name, 

861 attrs=attrs, 

862 op_def=op_def, 

863 compute_device=compute_device) 

864 

865 def capture(self, tensor, name=None): 

866 """Adds the given tensor to this graph and returns the captured tensor.""" 

867 if tensor.ref() in self._captured: 

868 # Captured already. 

869 return self._captured[tensor.ref()] 

870 elif self._capture_by_value: 

871 return self._add_tensor_and_parents(tensor) 

872 else: 

873 return self._capture_tensor_as_extra_input(tensor, name) 

874 

875 @property 

876 def captures(self): 

877 """Pairs of tensors and captured tensor.""" 

878 return [(k.deref(), v) for k, v in self._captured.items()] 

879 

880 def _capture_tensor_as_extra_input(self, tensor, name=None): 

881 # Substitute with a placeholder. 

882 self.extra_inputs.append(tensor) 

883 # Hoist the new input placeholder out of any control flow context 

884 # we're currently in. 

885 with ops.control_dependencies(None): 

886 ph = array_ops.placeholder( 

887 tensor.dtype, shape=tensor.get_shape(), name=name) 

888 # pylint: disable=protected-access 

889 if isinstance(tensor, ops.EagerTensor): 

890 handle_data = tensor._handle_data 

891 if handle_data: 

892 handle_data = handle_data.SerializeToString() 

893 else: 

894 with tensor.graph._c_graph.get() as c_graph: 

895 handle_data = c_api.GetHandleShapeAndType(c_graph, 

896 tensor._as_tf_output()) 

897 

898 if handle_data: 

899 with ph.graph._c_graph.get() as c_graph: 

900 c_api.SetHandleShapeAndType(c_graph, ph._as_tf_output(), 

901 compat.as_bytes(handle_data)) 

902 # pylint: enable=protected-access 

903 self.inputs.append(ph) 

904 self._captured[tensor.ref()] = ph 

905 self.extra_args.append(ph) 

906 if _is_guaranteed_const(tensor): 

907 with ops.control_dependencies(None): 

908 return array_ops.guarantee_const(ph) 

909 else: 

910 return ph 

911 

912 def _add_tensor_and_parents(self, tensor): 

913 op = self._add_op_and_parents(tensor.op) 

914 return op.outputs[tensor.value_index] 

915 

916 def _add_op_and_parents(self, op): 

917 # pylint: disable=protected-access 

918 op_def = graph_to_function_def._get_op_def(op) 

919 if op._is_stateful and op not in self._allowlisted_stateful_ops: 

920 raise ValueError(f"Cannot capture a stateful node (name:{op.name}, " 

921 f"type:{op.type}) by value.") 

922 elif op.type in ("Placeholder", "PlaceholderV2"): 

923 raise ValueError(f"Cannot capture a placeholder (name:{op.name}, " 

924 f"type:{op.type}) by value.") 

925 # pylint: enable=protected-access 

926 

927 captured_inputs = [self._add_tensor_and_parents(x) for x in op.inputs] 

928 

929 captured_op = self._create_op_internal( 

930 op.type, 

931 captured_inputs, [o.dtype for o in op.outputs], 

932 name=op.name, 

933 attrs=op.node_def.attr, 

934 op_def=op_def) 

935 

936 for t, captured_t in zip(op.outputs, captured_op.outputs): 

937 self._captured[t.ref()] = captured_t 

938 

939 return captured_op 

940 

941 

942def func_graph_from_py_func(func, 

943 arg_names, 

944 arg_types, 

945 name=None, 

946 capture_by_value=False, 

947 device=None, 

948 colocation_stack=None, 

949 container=None, 

950 collections_ref=None, 

951 arg_shapes=None, 

952 allowlisted_stateful_ops=None, 

953 capture_resource_var_by_value=True): 

954 """Returns a _FuncGraph generated from `func`. 

955 

956 Args: 

957 func: A Python callable which constructs a TF function body. The arguments 

958 must correspond to `arg_types`. Returns a value or list/tuple of values. 

959 No returned value can be None. 

960 arg_names: A sequence of strings for the function argument names. 

961 arg_types: A sequence of the function's argument types. 

962 name: The function name. If None, the name is derived from `func`. 

963 capture_by_value: boolean. If True, captured values will be copied into the 

964 function body. 

965 device: device name or function. 

966 colocation_stack: A colocation stack (list) the _FuncGraph should use. 

967 container: A container name the _FuncGraph should start with. 

968 collections_ref: A reference to a collections dict the _FuncGraph should 

969 use internally. 

970 arg_shapes: A sequence of the function's argument shapes. 

971 allowlisted_stateful_ops: A set of ops that if stateful we ignore and 

972 re-create. 

973 capture_resource_var_by_value: Boolean (defaults to True). If False, 

974 captured resource variable returns the handle instead of value. 

975 

976 Returns: 

977 A _FuncGraph. 

978 

979 Raises: 

980 ValueError: if func returns None. 

981 """ 

982 if not name: 

983 name = function_utils.get_func_name(func) 

984 func_graph = _FuncGraph(name, capture_by_value, allowlisted_stateful_ops, 

985 capture_resource_var_by_value) 

986 

987 with func_graph.as_default(), ops.device(device): 

988 # pylint: disable=protected-access 

989 if collections_ref is not None: 

990 func_graph._collections = collections_ref 

991 if container is not None: 

992 func_graph._container = container 

993 if colocation_stack is not None: 

994 func_graph._colocation_stack = colocation_stack 

995 # pylint: enable=protected-access 

996 

997 if arg_shapes is None: 

998 arg_shapes = [None] * len(arg_types) 

999 

1000 # Create placeholders for the function arguments. 

1001 for (argname, argtype, argshape) in zip(arg_names, arg_types, arg_shapes): 

1002 argholder = array_ops.placeholder(argtype, shape=argshape, name=argname) 

1003 func_graph.inputs.append(argholder) 

1004 # Call func and gather the output tensors. 

1005 with vs.variable_scope("", custom_getter=func_graph.getvar): 

1006 outputs = func(*func_graph.inputs) 

1007 

1008 # There is no way of distinguishing between a function not returning 

1009 # anything and a function returning None in Python. 

1010 # We need to allow the former and ideally want to forbid the latter as 

1011 # it is most likely user error. 

1012 # TODO(iga): Consider adding a @NoOutput decorator on top of @Defun to 

1013 # allow users to explicitly mark the function as not returning anything. 

1014 # For now, we allow a single None return and interpret it as a function 

1015 # with no output. 

1016 if outputs is None: 

1017 outputs = [] 

1018 else: 

1019 # If func only returned one value, make it a tuple. 

1020 if not isinstance(outputs, (list, tuple)): 

1021 outputs = (outputs,) 

1022 if any(_ is None for _ in outputs): 

1023 raise ValueError(f"Function {name} can not return None.") 

1024 # Ensures each output is a Tensor in the function graph. 

1025 outputs = [ops.convert_to_tensor(t) for t in outputs] 

1026 outputs = [func_graph.capture(t) if t.graph is not func_graph else t 

1027 for t in outputs] 

1028 func_graph.outputs = outputs 

1029 return func_graph 

1030 

1031 

1032def _is_guaranteed_const(tensor): 

1033 """Determines whether `tensor` is guaranteed to be a constant. 

1034 

1035 A tensor is guaranteed to be a constant if either it was produced by 

1036 a `GuaranteeConst` op or if all of its children are guaranteed to be 

1037 constants. 

1038 

1039 Args: 

1040 tensor: The tensor for which to determine const-ness. 

1041 

1042 Returns: 

1043 True if `tensor` is guaranteed to be a constant, False otherwise. 

1044 """ 

1045 

1046 if isinstance(tensor, ops.EagerTensor): 

1047 return False 

1048 

1049 class Work(object): 

1050 

1051 def __init__(self, op, leaving): 

1052 self.op = op 

1053 self.leaving = leaving 

1054 

1055 is_guaranteed_const = lambda op: op.node_def.op == "GuaranteeConst" 

1056 constants = set([]) 

1057 def all_inputs_const(op): 

1058 # If all inputs of an op are guaranteed constants, then we can infer that 

1059 # the op produces a constant as well. 

1060 return op.inputs and all(inp.op in constants for inp in op.inputs) 

1061 

1062 visited = set([]) 

1063 stack = [Work(tensor.op, leaving=False)] 

1064 while stack: 

1065 work = stack.pop() 

1066 if work.leaving: 

1067 if all_inputs_const(work.op): 

1068 constants.add(work.op) 

1069 continue 

1070 visited.add(work.op) 

1071 if is_guaranteed_const(work.op): 

1072 constants.add(work.op) 

1073 continue 

1074 

1075 # This op will be revisited after all its inputs are checked for const-ness. 

1076 stack.append(Work(work.op, leaving=True)) 

1077 for inp in work.op.inputs: 

1078 if inp.op not in visited: 

1079 stack.append(Work(inp.op, leaving=False)) 

1080 return tensor.op in constants 

1081 

1082 

1083def _call(sig, *inputs, **kwargs): 

1084 """Adds a node calling a function. 

1085 

1086 This adds a `call` op to the default graph that calls the function 

1087 of signature `sig`, passing the tensors in `inputs` as arguments. 

1088 It returns the outputs of the call, which are one or more tensors. 

1089 

1090 `sig` is OpDefArg.a `_DefinedFunction` object. 

1091 

1092 You can pass an optional keyword parameter `name=string` to name the 

1093 added operation. 

1094 

1095 You can pass an optional keyword parameter `noinline=True|False` to 

1096 instruct the runtime not to inline the function body into the call 

1097 site. 

1098 

1099 Args: 

1100 sig: OpDefArg. The signature of the function. 

1101 *inputs: arguments to the function. 

1102 **kwargs: Optional keyword arguments. Can only contain 'name' or 

1103 'noinline'. 

1104 

1105 Returns: 

1106 A 2-element tuple. First element: a Tensor if the function returns a single 

1107 value; a list of Tensors if the function returns multiple value; the 

1108 Operation if the function returns no values. Second element: the Operation. 

1109 

1110 Raises: 

1111 ValueError: if the arguments are invalid. 

1112 """ 

1113 if len(inputs) != len(sig.input_arg): 

1114 raise ValueError(f"Expected {len(sig.input_arg):d} arguments, got " 

1115 f"{len(inputs):d}.") 

1116 name = kwargs.pop("name", None) 

1117 g = ops.get_default_graph() 

1118 func_name = sig.name 

1119 if name is None: 

1120 name = func_name 

1121 attrs = _parse_kwargs_as_attrs(func_name, **kwargs) 

1122 output_types = [dtypes.DType(x.type) for x in sig.output_arg] 

1123 op = g._create_op_internal( # pylint: disable=protected-access 

1124 func_name, list(inputs), output_types, name=name, attrs=attrs, op_def=sig) 

1125 if op.outputs: 

1126 if len(op.outputs) == 1: 

1127 ret = op.outputs[0] 

1128 else: 

1129 ret = tuple(op.outputs) 

1130 else: 

1131 ret = op 

1132 return ret, op 

1133 

1134 

1135def _from_definition(fdef, grad_func=None): 

1136 """Creates a _DefinedFunction initialized from a FunctionDef proto. 

1137 

1138 Args: 

1139 fdef: a FunctionDef 

1140 grad_func: a _DefinedFunction or None 

1141 

1142 Returns: 

1143 A _DefinedFunction representing fdef 

1144 """ 

1145 # TODO(iga): This method does major surgery on _DefinedFunction. 

1146 # Make it a named constructor using @classmethod of _DefinedFunction. 

1147 

1148 # The Python callable is only needed to create a FunctionDef. Since we have 

1149 # the FunctionDef here, we don't need to set _DefinedFunction._func (nor do we 

1150 # have access to such a callable here). 

1151 func = None 

1152 argnames = [arg.name for arg in fdef.signature.input_arg] 

1153 input_types = tuple( 

1154 dtypes.as_dtype(arg.type) for arg in fdef.signature.input_arg) 

1155 func_name = fdef.signature.name 

1156 # Note: FunctionDefs do not include python gradient functions, so if the 

1157 # original _DefinedFunction included one it will not be reflected here. 

1158 python_grad_func = None 

1159 out_names = [arg.name for arg in fdef.signature.output_arg] 

1160 result = _DefinedFunction(func, argnames, input_types, func_name, grad_func, 

1161 python_grad_func, out_names) 

1162 # pylint: disable=protected-access 

1163 serialized = fdef.SerializeToString() 

1164 c_func = c_api.TF_FunctionImportFunctionDef(serialized) 

1165 result._c_func = c_api_util.ScopedTFFunction(c_func, func_name) 

1166 result._extra_inputs = [] 

1167 result._op_def = fdef.signature 

1168 # pylint: enable=protected-access 

1169 

1170 return result 

1171 

1172 

1173def from_library(lib): 

1174 """Creates _DefinedFunctions initialized from a FunctionDefLibrary proto. 

1175 

1176 This method handles assigning the correct gradient functions to each 

1177 function. 

1178 

1179 Args: 

1180 lib: a FunctionDefLibrary 

1181 

1182 Returns: 

1183 A list of _DefinedFunctions 

1184 

1185 Raises: 

1186 ValueError: `lib` is invalid 

1187 """ 

1188 if not lib.function and not lib.gradient: 

1189 return [] 

1190 

1191 # function name -> FunctionDef proto 

1192 funcs = {fdef.signature.name: fdef for fdef in lib.function} 

1193 

1194 # Validate that all references function names have function defs 

1195 for g in lib.gradient: 

1196 if g.function_name not in funcs: 

1197 raise ValueError(f"FunctionDefLibrary missing '{g.function_name}' " 

1198 f"FunctionDef\n{lib}") 

1199 if g.gradient_func not in funcs: 

1200 raise ValueError(f"FunctionDefLibrary missing '{g.gradient_func}' " 

1201 f"FunctionDef\n{lib}") 

1202 

1203 # function name -> gradient function name 

1204 func_to_grad = collections.defaultdict(lambda: None) 

1205 # gradient function name -> names of functions having that grad function 

1206 grad_to_funcs = collections.defaultdict(list) 

1207 

1208 for gdef in lib.gradient: 

1209 func_to_grad[gdef.function_name] = gdef.gradient_func 

1210 grad_to_funcs[gdef.gradient_func].append(gdef.function_name) 

1211 

1212 # Start with functions without gradients 

1213 ready = [ 

1214 fdef for fdef in lib.function if func_to_grad[fdef.signature.name] is None 

1215 ] 

1216 if not ready: 

1217 raise ValueError( 

1218 f"FunctionDefLibrary contains cyclic gradient functions!\n{lib}") 

1219 # function name -> _DefinedFunction 

1220 initialized = {} 

1221 

1222 while ready: 

1223 fdef = ready.pop() 

1224 name = fdef.signature.name 

1225 

1226 grad = initialized.get(func_to_grad[name]) 

1227 if func_to_grad[name]: 

1228 assert grad 

1229 defined_func = _from_definition(fdef, grad_func=grad) 

1230 initialized[name] = defined_func 

1231 

1232 ready.extend(funcs[f] for f in grad_to_funcs[name]) 

1233 

1234 return initialized.values() 

1235 

1236 

1237def _get_experimental_kwarg_as_attr(attr_name, value): 

1238 """Creates an AttrValue for a python object.""" 

1239 if isinstance(value, bool): 

1240 return attr_value_pb2.AttrValue(b=value) 

1241 elif isinstance(value, int): 

1242 return attr_value_pb2.AttrValue(i=value) 

1243 elif isinstance(value, float): 

1244 return attr_value_pb2.AttrValue(f=value) 

1245 elif isinstance(value, str): 

1246 return attr_value_pb2.AttrValue(s=compat.as_bytes(value)) 

1247 else: 

1248 raise ValueError(f"Attribute {attr_name} must be bool, int, float, or " 

1249 f"str. Got {type(value)}.") 

1250 

1251 

1252def _get_kwarg_as_str_attr(attr_name, value): 

1253 """Creates an AttrValue for a python object.""" 

1254 if isinstance(value, str): 

1255 return attr_value_pb2.AttrValue(s=compat.as_bytes(value)) 

1256 else: 

1257 raise ValueError(f"Attribute {attr_name} must be str. Got {type(value)}.") 

1258 

1259 

1260def _parse_kwargs_as_attrs(func_name, **kwargs): 

1261 """Parses **kwargs into a node's attributes.""" 

1262 attrs = {} 

1263 

1264 noinline = kwargs.pop("noinline", None) 

1265 if noinline is not None: 

1266 attrs["_noinline"] = attr_value_pb2.AttrValue(b=bool(noinline)) 

1267 

1268 # For compatibility with previous behavior, Defun does not perform shape 

1269 # inference through its function call operations. 

1270 attrs["_disable_call_shape_inference"] = attr_value_pb2.AttrValue(b=True) 

1271 

1272 compiled = kwargs.pop("compiled", None) 

1273 separate_compiled_gradients = kwargs.pop("separate_compiled_gradients", None) 

1274 if compiled is not None: 

1275 attrs["_XlaCompile"] = attr_value_pb2.AttrValue(b=bool(compiled)) 

1276 attrs["_XlaSeparateCompiledGradients"] = attr_value_pb2.AttrValue( 

1277 b=bool(separate_compiled_gradients)) 

1278 # Forward _XlaScope from enclosing context (if set), otherwise create new. 

1279 # pylint: disable=protected-access 

1280 if "_XlaScope" in ops.get_default_graph()._attr_scope_map: 

1281 attrs["_XlaScope"] = ops.get_default_graph()._attr_scope_map["_XlaScope"] 

1282 else: 

1283 attrs["_XlaScope"] = attr_value_pb2.AttrValue( 

1284 s=("function_%s" % func_name).encode()) 

1285 # pylint: enable=protected-access 

1286 

1287 kwargs_keys = list(kwargs.keys()) 

1288 for key in kwargs_keys: 

1289 if key.startswith("experimental_"): 

1290 attrs[key] = _get_experimental_kwarg_as_attr(key, kwargs[key]) 

1291 del kwargs[key] 

1292 # Support for https://github.com/tensorflow/community/pull/113/files. 

1293 elif key == "_implements" or key == "_reference": 

1294 attrs[key] = _get_kwarg_as_str_attr(key, kwargs[key]) 

1295 del kwargs[key] 

1296 if kwargs: 

1297 raise ValueError(f"Unknown keyword arguments: {kwargs.keys()}.") 

1298 return attrs 

1299 

1300 

1301def get_extra_vars(): 

1302 """Returns the captured variables by the function. 

1303 

1304 Returns: 

1305 If the default graph is being used to define a function, the 

1306 returned list of variables are those created inside the function 

1307 body so far. Otherwise, returns an empty list. 

1308 """ 

1309 g = ops.get_default_graph() 

1310 if isinstance(g, _FuncGraph): 

1311 return g.extra_vars 

1312 else: 

1313 return [] 

1314 

1315 

1316def get_extra_inputs(): 

1317 """Returns the captured input tensors by the function. 

1318 

1319 Returns: 

1320 If the default graph is being used to define a function, the 

1321 returned list of tensors are those accessed inside the function body 

1322 but defined outside the function body so far. Otherwise, returns an 

1323 empty list. 

1324 """ 

1325 g = ops.get_default_graph() 

1326 if isinstance(g, _FuncGraph): 

1327 return g.extra_inputs 

1328 else: 

1329 return [] 

1330 

1331 

1332def get_extra_args(): 

1333 """Returns the corresponding function arguments for the captured inputs. 

1334 

1335 Returns: 

1336 If the default graph is being used to define a function, the 

1337 returned list of place holders are those used inside the function 

1338 body corresponding those returned by get_extra_inputs(). Otherwise, 

1339 returns an empty list. 

1340 """ 

1341 g = ops.get_default_graph() 

1342 if isinstance(g, _FuncGraph): 

1343 return g.extra_args 

1344 else: 

1345 return [] 

1346 

1347 

1348def _type_list_to_str(types): 

1349 if any(_ not in _DTYPE_TO_STR for _ in types): 

1350 unsupported_types = [type_ for type_ in types if type_ not in _DTYPE_TO_STR] 

1351 raise ValueError(f"Unsupported dtypes {unsupported_types} in " 

1352 "`types`. Supported dtypes are " 

1353 f"{_DTYPE_TO_STR.keys()}.") 

1354 return "".join(_DTYPE_TO_STR[_] for _ in types) 

1355 

1356 

1357# NOTE: The list needs to be extended when more data types are added. 

1358_DTYPE_TO_STR = { 

1359 dtypes.float16: "f16", 

1360 dtypes.float32: "f32", 

1361 dtypes.float64: "f64", 

1362 dtypes.int32: "i32", 

1363 dtypes.uint8: "i8", 

1364 dtypes.uint16: "u16", 

1365 dtypes.uint32: "u32", 

1366 dtypes.uint64: "u64", 

1367 dtypes.int16: "i16", 

1368 dtypes.int8: "i8", 

1369 dtypes.string: "s", 

1370 dtypes.complex64: "c64", 

1371 dtypes.complex128: "c128", 

1372 dtypes.int64: "i64", 

1373 dtypes.bool: "b", 

1374 dtypes.qint8: "qi8", 

1375 dtypes.quint8: "qu8", 

1376 dtypes.qint16: "qi16", 

1377 dtypes.quint16: "qu16", 

1378 dtypes.qint32: "qi32", 

1379 dtypes.bfloat16: "b16", 

1380 dtypes.float8_e5m2: "f8e5m2", 

1381 dtypes.float8_e4m3fn: "f8e4m3fn" 

1382}