Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/framework/tensor_conversion_registry.py: 77%

70 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2019 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""Registry for tensor conversion functions.""" 

16# pylint: disable=g-bad-name 

17import collections 

18import threading 

19 

20import numpy as np 

21 

22from tensorflow.python.framework import dtypes 

23from tensorflow.python.types import core 

24from tensorflow.python.util.tf_export import tf_export 

25 

26 

27_tensor_conversion_func_registry = collections.defaultdict(list) 

28_tensor_conversion_func_cache = {} 

29_tensor_conversion_func_lock = threading.Lock() 

30 

31# Instances of these types should only be converted by internally-registered 

32# conversion functions. 

33_CONSTANT_OP_CONVERTIBLES = ( 

34 int, 

35 float, 

36 np.generic, 

37 np.ndarray, 

38) 

39 

40 

41# TODO(josh11b): Add ctx argument to conversion_func() signature. 

42def register_tensor_conversion_function_internal(base_type, 

43 conversion_func, 

44 priority=100): 

45 """Internal version of register_tensor_conversion_function. 

46 

47 See docstring of `register_tensor_conversion_function` for details. 

48 

49 The internal version of the function allows registering conversions 

50 for types in the _UNCONVERTIBLE_TYPES tuple. 

51 

52 Args: 

53 base_type: The base type or tuple of base types for all objects that 

54 `conversion_func` accepts. 

55 conversion_func: A function that converts instances of `base_type` to 

56 `Tensor`. 

57 priority: Optional integer that indicates the priority for applying this 

58 conversion function. Conversion functions with smaller priority values run 

59 earlier than conversion functions with larger priority values. Defaults to 

60 100. 

61 

62 Raises: 

63 TypeError: If the arguments do not have the appropriate type. 

64 """ 

65 base_types = base_type if isinstance(base_type, tuple) else (base_type,) 

66 if any(not isinstance(x, type) for x in base_types): 

67 raise TypeError("Argument `base_type` must be a type or a tuple of types. " 

68 f"Obtained: {base_type}") 

69 del base_types # Only needed for validation. 

70 if not callable(conversion_func): 

71 raise TypeError("Argument `conversion_func` must be callable. Received " 

72 f"{conversion_func}.") 

73 

74 with _tensor_conversion_func_lock: 

75 _tensor_conversion_func_registry[priority].append( 

76 (base_type, conversion_func)) 

77 _tensor_conversion_func_cache.clear() 

78 

79 

80@tf_export("register_tensor_conversion_function") 

81def register_tensor_conversion_function(base_type, 

82 conversion_func, 

83 priority=100): 

84 """Registers a function for converting objects of `base_type` to `Tensor`. 

85 

86 The conversion function must have the following signature: 

87 

88 ```python 

89 def conversion_func(value, dtype=None, name=None, as_ref=False): 

90 # ... 

91 ``` 

92 

93 It must return a `Tensor` with the given `dtype` if specified. If the 

94 conversion function creates a new `Tensor`, it should use the given 

95 `name` if specified. All exceptions will be propagated to the caller. 

96 

97 The conversion function may return `NotImplemented` for some 

98 inputs. In this case, the conversion process will continue to try 

99 subsequent conversion functions. 

100 

101 If `as_ref` is true, the function must return a `Tensor` reference, 

102 such as a `Variable`. 

103 

104 NOTE: The conversion functions will execute in order of priority, 

105 followed by order of registration. To ensure that a conversion function 

106 `F` runs before another conversion function `G`, ensure that `F` is 

107 registered with a smaller priority than `G`. 

108 

109 Args: 

110 base_type: The base type or tuple of base types for all objects that 

111 `conversion_func` accepts. 

112 conversion_func: A function that converts instances of `base_type` to 

113 `Tensor`. 

114 priority: Optional integer that indicates the priority for applying this 

115 conversion function. Conversion functions with smaller priority values run 

116 earlier than conversion functions with larger priority values. Defaults to 

117 100. 

118 

119 Raises: 

120 TypeError: If the arguments do not have the appropriate type. 

121 """ 

122 base_types = base_type if isinstance(base_type, tuple) else (base_type,) 

123 if any(not isinstance(x, type) for x in base_types): 

124 raise TypeError("Argument `base_type` must be a type or a tuple of types. " 

125 f"Obtained: {base_type}") 

126 if any(issubclass(x, _CONSTANT_OP_CONVERTIBLES) for x in base_types): 

127 raise TypeError("Cannot register conversions for Python numeric types and " 

128 "NumPy scalars and arrays.") 

129 del base_types # Only needed for validation. 

130 register_tensor_conversion_function_internal( 

131 base_type, conversion_func, priority) 

132 

133 

134def get(query): 

135 """Get conversion function for objects of `cls`. 

136 

137 Args: 

138 query: The type to query for. 

139 

140 Returns: 

141 A list of conversion functions in increasing order of priority. 

