Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/eager/polymorphic_function/function_spec.py: 51%
214 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Defines an input type specification for tf.function."""
17import functools
18import inspect
19from typing import Any, Dict, Tuple
21import numpy as np
22import six
24from tensorflow.core.function import trace_type
25from tensorflow.core.function.polymorphism import function_type as function_type_lib
26from tensorflow.python.eager.polymorphic_function import composite_tensor_utils
27from tensorflow.python.framework import composite_tensor
28from tensorflow.python.framework import constant_op
29from tensorflow.python.framework import ops
30from tensorflow.python.framework import tensor_spec
31from tensorflow.python.framework import type_spec
32from tensorflow.python.ops import resource_variable_ops
33from tensorflow.python.util import nest
35# Sentinel value used by with ConcreteFunction's structured signature to
36# indicate that a non-tensor parameter should use the value that was
37# specified when the concrete function was created.
38BOUND_VALUE = object()
41def to_fullargspec(function_type: function_type_lib.FunctionType,
42 default_values: Dict[str, Any]) -> inspect.FullArgSpec:
43 """Generates backwards compatible FullArgSpec from FunctionType."""
44 args = []
45 varargs = None
46 varkw = None
47 defaults = []
48 kwonlyargs = []
49 kwonlydefaults = {}
51 for parameter in function_type.parameters.values():
52 if parameter.kind in [
53 inspect.Parameter.POSITIONAL_ONLY,
54 inspect.Parameter.POSITIONAL_OR_KEYWORD
55 ]:
56 args.append(parameter.name)
57 if parameter.default is not inspect.Parameter.empty:
58 defaults.append(default_values[parameter.name])
59 elif parameter.kind is inspect.Parameter.KEYWORD_ONLY:
60 kwonlyargs.append(parameter.name)
61 if parameter.default is not inspect.Parameter.empty:
62 kwonlydefaults[parameter.name] = default_values[parameter.name]
63 elif parameter.kind is inspect.Parameter.VAR_POSITIONAL:
64 varargs = parameter.name
65 elif parameter.kind is inspect.Parameter.VAR_KEYWORD:
66 varkw = parameter.name
68 return inspect.FullArgSpec(
69 args,
70 varargs,
71 varkw,
72 tuple(defaults) if defaults else None,
73 kwonlyargs,
74 kwonlydefaults if kwonlydefaults else None,
75 annotations={})
78def _to_default_values(fullargspec):
79 """Returns default values from the function's inspected fullargspec."""
80 if fullargspec.defaults is not None:
81 defaults = {
82 name: value for name, value in zip(
83 fullargspec.args[-len(fullargspec.defaults):], fullargspec.defaults)
84 }
85 else:
86 defaults = {}
88 if fullargspec.kwonlydefaults is not None:
89 defaults.update(fullargspec.kwonlydefaults)
91 defaults = {
92 function_type_lib.sanitize_arg_name(name): value
93 for name, value in defaults.items()
94 }
96 return defaults
99def to_function_type(fullargspec):
100 """Generates FunctionType and default values from fullargspec."""
101 default_values = _to_default_values(fullargspec)
102 parameters = []
104 for arg in fullargspec.args:
105 arg_name = function_type_lib.sanitize_arg_name(arg)
106 parameters.append(
107 function_type_lib.Parameter(
108 arg_name, function_type_lib.Parameter.POSITIONAL_OR_KEYWORD,
109 arg_name in default_values, None))
111 if fullargspec.varargs is not None:
112 parameters.append(
113 function_type_lib.Parameter(fullargspec.varargs,
114 function_type_lib.Parameter.VAR_POSITIONAL,
115 False, None))
117 for kwarg in fullargspec.kwonlyargs:
118 parameters.append(
119 function_type_lib.Parameter(
120 function_type_lib.sanitize_arg_name(kwarg),
121 function_type_lib.Parameter.KEYWORD_ONLY, kwarg in default_values,
122 None))
124 if fullargspec.varkw is not None:
125 parameters.append(
126 function_type_lib.Parameter(fullargspec.varkw,
127 function_type_lib.Parameter.VAR_KEYWORD,
128 False, None))
130 return function_type_lib.FunctionType(parameters), default_values
133def to_input_signature(function_type):
134 """Extracts an input_signature from function_type instance."""
135 constrained_parameters = list(function_type.parameters.keys())
137 # self does not have a constraint in input_signature
138 if "self" in constrained_parameters:
139 constrained_parameters.pop(0)
141 # There are no parameters to constrain.
142 if not constrained_parameters:
143 return tuple()
145 constraints = []
146 is_auto_constrained = False
148 for parameter_name in constrained_parameters:
149 parameter = function_type.parameters[parameter_name]
150 constraint = None
151 if parameter.type_constraint:
152 # Generate legacy constraint representation.
153 constraint = parameter.type_constraint.placeholder_value(
154 trace_type.InternalPlaceholderContext(unnest_only=True)
155 )
156 if any(
157 not isinstance(arg, tensor_spec.TensorSpec)
158 for arg in nest.flatten([constraint], expand_composites=True)):
159 # input_signature only supports contiguous TensorSpec composites
160 is_auto_constrained = True
161 break
162 else:
163 constraints.append(constraint)
165 # All constraints were generated by FunctionType
166 if is_auto_constrained and not constraints:
167 return tuple()
169 # If the list is empty then there was no input_signature specified.
170 return tuple(constraints) if constraints else None
173# TODO(b/214462107): Clean up and migrate to core/function when unblocked.
174class FunctionSpec(object):
175 """Specification of how to bind arguments to a function."""
177 @classmethod
178 def from_function_and_signature(cls,
179 python_function,
180 input_signature,
181 is_pure=False,
182 jit_compile=None):
183 """Creates a FunctionSpec instance given a python function and signature.
185 Args:
186 python_function: a function to inspect
187 input_signature: a signature of the function (None, if variable)
188 is_pure: if True all input arguments (including variables and constants)
189 will be converted to tensors and no variable changes allowed.
190 jit_compile: see `tf.function`
192 Returns:
193 instance of FunctionSpec
194 """
195 _validate_signature(input_signature)
197 function_type = function_type_lib.FunctionType.from_callable(
198 python_function)
199 default_values = function_type_lib.FunctionType.get_default_values(
200 python_function)
202 if input_signature is not None:
203 input_signature = tuple(input_signature)
204 function_type = function_type_lib.add_type_constraints(
205 function_type, input_signature, default_values)
207 # Get the function's name. Remove functools.partial wrappers if necessary.
208 while isinstance(python_function, functools.partial):
209 python_function = python_function.func
210 name = getattr(python_function, "__name__", "f")
212 return FunctionSpec(
213 function_type,
214 default_values,
215 is_pure=is_pure,
216 jit_compile=jit_compile,
217 name=name)
219 @classmethod
220 def from_fullargspec_and_signature(cls,
221 fullargspec,
222 input_signature,
223 is_pure=False,
224 name=None,
225 jit_compile=None):
226 """Construct FunctionSpec from legacy FullArgSpec format."""
227 function_type, default_values = to_function_type(fullargspec)
228 if input_signature:
229 input_signature = tuple(input_signature)
230 _validate_signature(input_signature)
231 function_type = function_type_lib.add_type_constraints(
232 function_type, input_signature, default_values)
234 return FunctionSpec(function_type, default_values, is_pure,
235 name, jit_compile)
237 def __init__(self,
238 function_type,
239 default_values,
240 is_pure=False,
241 name=None,
242 jit_compile=None):
243 """Constructs a FunctionSpec describing a python function.
245 Args:
246 function_type: A FunctionType describing the python function signature.
247 default_values: Dictionary mapping parameter names to default values.
248 is_pure: if True all input arguments (including variables and constants)
249 will be converted to tensors and no variable changes allowed.
250 name: Name of the function
251 jit_compile: see `tf.function`.
252 """
253 self._function_type = function_type
254 self._default_values = default_values
255 self._fullargspec = to_fullargspec(function_type, default_values)
256 self._is_pure = is_pure
257 self._jit_compile = jit_compile
259 # TODO(edloper): Include name when serializing for SavedModel?
260 self._name = name or "f"
261 self._input_signature = to_input_signature(function_type)
263 @property
264 def default_values(self):
265 """Returns dict mapping parameter names to default values."""
266 return self._default_values
268 @property
269 def function_type(self):
270 """Returns a FunctionType representing the Python function signature."""
271 return self._function_type
273 @property
274 def fullargspec(self):
275 return self._fullargspec
277 # TODO(fmuham): Replace usages with FunctionType and remove.
278 @property
279 def input_signature(self):
280 return self._input_signature
282 # TODO(fmuham): Replace usages with FunctionType and remove.
283 @property
284 def flat_input_signature(self):
285 return tuple(nest.flatten(self.input_signature, expand_composites=True))
287 @property
288 def is_pure(self):
289 return self._is_pure
291 @property
292 def jit_compile(self):
293 return self._jit_compile
295 # TODO(fmuham): Replace usages and remove.
296 @property
297 def arg_names(self):
298 return list(
299 p.name
300 for p in self.function_type.parameters.values()
301 if (
302 p.kind is function_type_lib.Parameter.POSITIONAL_ONLY
303 or p.kind is function_type_lib.Parameter.POSITIONAL_OR_KEYWORD
304 )
305 )
307 def make_canonicalized_monomorphic_type(
308 self,
309 args: Any,
310 kwargs: Any,
311 captures: Any = None,
312 ) -> Tuple[function_type_lib.FunctionType,
313 trace_type.InternalTracingContext]:
314 """Generates function type given the function arguments."""
315 if captures is None:
316 captures = dict()
318 kwargs = {
319 function_type_lib.sanitize_arg_name(name): value
320 for name, value in kwargs.items()
321 }
323 _, function_type, type_context = (
324 function_type_lib.canonicalize_to_monomorphic(
325 args, kwargs, self.default_values, captures, self.function_type
326 )
327 )
329 return function_type, type_context
331 def signature_summary(self, default_values=False):
332 """Returns a string summarizing this function's signature.
334 Args:
335 default_values: If true, then include default values in the signature.
337 Returns:
338 A `string`.
339 """
340 args = list(self._arg_names)
341 if default_values:
342 for (i, default) in self._arg_indices_to_default_values.items():
343 args[i] += "={}".format(default)
344 if self._fullargspec.kwonlyargs:
345 args.append("*")
346 for arg_name in self._fullargspec.kwonlyargs:
347 args.append(arg_name)
348 if default_values and arg_name in self._fullargspec.kwonlydefaults:
349 args[-1] += "={}".format(self._fullargspec.kwonlydefaults[arg_name])
350 return f"{self._name}({', '.join(args)})"
352 def canonicalize_function_inputs(self, args, kwargs):
353 """Canonicalizes `args` and `kwargs`.
355 Canonicalize the inputs to the Python function using a `FunctionSpec`
356 instance. In particular, we parse the varargs and kwargs that the
357 original function was called with into a tuple corresponding to the
358 Python function's positional (named) arguments and a dictionary
359 corresponding to its kwargs. Missing default arguments are added.
361 If this `FunctionSpec` has an input signature, then it is used to convert
362 arguments to tensors; otherwise, any inputs containing numpy arrays are
363 converted to tensors.
365 Additionally, any inputs containing numpy arrays are converted to Tensors.
367 Args:
368 args: The varargs this object was called with.
369 kwargs: The keyword args this function was called with.
371 Returns:
372 A canonicalized ordering of the inputs, as well as full and filtered
373 (Tensors and Variables only) versions of their concatenated flattened
374 representations, represented by a tuple in the form (args, kwargs,
375 flat_args, filtered_flat_args). Here: `args` is a full list of bound
376 arguments, and `kwargs` contains only true keyword arguments, as opposed
377 to named arguments called in a keyword-like fashion.
379 Raises:
380 ValueError: If a keyword in `kwargs` cannot be matched with a positional
381 argument when an input signature is specified, or when the inputs
382 do not conform to the input signature.
383 """
384 if self.is_pure:
385 args, kwargs = _convert_variables_to_tensors(args, kwargs)
386 args, kwargs = self.bind_function_inputs(args, kwargs)
387 filtered_flat_args = filter_function_inputs(args, kwargs)
389 return args, kwargs, filtered_flat_args
391 def bind_function_inputs(self, args, kwargs):
392 """Bind `args` and `kwargs` into a canonicalized signature args, kwargs."""
393 sanitized_kwargs = {
394 function_type_lib.sanitize_arg_name(k): v for k, v in kwargs.items()
395 }
396 if len(kwargs) != len(sanitized_kwargs):
397 raise ValueError(f"Name collision after sanitization. Please rename "
398 f"tf.function input parameters. Original: "
399 f"{sorted(kwargs.keys())}, Sanitized: "
400 f"{sorted(sanitized_kwargs.keys())}")
402 try:
403 bound_arguments = self.function_type.bind_with_defaults(
404 args, sanitized_kwargs, self.default_values)
405 except Exception as e:
406 raise TypeError(
407 f"Binding inputs to tf.function `{self._name}` failed due to `{e}`. "
408 f"Received args: {args} and kwargs: {sanitized_kwargs} for signature:"
409 f" {self.function_type}."
410 ) from e
411 return bound_arguments.args, bound_arguments.kwargs
414def _validate_signature(signature):
415 """Checks the input_signature to be valid."""
416 if signature is None:
417 return
419 if not isinstance(signature, (tuple, list)):
420 raise TypeError("input_signature must be either a tuple or a list, got "
421 f"{type(signature)}.")
423 # TODO(xjun): Allow VariableSpec once we figure out API for de-aliasing.
424 variable_specs = _get_variable_specs(signature)
425 if variable_specs:
426 raise TypeError(
427 f"input_signature doesn't support VariableSpec, got {variable_specs}")
429 if any(not isinstance(arg, tensor_spec.TensorSpec)
430 for arg in nest.flatten(signature, expand_composites=True)):
431 bad_args = [
432 arg for arg in nest.flatten(signature, expand_composites=True)
433 if not isinstance(arg, tensor_spec.TensorSpec)
434 ]
435 raise TypeError("input_signature must be a possibly nested sequence of "
436 f"TensorSpec objects, got invalid args {bad_args} with "
437 f"types {list(six.moves.map(type, bad_args))}.")
440def _to_tensor_or_tensor_spec(x):
441 return (x if isinstance(x, (ops.Tensor, tensor_spec.TensorSpec)) else
442 ops.convert_to_tensor(x))
445def _convert_variables_to_tensors(args, kwargs):
446 args = [_to_tensor_or_tensor_spec(x) for x in args]
447 kwargs = {kw: _to_tensor_or_tensor_spec(x) for kw, x in kwargs.items()}
448 return tuple(args), kwargs
451# TODO(fmuham): Migrate to use TraceType/FunctionType _to_tensors.
452def filter_function_inputs(args, kwargs):
453 """Filters and flattens args and kwargs."""
454 flat_inputs = composite_tensor_utils.flatten_with_variables(
455 args) + composite_tensor_utils.flatten_with_variables(kwargs)
457 for index, flat_input in enumerate(flat_inputs):
458 if hasattr(flat_input, "__array__") and not (
459 hasattr(flat_input, "_should_act_as_resource_variable")
460 or isinstance(
461 flat_input,
462 (
463 ops.Tensor,
464 resource_variable_ops.BaseResourceVariable,
465 np.str_,
466 type,
467 composite_tensor.CompositeTensor,
468 ),
469 )
470 ):
471 ndarray = flat_input.__array__()
472 if not isinstance(ndarray, np.ndarray):
473 raise TypeError(f"The output of __array__ must be an np.ndarray, "
474 f"got {type(ndarray)} from {flat_input}.")
475 flat_inputs[index] = constant_op.constant(ndarray)
477 filtered_flat_inputs = [
478 t for t in flat_inputs
479 if isinstance(t, (ops.Tensor, resource_variable_ops.BaseResourceVariable))
480 ]
482 return filtered_flat_inputs
485def _get_variable_specs(args):
486 """Returns `VariableSpecs` from `args`."""
487 variable_specs = []
488 for arg in nest.flatten(args):
489 if not isinstance(arg, type_spec.TypeSpec):
490 continue
491 if isinstance(arg, resource_variable_ops.VariableSpec):
492 variable_specs.append(arg)
493 elif not isinstance(arg, tensor_spec.TensorSpec):
494 # arg is a CompositeTensor spec.
495 variable_specs.extend(_get_variable_specs(arg._component_specs)) # pylint: disable=protected-access
496 return variable_specs
499# TODO(fmuham): Replace usages with TraceType and remove.
500def is_same_structure(structure1, structure2, check_values=False):
501 """Check two structures for equality, optionally of types and of values."""
502 try:
503 nest.assert_same_structure(structure1, structure2, expand_composites=True)
504 except (ValueError, TypeError):
505 return False
506 if check_values:
507 flattened1 = nest.flatten(structure1, expand_composites=True)
508 flattened2 = nest.flatten(structure2, expand_composites=True)
509 # First check the types to avoid AttributeErrors.
510 if any(type(f1) is not type(f2) for f1, f2 in zip(flattened1, flattened2)):
511 return False
512 return flattened1 == flattened2
513 return True