Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/framework/tensor_conversion_registry.py: 77%
70 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Registry for tensor conversion functions."""
16# pylint: disable=g-bad-name
17import collections
18import threading
20import numpy as np
22from tensorflow.python.framework import dtypes
23from tensorflow.python.types import core
24from tensorflow.python.util.tf_export import tf_export
27_tensor_conversion_func_registry = collections.defaultdict(list)
28_tensor_conversion_func_cache = {}
29_tensor_conversion_func_lock = threading.Lock()
31# Instances of these types should only be converted by internally-registered
32# conversion functions.
33_CONSTANT_OP_CONVERTIBLES = (
34 int,
35 float,
36 np.generic,
37 np.ndarray,
38)
41# TODO(josh11b): Add ctx argument to conversion_func() signature.
42def register_tensor_conversion_function_internal(base_type,
43 conversion_func,
44 priority=100):
45 """Internal version of register_tensor_conversion_function.
47 See docstring of `register_tensor_conversion_function` for details.
49 The internal version of the function allows registering conversions
50 for types in the _UNCONVERTIBLE_TYPES tuple.
52 Args:
53 base_type: The base type or tuple of base types for all objects that
54 `conversion_func` accepts.
55 conversion_func: A function that converts instances of `base_type` to
56 `Tensor`.
57 priority: Optional integer that indicates the priority for applying this
58 conversion function. Conversion functions with smaller priority values run
59 earlier than conversion functions with larger priority values. Defaults to
60 100.
62 Raises:
63 TypeError: If the arguments do not have the appropriate type.
64 """
65 base_types = base_type if isinstance(base_type, tuple) else (base_type,)
66 if any(not isinstance(x, type) for x in base_types):
67 raise TypeError("Argument `base_type` must be a type or a tuple of types. "
68 f"Obtained: {base_type}")
69 del base_types # Only needed for validation.
70 if not callable(conversion_func):
71 raise TypeError("Argument `conversion_func` must be callable. Received "
72 f"{conversion_func}.")
74 with _tensor_conversion_func_lock:
75 _tensor_conversion_func_registry[priority].append(
76 (base_type, conversion_func))
77 _tensor_conversion_func_cache.clear()
80@tf_export("register_tensor_conversion_function")
81def register_tensor_conversion_function(base_type,
82 conversion_func,
83 priority=100):
84 """Registers a function for converting objects of `base_type` to `Tensor`.
86 The conversion function must have the following signature:
88 ```python
89 def conversion_func(value, dtype=None, name=None, as_ref=False):
90 # ...
91 ```
93 It must return a `Tensor` with the given `dtype` if specified. If the
94 conversion function creates a new `Tensor`, it should use the given
95 `name` if specified. All exceptions will be propagated to the caller.
97 The conversion function may return `NotImplemented` for some
98 inputs. In this case, the conversion process will continue to try
99 subsequent conversion functions.
101 If `as_ref` is true, the function must return a `Tensor` reference,
102 such as a `Variable`.
104 NOTE: The conversion functions will execute in order of priority,
105 followed by order of registration. To ensure that a conversion function
106 `F` runs before another conversion function `G`, ensure that `F` is
107 registered with a smaller priority than `G`.
109 Args:
110 base_type: The base type or tuple of base types for all objects that
111 `conversion_func` accepts.
112 conversion_func: A function that converts instances of `base_type` to
113 `Tensor`.
114 priority: Optional integer that indicates the priority for applying this
115 conversion function. Conversion functions with smaller priority values run
116 earlier than conversion functions with larger priority values. Defaults to
117 100.
119 Raises:
120 TypeError: If the arguments do not have the appropriate type.
121 """
122 base_types = base_type if isinstance(base_type, tuple) else (base_type,)
123 if any(not isinstance(x, type) for x in base_types):
124 raise TypeError("Argument `base_type` must be a type or a tuple of types. "
125 f"Obtained: {base_type}")
126 if any(issubclass(x, _CONSTANT_OP_CONVERTIBLES) for x in base_types):
127 raise TypeError("Cannot register conversions for Python numeric types and "
128 "NumPy scalars and arrays.")
129 del base_types # Only needed for validation.
130 register_tensor_conversion_function_internal(
131 base_type, conversion_func, priority)
134def get(query):
135 """Get conversion function for objects of `cls`.
137 Args:
138 query: The type to query for.
140 Returns:
141 A list of conversion functions in increasing order of priority.
142 """
143 conversion_funcs = _tensor_conversion_func_cache.get(query)
144 if conversion_funcs is None:
145 with _tensor_conversion_func_lock:
146 # Has another thread populated the cache in the meantime?
147 conversion_funcs = _tensor_conversion_func_cache.get(query)
148 if conversion_funcs is None:
149 conversion_funcs = []
150 for _, funcs_at_priority in sorted(
151 _tensor_conversion_func_registry.items()):
152 conversion_funcs.extend(
153 (base_type, conversion_func)
154 for base_type, conversion_func in funcs_at_priority
155 if issubclass(query, base_type))
156 _tensor_conversion_func_cache[query] = conversion_funcs
157 return conversion_funcs
160def _add_error_prefix(msg, *, name=None):
161 return msg if name is None else f"{name}: {msg}"
164def convert(value,
165 dtype=None,
166 name=None,
167 as_ref=False,
168 preferred_dtype=None,
169 accepted_result_types=(core.Symbol,)):
170 """Converts `value` to a `Tensor` using registered conversion functions.
172 Args:
173 value: An object whose type has a registered `Tensor` conversion function.
174 dtype: Optional element type for the returned tensor. If missing, the type
175 is inferred from the type of `value`.
176 name: Optional name to use if a new `Tensor` is created.
177 as_ref: Optional boolean specifying if the returned value should be a
178 reference-type `Tensor` (e.g. Variable). Pass-through to the registered
179 conversion function. Defaults to `False`.
180 preferred_dtype: Optional element type for the returned tensor.
181 Used when dtype is None. In some cases, a caller may not have a dtype
182 in mind when converting to a tensor, so `preferred_dtype` can be used
183 as a soft preference. If the conversion to `preferred_dtype` is not
184 possible, this argument has no effect.
185 accepted_result_types: Optional collection of types as an allow-list
186 for the returned value. If a conversion function returns an object
187 which is not an instance of some type in this collection, that value
188 will not be returned.
190 Returns:
191 A `Tensor` converted from `value`.
193 Raises:
194 ValueError: If `value` is a `Tensor` and conversion is requested
195 to a `Tensor` with an incompatible `dtype`.
196 TypeError: If no conversion function is registered for an element in
197 `values`.
198 RuntimeError: If a registered conversion function returns an invalid
199 value.
200 """
202 if dtype is not None:
203 dtype = dtypes.as_dtype(dtype)
204 if preferred_dtype is not None:
205 preferred_dtype = dtypes.as_dtype(preferred_dtype)
207 overload = getattr(value, "__tf_tensor__", None)
208 if overload is not None:
209 return overload(dtype, name) # pylint: disable=not-callable
211 for base_type, conversion_func in get(type(value)):
212 # If dtype is None but preferred_dtype is not None, we try to
213 # cast to preferred_dtype first.
214 ret = None
215 if dtype is None and preferred_dtype is not None:
216 try:
217 ret = conversion_func(
218 value, dtype=preferred_dtype, name=name, as_ref=as_ref)
219 except (TypeError, ValueError):
220 # Could not coerce the conversion to use the preferred dtype.
221 pass
222 else:
223 if (ret is not NotImplemented and
224 ret.dtype.base_dtype != preferred_dtype.base_dtype):
225 raise RuntimeError(
226 _add_error_prefix(
227 f"Conversion function {conversion_func!r} for type "
228 f"{base_type} returned incompatible dtype: requested = "
229 f"{preferred_dtype.base_dtype.name}, "
230 f"actual = {ret.dtype.base_dtype.name}",
231 name=name))
233 if ret is None:
234 ret = conversion_func(value, dtype=dtype, name=name, as_ref=as_ref)
236 if ret is NotImplemented:
237 continue
239 if not isinstance(ret, accepted_result_types):
240 raise RuntimeError(
241 _add_error_prefix(
242 f"Conversion function {conversion_func!r} for type "
243 f"{base_type} returned non-Tensor: {ret!r}",
244 name=name))
245 if dtype and not dtype.is_compatible_with(ret.dtype):
246 raise RuntimeError(
247 _add_error_prefix(
248 f"Conversion function {conversion_func} for type {base_type} "
249 f"returned incompatible dtype: requested = {dtype.name}, "
250 f"actual = {ret.dtype.name}",
251 name=name))
252 return ret
253 raise TypeError(
254 _add_error_prefix(
255 f"Cannot convert {value!r} with type {type(value)} to Tensor: "
256 f"no conversion function registered.",
257 name=name))