142 """ 

143 conversion_funcs = _tensor_conversion_func_cache.get(query) 

144 if conversion_funcs is None: 

145 with _tensor_conversion_func_lock: 

146 # Has another thread populated the cache in the meantime? 

147 conversion_funcs = _tensor_conversion_func_cache.get(query) 

148 if conversion_funcs is None: 

149 conversion_funcs = [] 

150 for _, funcs_at_priority in sorted( 

151 _tensor_conversion_func_registry.items()): 

152 conversion_funcs.extend( 

153 (base_type, conversion_func) 

154 for base_type, conversion_func in funcs_at_priority 

155 if issubclass(query, base_type)) 

156 _tensor_conversion_func_cache[query] = conversion_funcs 

157 return conversion_funcs 

158 

159 

160def _add_error_prefix(msg, *, name=None): 

161 return msg if name is None else f"{name}: {msg}" 

162 

163 

164def convert(value, 

165 dtype=None, 

166 name=None, 

167 as_ref=False, 

168 preferred_dtype=None, 

169 accepted_result_types=(core.Symbol,)): 

170 """Converts `value` to a `Tensor` using registered conversion functions. 

171 

172 Args: 

173 value: An object whose type has a registered `Tensor` conversion function. 

174 dtype: Optional element type for the returned tensor. If missing, the type 

175 is inferred from the type of `value`. 

176 name: Optional name to use if a new `Tensor` is created. 

177 as_ref: Optional boolean specifying if the returned value should be a 

178 reference-type `Tensor` (e.g. Variable). Pass-through to the registered 

179 conversion function. Defaults to `False`. 

180 preferred_dtype: Optional element type for the returned tensor. 

181 Used when dtype is None. In some cases, a caller may not have a dtype 

182 in mind when converting to a tensor, so `preferred_dtype` can be used 

183 as a soft preference. If the conversion to `preferred_dtype` is not 

184 possible, this argument has no effect. 

185 accepted_result_types: Optional collection of types as an allow-list 

186 for the returned value. If a conversion function returns an object 

187 which is not an instance of some type in this collection, that value 

188 will not be returned. 

189 

190 Returns: 

191 A `Tensor` converted from `value`. 

192 

193 Raises: 

194 ValueError: If `value` is a `Tensor` and conversion is requested 

195 to a `Tensor` with an incompatible `dtype`. 

196 TypeError: If no conversion function is registered for an element in 

197 `values`. 

198 RuntimeError: If a registered conversion function returns an invalid 

199 value. 

200 """ 

201 

202 if dtype is not None: 

203 dtype = dtypes.as_dtype(dtype) 

204 if preferred_dtype is not None: 

205 preferred_dtype = dtypes.as_dtype(preferred_dtype) 

206 

207 overload = getattr(value, "__tf_tensor__", None) 

208 if overload is not None: 

209 return overload(dtype, name) # pylint: disable=not-callable 

210 

211 for base_type, conversion_func in get(type(value)): 

212 # If dtype is None but preferred_dtype is not None, we try to 

213 # cast to preferred_dtype first. 

214 ret = None 

215 if dtype is None and preferred_dtype is not None: 

216 try: 

217 ret = conversion_func( 

218 value, dtype=preferred_dtype, name=name, as_ref=as_ref) 

219 except (TypeError, ValueError): 

220 # Could not coerce the conversion to use the preferred dtype. 

221 pass 

222 else: 

223 if (ret is not NotImplemented and 

224 ret.dtype.base_dtype != preferred_dtype.base_dtype): 

225 raise RuntimeError( 

226 _add_error_prefix( 

227 f"Conversion function {conversion_func!r} for type " 

228 f"{base_type} returned incompatible dtype: requested = " 

229 f"{preferred_dtype.base_dtype.name}, " 

230 f"actual = {ret.dtype.base_dtype.name}", 

231 name=name)) 

232 

233 if ret is None: 

234 ret = conversion_func(value, dtype=dtype, name=name, as_ref=as_ref) 

235 

236 if ret is NotImplemented: 

237 continue 

238 

239 if not isinstance(ret, accepted_result_types): 

240 raise RuntimeError( 

241 _add_error_prefix( 

242 f"Conversion function {conversion_func!r} for type " 

243 f"{base_type} returned non-Tensor: {ret!r}", 

244 name=name)) 

245 if dtype and not dtype.is_compatible_with(ret.dtype): 

246 raise RuntimeError( 

247 _add_error_prefix( 

248 f"Conversion function {conversion_func} for type {base_type} " 

249 f"returned incompatible dtype: requested = {dtype.name}, " 

250 f"actual = {ret.dtype.name}", 

251 name=name)) 

252 return ret 

253 raise TypeError( 

254 _add_error_prefix( 

255 f"Cannot convert {value!r} with type {type(value)} to Tensor: " 

256 f"no conversion function registered.", 

257 name=name))